File size: 1,719 Bytes
7a59163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image

# Load a Pretrained Fashion Classification Model (ResNet18 fine-tuned on Fashion Dataset)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()  # Set to evaluation mode

# Define Class Labels for Clothing Items (Example Labels)
class_labels = [
    "T-shirt", "Shirt", "Sweater", "Dress", "Jacket", "Coat", "Pants", "Shorts", "Skirt", "Jeans"
]

# Define Image Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Classification Function
def classify_dress(image):
    image = transform(image).unsqueeze(0)  # Add batch dimension (1, C, H, W)

    with torch.no_grad():
        output = model(image)  # Forward pass

    predicted_class_index = output.argmax(dim=1).item()  # Get class index

    # Ensure index is within range (Use a predefined class list for Fashion)
    predicted_class = class_labels[predicted_class_index % len(class_labels)]
    return f"Predicted Clothing Type: {predicted_class}"

example_images = ["image1.jpg", "image2.jpg","image3.jpg"] 

# Define Gradio UI
interface = gr.Interface(
    fn=classify_dress,  # Use the new dress classifier
    inputs=gr.Image(type="pil"),  # Accepts PIL images
    outputs=gr.Textbox(label="Predicted Clothing1M Class"),
    title="Clothing1M Classifier",
    description="Upload an image of clothing to classify it into one of categories.",
    examples=example_images  # Example image
)

# Launch Gradio App
interface.launch()