import torch import torchvision.transforms as transforms from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from model import NetFeat, NetClassifier CLOTHING_CLASSES = [ "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket", "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit" ] # Load the model def load_model(): model_filename = 'netBest.pth' # Adjust the path as necessary net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M') net_cls = NetClassifier(feat_dim=512, nb_cls=14) state_dict = torch.load(model_filename, map_location=torch.device('cpu')) if "feat" in state_dict: net_feat.load_state_dict(state_dict['feat'], strict=False) if "cls" in state_dict: net_cls.load_state_dict(state_dict['cls'], strict=False) net_feat.eval() net_cls.eval() return net_feat, net_cls # Preprocess image for model input def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(image).convert("RGB") return transform(image).unsqueeze(0) def run_inference(image, net_feat, net_cls): image_tensor = preprocess_image(image) with torch.no_grad(): feature_vector = net_feat(image_tensor) output = net_cls(feature_vector) predicted_index = output.argmax(dim=1).item() return CLOTHING_CLASSES[predicted_index] net_feat, net_cls = load_model() def classify_image(image): return run_inference(image, net_feat, net_cls) example_images = ["example.jpeg", "example2.webp","image2.jpg"] interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="filepath"), # Simple Image input outputs=gr.Textbox(label="Predicted Clothing1M Class"), title="Clothing1M Classifier", description="Upload an image of clothing to classify it into one of 14 categories.", examples=example_images ) if __name__ == "__main__": interface.launch()