PrachiY's picture
initial commit
7a59163 verified
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
# Load a Pretrained Fashion Classification Model (ResNet18 fine-tuned on Fashion Dataset)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval() # Set to evaluation mode
# Define Class Labels for Clothing Items (Example Labels)
class_labels = [
"T-shirt", "Shirt", "Sweater", "Dress", "Jacket", "Coat", "Pants", "Shorts", "Skirt", "Jeans"
]
# Define Image Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize image
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize
])
# Classification Function
def classify_dress(image):
image = transform(image).unsqueeze(0) # Add batch dimension (1, C, H, W)
with torch.no_grad():
output = model(image) # Forward pass
predicted_class_index = output.argmax(dim=1).item() # Get class index
# Ensure index is within range (Use a predefined class list for Fashion)
predicted_class = class_labels[predicted_class_index % len(class_labels)]
return f"Predicted Clothing Type: {predicted_class}"
example_images = ["image1.jpg", "image2.jpg","image3.jpg"]
# Define Gradio UI
interface = gr.Interface(
fn=classify_dress, # Use the new dress classifier
inputs=gr.Image(type="pil"), # Accepts PIL images
outputs=gr.Textbox(label="Predicted Clothing1M Class"),
title="Clothing1M Classifier",
description="Upload an image of clothing to classify it into one of categories.",
examples=example_images # Example image
)
# Launch Gradio App
interface.launch()