runthebandsup's picture
Fix checkpoint loading to extract model_state_dict properly
bc1fc68 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModel
from model import FineGrainedClassifier
# Download model files
try:
model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="model_checkpoint.pth")
label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json")
with open(label_path, 'r') as f:
label_to_class = json.load(f)
num_classes = len(label_to_class)
model = FineGrainedClassifier(num_classes=num_classes)
# Load checkpoint and extract model_state_dict
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state_dict)
model.eval()
# Load text tokenizer
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
text_encoder.eval()
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
model_loaded = False
label_to_class = {}
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def classify_product(image, text=""):
if not model_loaded:
return {"Error": "Model not loaded properly"}
try:
# Process image
if image is None:
return {"Error": "Please provide an image"}
img = Image.fromarray(image).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
# Process text
if text.strip():
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
text_embeddings = text_encoder(**inputs).last_hidden_state.mean(dim=1)
else:
# Use zero embeddings if no text provided
text_embeddings = torch.zeros(1, 768)
# Get predictions
with torch.no_grad():
outputs = model(img_tensor, text_embeddings)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
# Get top predictions
top_k = min(10, len(label_to_class))
top_probs, top_indices = torch.topk(probabilities, top_k)
results = {}
for prob, idx in zip(top_probs, top_indices):
label = label_to_class.get(str(idx.item()), f"Class {idx.item()}")
results[label] = float(prob)
return results
except Exception as e:
return {"Error": f"Classification failed: {str(e)}"}
# Create Gradio interface
demo = gr.Interface(
fn=classify_product,
inputs=[
gr.Image(label="Product Image"),
gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2, value="")
],
outputs=gr.Label(label="Classification Results", num_top_classes=10),
title="🛍️ E-Commerce Product Classifier",
description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Upload a product image and optionally provide a text description to classify it into the appropriate category.",
examples=[],
theme="soft"
)
if __name__ == "__main__":
demo.launch()