|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from resnet import SupCEResNet |
|
|
|
|
|
|
|
|
class_labels = [ |
|
|
"T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", |
|
|
"Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit", |
|
|
"Shawl", "Dress", "Vest", "Underwear" |
|
|
] |
|
|
|
|
|
|
|
|
def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"): |
|
|
try: |
|
|
print("Downloading model from Hugging Face...") |
|
|
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict) and "model" in checkpoint: |
|
|
state_dict = checkpoint["model"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
|
|
|
|
|
model = SupCEResNet(name='resnet50', num_classes=14, pool=True) |
|
|
|
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
model.eval() |
|
|
|
|
|
print("Model loaded successfully from Hugging Face!") |
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
model = load_model_from_huggingface() |
|
|
|
|
|
def classify_image(image): |
|
|
"""Process and classify an uploaded PIL image accurately.""" |
|
|
|
|
|
|
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
transform_test = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
]) |
|
|
|
|
|
|
|
|
image_tensor = transform_test(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(image_tensor) |
|
|
_, pred = torch.max(output, 1) |
|
|
|
|
|
|
|
|
predicted_label = class_labels[pred.item()] |
|
|
return f"Predicted Category: {predicted_label}" |
|
|
|
|
|
|
|
|
example_url = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/resolve/main/content/drive/MyDrive/CS5930/download.jpeg" |
|
|
|
|
|
def load_example_image(): |
|
|
"""Download and return an example image from Hugging Face""" |
|
|
try: |
|
|
response = requests.get(example_url) |
|
|
if response.status_code == 200: |
|
|
return Image.open(BytesIO(response.content)).convert("RGB") |
|
|
else: |
|
|
print("Failed to fetch example image.") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error loading example image: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
example_image = load_example_image() |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=classify_image, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs="text", |
|
|
title="Clothing Image Classifier", |
|
|
description="Upload an image or use the example below. The model will classify it into one of 14 clothing categories.", |
|
|
allow_flagging="never", |
|
|
examples=[[example_image]] if example_image else None |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch() |
|
|
|