Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModel | |
| from model import FineGrainedClassifier | |
| # Download model files | |
| try: | |
| model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="model_checkpoint.pth") | |
| label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json") | |
| with open(label_path, 'r') as f: | |
| label_to_class = json.load(f) | |
| num_classes = len(label_to_class) | |
| model = FineGrainedClassifier(num_classes=num_classes) | |
| # Load checkpoint and extract model_state_dict | |
| checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Load text tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True) | |
| text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True) | |
| text_encoder.eval() | |
| model_loaded = True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model_loaded = False | |
| label_to_class = {} | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def classify_product(image, text=""): | |
| if not model_loaded: | |
| return {"Error": "Model not loaded properly"} | |
| try: | |
| # Process image | |
| if image is None: | |
| return {"Error": "Please provide an image"} | |
| img = Image.fromarray(image).convert('RGB') | |
| img_tensor = transform(img).unsqueeze(0) | |
| # Process text | |
| if text.strip(): | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| text_embeddings = text_encoder(**inputs).last_hidden_state.mean(dim=1) | |
| else: | |
| # Use zero embeddings if no text provided | |
| text_embeddings = torch.zeros(1, 768) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(img_tensor, text_embeddings) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| # Get top predictions | |
| top_k = min(10, len(label_to_class)) | |
| top_probs, top_indices = torch.topk(probabilities, top_k) | |
| results = {} | |
| for prob, idx in zip(top_probs, top_indices): | |
| label = label_to_class.get(str(idx.item()), f"Class {idx.item()}") | |
| results[label] = float(prob) | |
| return results | |
| except Exception as e: | |
| return {"Error": f"Classification failed: {str(e)}"} | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=classify_product, | |
| inputs=[ | |
| gr.Image(label="Product Image"), | |
| gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2, value="") | |
| ], | |
| outputs=gr.Label(label="Classification Results", num_top_classes=10), | |
| title="🛍️ E-Commerce Product Classifier", | |
| description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Upload a product image and optionally provide a text description to classify it into the appropriate category.", | |
| examples=[], | |
| theme="soft" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |