clothing1m / app.py
prshanthreddy's picture
Update app.py
ca1ce58 verified
import gradio as gr
import torch
import torchvision.transforms as transforms
from model import model
from PIL import Image
# Static model checkpoint path
RESUME_PATH = "netBest1.pth"
# Class Mapping
CLOTHING1M_CLASSES = {
0: "T-shirt", 1: "Shirt", 2: "Knitwear", 3: "Chiffon", 4: "Sweater",
5: "Hoodie", 6: "Windbreaker", 7: "Jacket", 8: "Down Coat", 9: "Suit",
10: "Shawl", 11: "Dress", 12: "Vest", 13: "Underwear"
}
# Load model
def load_model():
net_feat = model.NetFeat(arch='resnet18', pretrained=False, dataset="Clothing1M")
net_cls = model.NetClassifier(feat_dim=net_feat.feat_dim, nb_cls=14)
param = torch.load(RESUME_PATH, map_location=torch.device("cpu"))
net_feat.load_state_dict(param['feat'])
net_cls.load_state_dict(param['cls'])
net_feat.eval()
net_cls.eval()
return net_feat, net_cls
# Image Preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
image = image.convert("RGB")
return transform(image).unsqueeze(0) # Add batch dimension
# Image Classification
def classify_image(image):
net_feat, net_cls = load_model()
image_tensor = preprocess_image(image).to("cpu")
with torch.no_grad():
features = net_feat(image_tensor)
output = net_cls(features)
_, predicted = torch.max(output, 1)
return CLOTHING1M_CLASSES[predicted.item()]
# Gradio Interface
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Predicted Category"),
title="Clothing Image Classifier",
description="Upload an image to classify its clothing category.",
examples=[
["examples/83.jpg"],
["examples/48.jpg"],
["examples/74.jpg"],
["examples/147.jpg"],
["examples/148.jpg"],
["examples/525.jpg"],
["examples/325.jpg"],
["examples/550.jpg"]
]
)
demo.launch()