Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import json | |
| import requests | |
| from PIL import Image | |
| from torchvision import transforms | |
| import urllib.request | |
| from torchvision import models | |
| import torch.nn as nn | |
| # --- Define the Model --- | |
| class FineGrainedClassifier(nn.Module): | |
| def __init__(self, num_classes=434): # Updated to 434 classes | |
| super(FineGrainedClassifier, self).__init__() | |
| self.image_encoder = models.resnet50(pretrained=True) | |
| self.image_encoder.fc = nn.Identity() | |
| self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en') | |
| self.classifier = nn.Sequential( | |
| nn.Linear(2048 + 768, 1024), | |
| nn.BatchNorm1d(1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(1024, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, num_classes) # Updated to 434 classes | |
| ) | |
| def forward(self, image, input_ids, attention_mask): | |
| image_features = self.image_encoder(image) | |
| text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| text_features = text_output.last_hidden_state[:, 0, :] | |
| combined_features = torch.cat((image_features, text_features), dim=1) | |
| output = self.classifier(combined_features) | |
| return output | |
| # --- Data Augmentation Setup --- | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load the label-to-class mapping from your Hugging Face repository | |
| label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" | |
| label_to_class = requests.get(label_map_url).json() | |
| # Load your custom model from Hugging Face | |
| model = FineGrainedClassifier(num_classes=len(label_to_class)) | |
| checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth" | |
| checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu')) | |
| # Strip the "module." prefix from the keys in the state_dict if they exist | |
| # Clean up the state dictionary | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_key = k[7:] # Remove "module." prefix | |
| else: | |
| new_key = k | |
| # Check if the new_key exists in the model's state_dict, only add if it does | |
| if new_key in model.state_dict(): | |
| new_state_dict[new_key] = v | |
| model.load_state_dict(new_state_dict) | |
| # Load the tokenizer from Jina | |
| tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") | |
| def load_image(image_path_or_url): | |
| if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"): | |
| with urllib.request.urlopen(image_path_or_url) as url: | |
| image = Image.open(url).convert('RGB') | |
| else: | |
| image = Image.open(image_path_or_url).convert('RGB') | |
| image = transform(image) | |
| image = image.unsqueeze(0) # Add batch dimension | |
| return image | |
| def predict(image_path_or_file, title, threshold=0.4): | |
| # Validation: Check if the title is empty or has fewer than 3 words | |
| if not title or len(title.split()) < 3: | |
| raise gr.Error("Title must be at least 3 words long. Please provide a valid title.") | |
| # Preprocess the image | |
| image = load_image(image_path_or_file) | |
| # Tokenize title | |
| title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt') | |
| input_ids = title_encoding['input_ids'] | |
| attention_mask = title_encoding['attention_mask'] | |
| # Predict | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(image, input_ids=input_ids, attention_mask=attention_mask) | |
| probabilities = torch.nn.functional.softmax(output, dim=1) | |
| top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) | |
| # Map the top 3 indices to class names | |
| top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] | |
| # Check if the highest probability is below the threshold | |
| if top3_probabilities[0][0].item() < threshold: | |
| top3_classes.insert(0, "Others") | |
| top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) | |
| # Prepare the output as a dictionary | |
| results = {} | |
| for i in range(len(top3_classes)): | |
| results[top3_classes[i]] = top3_probabilities[0][i].item() | |
| return results | |
| # Define the Gradio interface | |
| title_input = gr.Textbox(label="Product Title", placeholder="Enter the product title here...") | |
| image_input = gr.Image(type="filepath", label="Upload Image or Provide URL") | |
| output = gr.JSON(label="Top 3 Predictions with Probabilities") | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[image_input, title_input], | |
| outputs=output, | |
| title="Ecommerce Classifier", | |
| description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.", | |
| ).launch(share=True) | |