import gradio as gr import torch import torchvision.transforms as transforms from model import model from PIL import Image # Static model checkpoint path RESUME_PATH = "netBest1.pth" # Class Mapping CLOTHING1M_CLASSES = { 0: "T-shirt", 1: "Shirt", 2: "Knitwear", 3: "Chiffon", 4: "Sweater", 5: "Hoodie", 6: "Windbreaker", 7: "Jacket", 8: "Down Coat", 9: "Suit", 10: "Shawl", 11: "Dress", 12: "Vest", 13: "Underwear" } # Load model def load_model(): net_feat = model.NetFeat(arch='resnet18', pretrained=False, dataset="Clothing1M") net_cls = model.NetClassifier(feat_dim=net_feat.feat_dim, nb_cls=14) param = torch.load(RESUME_PATH, map_location=torch.device("cpu")) net_feat.load_state_dict(param['feat']) net_cls.load_state_dict(param['cls']) net_feat.eval() net_cls.eval() return net_feat, net_cls # Image Preprocessing def preprocess_image(image): transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) image = image.convert("RGB") return transform(image).unsqueeze(0) # Add batch dimension # Image Classification def classify_image(image): net_feat, net_cls = load_model() image_tensor = preprocess_image(image).to("cpu") with torch.no_grad(): features = net_feat(image_tensor) output = net_cls(features) _, predicted = torch.max(output, 1) return CLOTHING1M_CLASSES[predicted.item()] # Gradio Interface demo = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Predicted Category"), title="Clothing Image Classifier", description="Upload an image to classify its clothing category.", examples=[ ["examples/83.jpg"], ["examples/48.jpg"], ["examples/74.jpg"], ["examples/147.jpg"], ["examples/148.jpg"], ["examples/525.jpg"], ["examples/325.jpg"], ["examples/550.jpg"] ] ) demo.launch()