Spaces:
Sleeping
Sleeping
File size: 3,647 Bytes
2a0af58 d157a44 23aa755 d157a44 23aa755 d157a44 23aa755 d157a44 2a0af58 d157a44 2a0af58 d157a44 bc1fc68 d157a44 23aa755 d157a44 23aa755 d157a44 2a0af58 d157a44 23aa755 d157a44 2a0af58 d157a44 23aa755 d157a44 2a0af58 23aa755 2a0af58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModel
from model import FineGrainedClassifier
# Download model files
try:
model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="model_checkpoint.pth")
label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json")
with open(label_path, 'r') as f:
label_to_class = json.load(f)
num_classes = len(label_to_class)
model = FineGrainedClassifier(num_classes=num_classes)
# Load checkpoint and extract model_state_dict
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state_dict)
model.eval()
# Load text tokenizer
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
text_encoder.eval()
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
model_loaded = False
label_to_class = {}
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def classify_product(image, text=""):
if not model_loaded:
return {"Error": "Model not loaded properly"}
try:
# Process image
if image is None:
return {"Error": "Please provide an image"}
img = Image.fromarray(image).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
# Process text
if text.strip():
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
text_embeddings = text_encoder(**inputs).last_hidden_state.mean(dim=1)
else:
# Use zero embeddings if no text provided
text_embeddings = torch.zeros(1, 768)
# Get predictions
with torch.no_grad():
outputs = model(img_tensor, text_embeddings)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
# Get top predictions
top_k = min(10, len(label_to_class))
top_probs, top_indices = torch.topk(probabilities, top_k)
results = {}
for prob, idx in zip(top_probs, top_indices):
label = label_to_class.get(str(idx.item()), f"Class {idx.item()}")
results[label] = float(prob)
return results
except Exception as e:
return {"Error": f"Classification failed: {str(e)}"}
# Create Gradio interface
demo = gr.Interface(
fn=classify_product,
inputs=[
gr.Image(label="Product Image"),
gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2, value="")
],
outputs=gr.Label(label="Classification Results", num_top_classes=10),
title="🛍️ E-Commerce Product Classifier",
description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Upload a product image and optionally provide a text description to classify it into the appropriate category.",
examples=[],
theme="soft"
)
if __name__ == "__main__":
demo.launch() |