File size: 3,647 Bytes
2a0af58
 
d157a44
23aa755
d157a44
 
 
23aa755
 
d157a44
 
 
23aa755
d157a44
2a0af58
d157a44
 
2a0af58
d157a44
 
bc1fc68
 
 
 
 
d157a44
23aa755
 
 
 
 
 
d157a44
 
 
 
 
 
 
 
 
 
 
 
 
23aa755
d157a44
 
2a0af58
d157a44
 
 
 
 
 
 
 
23aa755
 
 
 
 
 
 
 
d157a44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0af58
 
 
 
d157a44
 
23aa755
d157a44
2a0af58
 
23aa755
 
2a0af58
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModel
from model import FineGrainedClassifier

# Download model files
try:
    model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="model_checkpoint.pth")
    label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json")
    
    with open(label_path, 'r') as f:
        label_to_class = json.load(f)
    
    num_classes = len(label_to_class)
    model = FineGrainedClassifier(num_classes=num_classes)
    
    # Load checkpoint and extract model_state_dict
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    state_dict = checkpoint.get('model_state_dict', checkpoint)
    model.load_state_dict(state_dict)
    model.eval()
    
    # Load text tokenizer
    tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
    text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
    text_encoder.eval()
    
    model_loaded = True
except Exception as e:
    print(f"Error loading model: {e}")
    model_loaded = False
    label_to_class = {}

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def classify_product(image, text=""):
    if not model_loaded:
        return {"Error": "Model not loaded properly"}
    
    try:
        # Process image
        if image is None:
            return {"Error": "Please provide an image"}
        
        img = Image.fromarray(image).convert('RGB')
        img_tensor = transform(img).unsqueeze(0)
        
        # Process text
        if text.strip():
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
            with torch.no_grad():
                text_embeddings = text_encoder(**inputs).last_hidden_state.mean(dim=1)
        else:
            # Use zero embeddings if no text provided
            text_embeddings = torch.zeros(1, 768)
        
        # Get predictions
        with torch.no_grad():
            outputs = model(img_tensor, text_embeddings)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
        
        # Get top predictions
        top_k = min(10, len(label_to_class))
        top_probs, top_indices = torch.topk(probabilities, top_k)
        
        results = {}
        for prob, idx in zip(top_probs, top_indices):
            label = label_to_class.get(str(idx.item()), f"Class {idx.item()}")
            results[label] = float(prob)
        
        return results
    except Exception as e:
        return {"Error": f"Classification failed: {str(e)}"}

# Create Gradio interface
demo = gr.Interface(
    fn=classify_product,
    inputs=[
        gr.Image(label="Product Image"),
        gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2, value="")
    ],
    outputs=gr.Label(label="Classification Results", num_top_classes=10),
    title="🛍️ E-Commerce Product Classifier",
    description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Upload a product image and optionally provide a text description to classify it into the appropriate category.",
    examples=[],
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()