Ali Mohsin commited on
Commit
84ebcb0
Β·
1 Parent(s): 941ea8d

This is the last update

Browse files
Files changed (1) hide show
  1. inference.py +42 -17
inference.py CHANGED
@@ -27,7 +27,7 @@ class InferenceService:
27
  self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
28
  self.resnet_version = "resnet_v1"
29
  self.vit_version = "vit_v1"
30
-
31
  # Model loading status tracking
32
  self.models_loaded = False
33
  self.model_errors = []
@@ -45,9 +45,9 @@ class InferenceService:
45
  # Disable gradients
46
  for m in [self.resnet, self.vit]:
47
  if m is not None:
48
- for p in m.parameters():
49
- p.requires_grad_(False)
50
-
51
  # Update overall status
52
  self.models_loaded = self.resnet_loaded and self.vit_loaded
53
  if not self.models_loaded:
@@ -177,8 +177,8 @@ class InferenceService:
177
  # Disable gradients
178
  for m in [self.resnet, self.vit]:
179
  if m is not None:
180
- for p in m.parameters():
181
- p.requires_grad_(False)
182
 
183
  # Update overall status
184
  self.models_loaded = self.resnet_loaded and self.vit_loaded
@@ -191,18 +191,37 @@ class InferenceService:
191
 
192
  @torch.inference_mode()
193
  def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
 
194
  if len(images) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  return []
196
- batch = torch.stack([self.transform(img) for img in images])
197
- batch = batch.to(self.device, memory_format=torch.channels_last)
198
- use_amp = (self.device == "cuda")
199
- with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
200
- emb = self.resnet(batch)
201
- emb = nn.functional.normalize(emb, dim=-1)
202
- return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
203
 
204
  @torch.inference_mode()
205
  def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
206
  # Validate that models are properly loaded
207
  if not self.models_loaded:
208
  error_msg = f"❌ Cannot provide recommendations: Models not properly loaded. Errors: {self.model_errors}"
@@ -215,13 +234,16 @@ class InferenceService:
215
 
216
  # 1) Ensure embeddings for each input item
217
  proc_items: List[Dict[str, Any]] = []
218
- for it in items:
 
219
  emb = it.get("embedding")
220
  if emb is None and it.get("image") is not None:
221
  # Compute on-the-fly if image provided
 
222
  emb = self.embed_images([it["image"]])[0]
223
  if emb is None:
224
  # Skip if we cannot get an embedding
 
225
  continue
226
  emb_np = np.asarray(emb, dtype=np.float32)
227
  proc_items.append({
@@ -229,8 +251,11 @@ class InferenceService:
229
  "embedding": emb_np,
230
  "category": it.get("category")
231
  })
 
232
 
 
233
  if len(proc_items) < 2:
 
234
  return []
235
 
236
  # 2) Candidate generation with outfit templates
@@ -238,7 +263,7 @@ class InferenceService:
238
  num_outfits = int(context.get("num_outfits", 3))
239
  min_size, max_size = 4, 6
240
  ids = list(range(len(proc_items)))
241
-
242
  # Outfit templates for cohesive styling
243
  outfit_templates = {
244
  "casual": {
@@ -274,7 +299,7 @@ class InferenceService:
274
  # Enhanced category-aware pools with diversity checks
275
  def cat_str(i: int) -> str:
276
  return (proc_items[i].get("category") or "").lower()
277
-
278
  def extract_color_from_category(category: str) -> str:
279
  """Extract color information from category name"""
280
  category_lower = category.lower()
@@ -442,7 +467,7 @@ class InferenceService:
442
  # Remove duplicates and validate
443
  subset = list(set(subset))
444
  if len(subset) >= 3: # At least 3 items for a valid outfit
445
- candidates.append(subset)
446
 
447
  # 3) Score using ViT
448
  def score_subset(idx_subset: List[int]) -> float:
 
27
  self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
28
  self.resnet_version = "resnet_v1"
29
  self.vit_version = "vit_v1"
30
+
31
  # Model loading status tracking
32
  self.models_loaded = False
33
  self.model_errors = []
 
45
  # Disable gradients
46
  for m in [self.resnet, self.vit]:
47
  if m is not None:
48
+ for p in m.parameters():
49
+ p.requires_grad_(False)
50
+
51
  # Update overall status
52
  self.models_loaded = self.resnet_loaded and self.vit_loaded
53
  if not self.models_loaded:
 
177
  # Disable gradients
178
  for m in [self.resnet, self.vit]:
179
  if m is not None:
180
+ for p in m.parameters():
181
+ p.requires_grad_(False)
182
 
183
  # Update overall status
184
  self.models_loaded = self.resnet_loaded and self.vit_loaded
 
191
 
192
  @torch.inference_mode()
193
  def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
194
+ print(f"πŸ” DEBUG: embed_images called with {len(images)} images")
195
  if len(images) == 0:
196
+ print("πŸ” DEBUG: No images provided, returning empty list")
197
+ return []
198
+
199
+ print(f"πŸ” DEBUG: ResNet model is None: {self.resnet is None}")
200
+ if self.resnet is None:
201
+ print("πŸ” DEBUG: ResNet model is None, returning empty list")
202
+ return []
203
+
204
+ try:
205
+ batch = torch.stack([self.transform(img) for img in images])
206
+ batch = batch.to(self.device, memory_format=torch.channels_last)
207
+ use_amp = (self.device == "cuda")
208
+ with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
209
+ emb = self.resnet(batch)
210
+ emb = nn.functional.normalize(emb, dim=-1)
211
+ result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
212
+ print(f"πŸ” DEBUG: Successfully generated {len(result)} embeddings")
213
+ return result
214
+ except Exception as e:
215
+ print(f"πŸ” DEBUG: Error in embed_images: {e}")
216
  return []
 
 
 
 
 
 
 
217
 
218
  @torch.inference_mode()
219
  def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
220
+ # Debug: Print model status
221
+ print(f"πŸ” DEBUG: models_loaded={self.models_loaded}, resnet_loaded={self.resnet_loaded}, vit_loaded={self.vit_loaded}")
222
+ print(f"πŸ” DEBUG: model_errors={self.model_errors}")
223
+ print(f"πŸ” DEBUG: items count={len(items)}")
224
+
225
  # Validate that models are properly loaded
226
  if not self.models_loaded:
227
  error_msg = f"❌ Cannot provide recommendations: Models not properly loaded. Errors: {self.model_errors}"
 
234
 
235
  # 1) Ensure embeddings for each input item
236
  proc_items: List[Dict[str, Any]] = []
237
+ for i, it in enumerate(items):
238
+ print(f"πŸ” DEBUG: Processing item {i}: id={it.get('id')}, has_image={it.get('image') is not None}, has_embedding={it.get('embedding') is not None}")
239
  emb = it.get("embedding")
240
  if emb is None and it.get("image") is not None:
241
  # Compute on-the-fly if image provided
242
+ print(f"πŸ” DEBUG: Generating embedding for item {i}")
243
  emb = self.embed_images([it["image"]])[0]
244
  if emb is None:
245
  # Skip if we cannot get an embedding
246
+ print(f"πŸ” DEBUG: Skipping item {i} - no embedding generated")
247
  continue
248
  emb_np = np.asarray(emb, dtype=np.float32)
249
  proc_items.append({
 
251
  "embedding": emb_np,
252
  "category": it.get("category")
253
  })
254
+ print(f"πŸ” DEBUG: Added item {i} to proc_items, total: {len(proc_items)}")
255
 
256
+ print(f"πŸ” DEBUG: Final proc_items count: {len(proc_items)}")
257
  if len(proc_items) < 2:
258
+ print("πŸ” DEBUG: Returning empty array - not enough items (< 2)")
259
  return []
260
 
261
  # 2) Candidate generation with outfit templates
 
263
  num_outfits = int(context.get("num_outfits", 3))
264
  min_size, max_size = 4, 6
265
  ids = list(range(len(proc_items)))
266
+
267
  # Outfit templates for cohesive styling
268
  outfit_templates = {
269
  "casual": {
 
299
  # Enhanced category-aware pools with diversity checks
300
  def cat_str(i: int) -> str:
301
  return (proc_items[i].get("category") or "").lower()
302
+
303
  def extract_color_from_category(category: str) -> str:
304
  """Extract color information from category name"""
305
  category_lower = category.lower()
 
467
  # Remove duplicates and validate
468
  subset = list(set(subset))
469
  if len(subset) >= 3: # At least 3 items for a valid outfit
470
+ candidates.append(subset)
471
 
472
  # 3) Score using ViT
473
  def score_subset(idx_subset: List[int]) -> float: