Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from model import model | |
| from PIL import Image | |
| # Static model checkpoint path | |
| RESUME_PATH = "netBest1.pth" | |
| # Class Mapping | |
| CLOTHING1M_CLASSES = { | |
| 0: "T-shirt", 1: "Shirt", 2: "Knitwear", 3: "Chiffon", 4: "Sweater", | |
| 5: "Hoodie", 6: "Windbreaker", 7: "Jacket", 8: "Down Coat", 9: "Suit", | |
| 10: "Shawl", 11: "Dress", 12: "Vest", 13: "Underwear" | |
| } | |
| # Load model | |
| def load_model(): | |
| net_feat = model.NetFeat(arch='resnet18', pretrained=False, dataset="Clothing1M") | |
| net_cls = model.NetClassifier(feat_dim=net_feat.feat_dim, nb_cls=14) | |
| param = torch.load(RESUME_PATH, map_location=torch.device("cpu")) | |
| net_feat.load_state_dict(param['feat']) | |
| net_cls.load_state_dict(param['cls']) | |
| net_feat.eval() | |
| net_cls.eval() | |
| return net_feat, net_cls | |
| # Image Preprocessing | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| ]) | |
| image = image.convert("RGB") | |
| return transform(image).unsqueeze(0) # Add batch dimension | |
| # Image Classification | |
| def classify_image(image): | |
| net_feat, net_cls = load_model() | |
| image_tensor = preprocess_image(image).to("cpu") | |
| with torch.no_grad(): | |
| features = net_feat(image_tensor) | |
| output = net_cls(features) | |
| _, predicted = torch.max(output, 1) | |
| return CLOTHING1M_CLASSES[predicted.item()] | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(label="Predicted Category"), | |
| title="Clothing Image Classifier", | |
| description="Upload an image to classify its clothing category.", | |
| examples=[ | |
| ["examples/83.jpg"], | |
| ["examples/48.jpg"], | |
| ["examples/74.jpg"], | |
| ["examples/147.jpg"], | |
| ["examples/148.jpg"], | |
| ["examples/525.jpg"], | |
| ["examples/325.jpg"], | |
| ["examples/550.jpg"] | |
| ] | |
| ) | |
| demo.launch() | |