Spaces:
Build error
Build error
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from model import NetFeat, NetClassifier # Import feature extractor & classifier | |
| # Define Clothing1M class labels | |
| CLOTHING_CLASSES = [ | |
| "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket", | |
| "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit" | |
| ] | |
| # Load the pre-trained model from Hugging Face Hub | |
| def load_model(): | |
| # Download the trained model from Hugging Face Hub | |
| model_filename = "netBest.pth" | |
| # Initialize both feature extractor & classifier | |
| net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M') | |
| net_cls = NetClassifier(feat_dim=512, nb_cls=14) # Feature dim from ResNet-18 | |
| # Load saved model weights | |
| state_dict = torch.load(model_filename, map_location=torch.device('cpu')) | |
| # Load feature extractor weights | |
| if "feat" in state_dict: | |
| net_feat.load_state_dict(state_dict["feat"], strict=False) | |
| # Load classifier weights | |
| if "cls" in state_dict: | |
| net_cls.load_state_dict(state_dict["cls"], strict=False) | |
| # Set both models to evaluation mode | |
| net_feat.eval() | |
| net_cls.eval() | |
| return net_feat, net_cls | |
| # Preprocess image for model input | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image).convert("RGB") # Ensure the image is in RGB mode | |
| return transform(image).unsqueeze(0) # Add batch dimension | |
| # Run inference on a single image | |
| def run_inference(image_path, net_feat, net_cls): | |
| image_tensor = preprocess_image(image_path) | |
| with torch.no_grad(): | |
| feature_vector = net_feat(image_tensor) # Extract features | |
| output = net_cls(feature_vector) # Apply classification head | |
| predicted_index = output.argmax(dim=1).item() | |
| return CLOTHING_CLASSES[predicted_index] |