Spaces:
Build error
Build error
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from model import NetFeat, NetClassifier | |
| CLOTHING_CLASSES = [ | |
| "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket", | |
| "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit" | |
| ] | |
| # Load the model | |
| def load_model(): | |
| model_filename = 'netBest.pth' # Adjust the path as necessary | |
| net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M') | |
| net_cls = NetClassifier(feat_dim=512, nb_cls=14) | |
| state_dict = torch.load(model_filename, map_location=torch.device('cpu')) | |
| if "feat" in state_dict: | |
| net_feat.load_state_dict(state_dict['feat'], strict=False) | |
| if "cls" in state_dict: | |
| net_cls.load_state_dict(state_dict['cls'], strict=False) | |
| net_feat.eval() | |
| net_cls.eval() | |
| return net_feat, net_cls | |
| # Preprocess image for model input | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image).convert("RGB") | |
| return transform(image).unsqueeze(0) | |
| def run_inference(image, net_feat, net_cls): | |
| image_tensor = preprocess_image(image) | |
| with torch.no_grad(): | |
| feature_vector = net_feat(image_tensor) | |
| output = net_cls(feature_vector) | |
| predicted_index = output.argmax(dim=1).item() | |
| return CLOTHING_CLASSES[predicted_index] | |
| net_feat, net_cls = load_model() | |
| def classify_image(image): | |
| return run_inference(image, net_feat, net_cls) | |
| example_images = ["example.jpeg", "example2.webp","image2.jpg"] | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="filepath"), # Simple Image input | |
| outputs=gr.Textbox(label="Predicted Clothing1M Class"), | |
| title="Clothing1M Classifier", | |
| description="Upload an image of clothing to classify it into one of 14 categories.", | |
| examples=example_images | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |