Ali Mohsin commited on
Commit
3dd2128
Β·
1 Parent(s): 1d3b4c2

bbbbbhtt555

Browse files
Files changed (1) hide show
  1. inference.py +47 -7
inference.py CHANGED
@@ -6,7 +6,11 @@ import torch
6
  import torch.nn as nn
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
- import clip
 
 
 
 
10
 
11
  from utils.transforms import build_inference_transform
12
  from models.resnet_embedder import ResNetItemEmbedder
@@ -64,9 +68,16 @@ class InferenceService:
64
 
65
  def _load_clip(self) -> None:
66
  """Load CLIP model for category detection."""
 
 
 
 
 
67
  try:
68
  print("πŸ”„ Loading CLIP model for category detection...")
69
- self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
 
 
70
  print("βœ… CLIP model loaded successfully")
71
  except Exception as e:
72
  print(f"❌ Failed to load CLIP model: {e}")
@@ -95,7 +106,7 @@ class InferenceService:
95
 
96
  # Prepare image and text
97
  image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
98
- text_inputs = clip.tokenize(categories).to(self.device)
99
 
100
  # Get predictions
101
  with torch.no_grad():
@@ -130,6 +141,32 @@ class InferenceService:
130
  except Exception as e:
131
  print(f"❌ CLIP category detection failed: {e}")
132
  return "other"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def _load_resnet(self) -> tuple[nn.Module, bool]:
135
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
@@ -311,15 +348,18 @@ class InferenceService:
311
  for i, it in enumerate(items):
312
  print(f"πŸ” DEBUG: Processing item {i}: id={it.get('id')}, has_image={it.get('image') is not None}, has_embedding={it.get('embedding') is not None}")
313
 
314
- # Auto-detect category using CLIP if not provided or is None
315
  category = it.get("category")
316
  if not category or category == "None" or category == "":
317
- if it.get("image") is not None:
318
  print(f"πŸ” DEBUG: Auto-detecting category for item {i} using CLIP...")
319
  category = self._detect_category_with_clip(it["image"])
320
  else:
321
- category = "other"
322
- print(f"πŸ” DEBUG: No image available for item {i}, using 'other' category")
 
 
 
323
 
324
  emb = it.get("embedding")
325
  if emb is None and it.get("image") is not None:
 
6
  import torch.nn as nn
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
+ try:
10
+ import open_clip
11
+ CLIP_AVAILABLE = True
12
+ except ImportError:
13
+ CLIP_AVAILABLE = False
14
 
15
  from utils.transforms import build_inference_transform
16
  from models.resnet_embedder import ResNetItemEmbedder
 
68
 
69
  def _load_clip(self) -> None:
70
  """Load CLIP model for category detection."""
71
+ if not CLIP_AVAILABLE:
72
+ print("⚠️ CLIP not available, using filename-based category detection")
73
+ self.clip_model, self.clip_preprocess = None, None
74
+ return
75
+
76
  try:
77
  print("πŸ”„ Loading CLIP model for category detection...")
78
+ self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
79
+ 'ViT-B-32', pretrained='laion2b_s34b_b79k', device=self.device
80
+ )
81
  print("βœ… CLIP model loaded successfully")
82
  except Exception as e:
83
  print(f"❌ Failed to load CLIP model: {e}")
 
106
 
107
  # Prepare image and text
108
  image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
109
+ text_inputs = open_clip.tokenize(categories).to(self.device)
110
 
111
  # Get predictions
112
  with torch.no_grad():
 
141
  except Exception as e:
142
  print(f"❌ CLIP category detection failed: {e}")
143
  return "other"
144
+
145
+ def _detect_category_from_filename(self, filename: str) -> str:
146
+ """Fallback: Detect category from filename using keyword matching."""
147
+ if not filename:
148
+ return "other"
149
+
150
+ filename_lower = filename.lower()
151
+
152
+ # Upper body items
153
+ if any(kw in filename_lower for kw in ["shirt", "top", "blouse", "tank", "hoodie", "sweater", "jacket", "blazer", "coat"]):
154
+ return "shirt"
155
+
156
+ # Bottom items
157
+ if any(kw in filename_lower for kw in ["pant", "jean", "short", "skirt", "trouser", "legging", "jogger"]):
158
+ return "pants"
159
+
160
+ # Shoes
161
+ if any(kw in filename_lower for kw in ["shoe", "boot", "sneaker", "sandal", "heel", "loafer", "oxford"]):
162
+ return "shoes"
163
+
164
+ # Accessories
165
+ if any(kw in filename_lower for kw in ["watch", "ring", "necklace", "bracelet", "bag", "hat", "belt", "scarf"]):
166
+ return "accessory"
167
+
168
+ # Default fallback
169
+ return "other"
170
 
171
  def _load_resnet(self) -> tuple[nn.Module, bool]:
172
  strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
 
348
  for i, it in enumerate(items):
349
  print(f"πŸ” DEBUG: Processing item {i}: id={it.get('id')}, has_image={it.get('image') is not None}, has_embedding={it.get('embedding') is not None}")
350
 
351
+ # Auto-detect category if not provided or is None
352
  category = it.get("category")
353
  if not category or category == "None" or category == "":
354
+ if it.get("image") is not None and self.clip_model is not None:
355
  print(f"πŸ” DEBUG: Auto-detecting category for item {i} using CLIP...")
356
  category = self._detect_category_with_clip(it["image"])
357
  else:
358
+ # Fallback to filename-based detection
359
+ filename = it.get("id", "")
360
+ print(f"πŸ” DEBUG: Auto-detecting category for item {i} using filename '{filename}'...")
361
+ category = self._detect_category_from_filename(filename)
362
+ print(f"πŸ” DEBUG: Filename-based detection result: '{category}'")
363
 
364
  emb = it.get("embedding")
365
  if emb is None and it.get("image") is not None: