| import torch |
| import torchvision.transforms as transforms |
| import gradio as gr |
| from PIL import Image |
|
|
| |
| model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True) |
| model.eval() |
|
|
| |
| class_labels = [ |
| "T-shirt", "Shirt", "Sweater", "Dress", "Jacket", "Coat", "Pants", "Shorts", "Skirt", "Jeans" |
| ] |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| def classify_dress(image): |
| image = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(image) |
|
|
| predicted_class_index = output.argmax(dim=1).item() |
|
|
| |
| predicted_class = class_labels[predicted_class_index % len(class_labels)] |
| return f"Predicted Clothing Type: {predicted_class}" |
|
|
| example_images = ["image1.jpg", "image2.jpg","image3.jpg"] |
|
|
| |
| interface = gr.Interface( |
| fn=classify_dress, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Textbox(label="Predicted Clothing1M Class"), |
| title="Clothing1M Classifier", |
| description="Upload an image of clothing to classify it into one of categories.", |
| examples=example_images |
| ) |
|
|
| |
| interface.launch() |