runthebandsup commited on
Commit
23aa755
·
verified ·
1 Parent(s): b0ca6f6

Fix model loading to use model_checkpoint.pth and add proper text encoding

Browse files
Files changed (1) hide show
  1. app.py +22 -27
app.py CHANGED
@@ -1,31 +1,16 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from torchvision import models, transforms
5
  from PIL import Image
6
  import json
7
  from huggingface_hub import hf_hub_download
8
- import os
9
-
10
- # Define the model architecture
11
- class FineGrainedClassifier(nn.Module):
12
- def __init__(self, num_classes, text_dim=768):
13
- super(FineGrainedClassifier, self).__init__()
14
- self.resnet = models.resnet50(pretrained=False)
15
- self.resnet.fc = nn.Identity()
16
- self.text_fc = nn.Linear(text_dim, 1024)
17
- self.fusion_fc = nn.Linear(2048 + 1024, num_classes)
18
-
19
- def forward(self, images, text_embeddings):
20
- image_features = self.resnet(images)
21
- text_features = self.text_fc(text_embeddings)
22
- combined = torch.cat((image_features, text_features), dim=1)
23
- output = self.fusion_fc(combined)
24
- return output
25
 
26
  # Download model files
27
  try:
28
- model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="best_model.pth")
29
  label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json")
30
 
31
  with open(label_path, 'r') as f:
@@ -35,6 +20,12 @@ try:
35
  model = FineGrainedClassifier(num_classes=num_classes)
36
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
37
  model.eval()
 
 
 
 
 
 
38
  model_loaded = True
39
  except Exception as e:
40
  print(f"Error loading model: {e}")
@@ -48,7 +39,7 @@ transform = transforms.Compose([
48
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49
  ])
50
 
51
- def classify_product(image, text):
52
  if not model_loaded:
53
  return {"Error": "Model not loaded properly"}
54
 
@@ -60,8 +51,14 @@ def classify_product(image, text):
60
  img = Image.fromarray(image).convert('RGB')
61
  img_tensor = transform(img).unsqueeze(0)
62
 
63
- # For now, use dummy text embeddings (in real implementation, use Jina embeddings)
64
- text_embeddings = torch.zeros(1, 768)
 
 
 
 
 
 
65
 
66
  # Get predictions
67
  with torch.no_grad():
@@ -86,14 +83,12 @@ demo = gr.Interface(
86
  fn=classify_product,
87
  inputs=[
88
  gr.Image(label="Product Image"),
89
- gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2)
90
  ],
91
  outputs=gr.Label(label="Classification Results", num_top_classes=10),
92
  title="🛍️ E-Commerce Product Classifier",
93
- description="Fast and accurate e-commerce product classification. Upload a product image to classify it into the appropriate category.",
94
- examples=[
95
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/guides/assets/demo_files/T-shirt.png", "Cotton T-Shirt"],
96
- ],
97
  theme="soft"
98
  )
99
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ from torchvision import transforms
5
  from PIL import Image
6
  import json
7
  from huggingface_hub import hf_hub_download
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from model import FineGrainedClassifier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Download model files
12
  try:
13
+ model_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="model_checkpoint.pth")
14
  label_path = hf_hub_download(repo_id="Maverick98/EcommerceClassifier", filename="label_to_class.json")
15
 
16
  with open(label_path, 'r') as f:
 
20
  model = FineGrainedClassifier(num_classes=num_classes)
21
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
22
  model.eval()
23
+
24
+ # Load text tokenizer
25
+ tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
26
+ text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
27
+ text_encoder.eval()
28
+
29
  model_loaded = True
30
  except Exception as e:
31
  print(f"Error loading model: {e}")
 
39
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40
  ])
41
 
42
+ def classify_product(image, text=""):
43
  if not model_loaded:
44
  return {"Error": "Model not loaded properly"}
45
 
 
51
  img = Image.fromarray(image).convert('RGB')
52
  img_tensor = transform(img).unsqueeze(0)
53
 
54
+ # Process text
55
+ if text.strip():
56
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
57
+ with torch.no_grad():
58
+ text_embeddings = text_encoder(**inputs).last_hidden_state.mean(dim=1)
59
+ else:
60
+ # Use zero embeddings if no text provided
61
+ text_embeddings = torch.zeros(1, 768)
62
 
63
  # Get predictions
64
  with torch.no_grad():
 
83
  fn=classify_product,
84
  inputs=[
85
  gr.Image(label="Product Image"),
86
+ gr.Textbox(label="Product Description (optional)", placeholder="Enter product title or description...", lines=2, value="")
87
  ],
88
  outputs=gr.Label(label="Classification Results", num_top_classes=10),
89
  title="🛍️ E-Commerce Product Classifier",
90
+ description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Upload a product image and optionally provide a text description to classify it into the appropriate category.",
91
+ examples=[],
 
 
92
  theme="soft"
93
  )
94