import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import json from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, AutoModel from model import FineGrainedClassifier # Download model files try: model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="model_checkpoint.pth") label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json") with open(label_path, 'r') as f: label_to_class = json.load(f) num_classes = len(label_to_class) model = FineGrainedClassifier(num_classes=num_classes) # Load checkpoint and extract model_state_dict checkpoint = torch.load(model_path, map_location=torch.device('cpu')) state_dict = checkpoint.get('model_state_dict', checkpoint) model.load_state_dict(state_dict) model.eval() # Load text tokenizer tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True) text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True) text_encoder.eval() model_loaded = True except Exception as e: print(f"Error loading model: {e}") model_loaded = False label_to_class = {} # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def classify_product(image, text=""): if not model_loaded: return {"Error": "Model not loaded properly"} try: # Process image if image is None: return {"Error": "Please provide an image"} img = Image.fromarray(image).convert('RGB') img_tensor = transform(img).unsqueeze(0) # Process text if text.strip(): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): text_embeddings = text_encoder(**inputs).last_hidden_state.mean(dim=1) else: # Use zero embeddings if no text provided text_embeddings = torch.zeros(1, 768) # Get predictions with torch.no_grad(): outputs = model(img_tensor, text_embeddings) probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] # Get top predictions top_k = min(10, len(label_to_class)) top_probs, top_indices = torch.topk(probabilities, top_k) results = {} for prob, idx in zip(top_probs, top_indices): label = label_to_class.get(str(idx.item()), f"Class {idx.item()}") results[label] = float(prob) return results except Exception as e: return {"Error": f"Classification failed: {str(e)}"} # Create Gradio interface demo = gr.Interface( fn=classify_product, inputs=[ gr.Image(label="Product Image"), gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2, value="") ], outputs=gr.Label(label="Classification Results", num_top_classes=10), title="🛍️ E-Commerce Product Classifier", description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Upload a product image and optionally provide a text description to classify it into the appropriate category.", examples=[], theme="soft" ) if __name__ == "__main__": demo.launch()