tfarhan10's picture
Update app.py
5b6f9be verified
import torch
import gradio as gr
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import requests
from io import BytesIO
from resnet import SupCEResNet
# Define class labels
class_labels = [
"T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
"Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit",
"Shawl", "Dress", "Vest", "Underwear"
]
# Load model from Hugging Face Hub
def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"):
try:
print("Downloading model from Hugging Face...")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False)
# Extract state_dict if stored in a dictionary
if isinstance(checkpoint, dict) and "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# Fix "module." prefix issue
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# Initialize model
model = SupCEResNet(name='resnet50', num_classes=14, pool=True)
# Load weights
model.load_state_dict(new_state_dict, strict=False) # `strict=False` allows minor mismatches
model.eval() # Set model to evaluation mode
print("Model loaded successfully from Hugging Face!")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Load the model
model = load_model_from_huggingface()
def classify_image(image):
"""Process and classify an uploaded PIL image accurately."""
# Ensure image is in RGB format
if image.mode != "RGB":
image = image.convert("RGB")
# Define preprocessing transformations (same as training)
transform_test = transforms.Compose([
transforms.Resize(256), # Resize the shorter side to 256
transforms.CenterCrop(224), # Center crop to 224x224 (expected input size)
transforms.ToTensor(), # Convert to Tensor
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # Normalize
])
# Apply transformations
image_tensor = transform_test(image).unsqueeze(0) # Add batch dimension
# Ensure tensor is on the same device as model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image_tensor = image_tensor.to(device)
# Run inference
with torch.no_grad():
output = model(image_tensor)
_, pred = torch.max(output, 1) # Get predicted class index
# Map predicted class index to label
predicted_label = class_labels[pred.item()]
return f"Predicted Category: {predicted_label}"
# Load example image from Hugging Face repository
example_url = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/resolve/main/content/drive/MyDrive/CS5930/download.jpeg"
def load_example_image():
"""Download and return an example image from Hugging Face"""
try:
response = requests.get(example_url)
if response.status_code == 200:
return Image.open(BytesIO(response.content)).convert("RGB")
else:
print("Failed to fetch example image.")
return None
except Exception as e:
print(f"Error loading example image: {e}")
return None
# Example image
example_image = load_example_image()
# Create Gradio Interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"), # Accept image input
outputs="text",
title="Clothing Image Classifier",
description="Upload an image or use the example below. The model will classify it into one of 14 clothing categories.",
allow_flagging="never", # Disable flagging feature
examples=[[example_image]] if example_image else None # Use example image if available
)
# Launch the app
if __name__ == "__main__":
interface.launch()