Spaces:
Build error
Build error
File size: 4,101 Bytes
0141596 35bc32f 0141596 946f89e 0141596 35bc32f 0141596 35bc32f 0141596 35bc32f 0141596 35bc32f 5b6f9be 0141596 35bc32f 0141596 5b6f9be 0141596 35bc32f 0141596 5d098b6 0141596 5d098b6 0141596 5d098b6 0141596 2d9baba 0141596 2d9baba 0141596 2d9baba 0141596 5b6f9be 0141596 5b6f9be 0141596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import torch
import gradio as gr
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import requests
from io import BytesIO
from resnet import SupCEResNet
# Define class labels
class_labels = [
"T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
"Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit",
"Shawl", "Dress", "Vest", "Underwear"
]
# Load model from Hugging Face Hub
def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"):
try:
print("Downloading model from Hugging Face...")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False)
# Extract state_dict if stored in a dictionary
if isinstance(checkpoint, dict) and "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# Fix "module." prefix issue
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# Initialize model
model = SupCEResNet(name='resnet50', num_classes=14, pool=True)
# Load weights
model.load_state_dict(new_state_dict, strict=False) # `strict=False` allows minor mismatches
model.eval() # Set model to evaluation mode
print("Model loaded successfully from Hugging Face!")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Load the model
model = load_model_from_huggingface()
def classify_image(image):
"""Process and classify an uploaded PIL image accurately."""
# Ensure image is in RGB format
if image.mode != "RGB":
image = image.convert("RGB")
# Define preprocessing transformations (same as training)
transform_test = transforms.Compose([
transforms.Resize(256), # Resize the shorter side to 256
transforms.CenterCrop(224), # Center crop to 224x224 (expected input size)
transforms.ToTensor(), # Convert to Tensor
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # Normalize
])
# Apply transformations
image_tensor = transform_test(image).unsqueeze(0) # Add batch dimension
# Ensure tensor is on the same device as model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image_tensor = image_tensor.to(device)
# Run inference
with torch.no_grad():
output = model(image_tensor)
_, pred = torch.max(output, 1) # Get predicted class index
# Map predicted class index to label
predicted_label = class_labels[pred.item()]
return f"Predicted Category: {predicted_label}"
# Load example image from Hugging Face repository
example_url = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/resolve/main/content/drive/MyDrive/CS5930/download.jpeg"
def load_example_image():
"""Download and return an example image from Hugging Face"""
try:
response = requests.get(example_url)
if response.status_code == 200:
return Image.open(BytesIO(response.content)).convert("RGB")
else:
print("Failed to fetch example image.")
return None
except Exception as e:
print(f"Error loading example image: {e}")
return None
# Example image
example_image = load_example_image()
# Create Gradio Interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"), # Accept image input
outputs="text",
title="Clothing Image Classifier",
description="Upload an image or use the example below. The model will classify it into one of 14 clothing categories.",
allow_flagging="never", # Disable flagging feature
examples=[[example_image]] if example_image else None # Use example image if available
)
# Launch the app
if __name__ == "__main__":
interface.launch()
|