Ali Mohsin commited on
Commit
8d2d202
Β·
1 Parent(s): 4a2dbbf

Final fixes 5666444

Browse files
Files changed (2) hide show
  1. app.py +37 -0
  2. inference.py +64 -12
app.py CHANGED
@@ -296,6 +296,43 @@ threading.Thread(target=_background_bootstrap, daemon=True).start()
296
  def health() -> dict:
297
  return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  @app.post("/embed")
301
  def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
 
296
  def health() -> dict:
297
  return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
298
 
299
+ @app.get("/model-status")
300
+ def model_status() -> dict:
301
+ """Get detailed model loading status."""
302
+ return service.get_model_status()
303
+
304
+ @app.post("/reload-models")
305
+ def reload_models() -> dict:
306
+ """Force reload models - useful for debugging."""
307
+ try:
308
+ service.force_reload_models()
309
+ return {"status": "success", "message": "Models reloaded successfully"}
310
+ except Exception as e:
311
+ return {"status": "error", "message": str(e)}
312
+
313
+ @app.post("/test-recommend")
314
+ def test_recommend() -> dict:
315
+ """Test recommendation with dummy data to debug the issue."""
316
+ try:
317
+ # Create dummy items for testing
318
+ dummy_items = [
319
+ {"id": "test_1", "image": None, "category": "shirt"},
320
+ {"id": "test_2", "image": None, "category": "pants"},
321
+ {"id": "test_3", "image": None, "category": "shoes"}
322
+ ]
323
+
324
+ # Try to get recommendations
325
+ result = service.compose_outfits(dummy_items, {"num_outfits": 1})
326
+
327
+ return {
328
+ "status": "success",
329
+ "model_status": service.get_model_status(),
330
+ "result": result,
331
+ "result_length": len(result) if result else 0
332
+ }
333
+ except Exception as e:
334
+ return {"status": "error", "message": str(e), "model_status": service.get_model_status()}
335
+
336
 
337
  @app.post("/embed")
338
  def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
inference.py CHANGED
@@ -45,8 +45,8 @@ class InferenceService:
45
  # Disable gradients
46
  for m in [self.resnet, self.vit]:
47
  if m is not None:
48
- for p in m.parameters():
49
- p.requires_grad_(False)
50
 
51
  # Update overall status
52
  self.models_loaded = self.resnet_loaded and self.vit_loaded
@@ -177,8 +177,8 @@ class InferenceService:
177
  # Disable gradients
178
  for m in [self.resnet, self.vit]:
179
  if m is not None:
180
- for p in m.parameters():
181
- p.requires_grad_(False)
182
 
183
  # Update overall status
184
  self.models_loaded = self.resnet_loaded and self.vit_loaded
@@ -202,12 +202,12 @@ class InferenceService:
202
  return []
203
 
204
  try:
205
- batch = torch.stack([self.transform(img) for img in images])
206
- batch = batch.to(self.device, memory_format=torch.channels_last)
207
- use_amp = (self.device == "cuda")
208
- with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
209
- emb = self.resnet(batch)
210
- emb = nn.functional.normalize(emb, dim=-1)
211
  result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
212
  print(f"πŸ” DEBUG: Successfully generated {len(result)} embeddings")
213
  return result
@@ -257,6 +257,8 @@ class InferenceService:
257
  if len(proc_items) < 2:
258
  print("πŸ” DEBUG: Returning empty array - not enough items (< 2)")
259
  return []
 
 
260
 
261
  # 2) Candidate generation with outfit templates
262
  rng = np.random.default_rng(int(context.get("seed", 42)))
@@ -299,6 +301,11 @@ class InferenceService:
299
  # Enhanced category-aware pools with diversity checks
300
  def cat_str(i: int) -> str:
301
  return (proc_items[i].get("category") or "").lower()
 
 
 
 
 
302
 
303
  def extract_color_from_category(category: str) -> str:
304
  """Extract color information from category name"""
@@ -356,6 +363,7 @@ class InferenceService:
356
  def get_category_type(cat: str) -> str:
357
  """Map category to outfit slot type with comprehensive taxonomy"""
358
  cat_lower = cat.lower().strip()
 
359
 
360
  # Upper body items (tops, outerwear)
361
  upper_keywords = [
@@ -405,14 +413,23 @@ class InferenceService:
405
  return "other"
406
 
407
  # Create category pools
 
408
  uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
409
  bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
410
  shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
411
  accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
412
  others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
 
 
 
 
 
 
 
413
 
414
  candidates: List[List[int]] = []
415
  num_samples = max(num_outfits * 12, 24)
 
416
 
417
  def has_category_diversity(subset: List[int]) -> bool:
418
  """Check if subset has good category diversity"""
@@ -468,7 +485,11 @@ class InferenceService:
468
  subset = list(set(subset))
469
  if len(subset) >= 3: # At least 3 items for a valid outfit
470
  candidates.append(subset)
 
 
471
 
 
 
472
  # 3) Score using ViT
473
  def score_subset(idx_subset: List[int]) -> float:
474
  embs = torch.tensor(
@@ -609,7 +630,7 @@ class InferenceService:
609
  })
610
 
611
  return results
612
-
613
  def get_model_status(self) -> Dict[str, Any]:
614
  """Get current model loading status and errors."""
615
  return {
@@ -617,7 +638,38 @@ class InferenceService:
617
  "resnet_loaded": self.resnet_loaded,
618
  "vit_loaded": self.vit_loaded,
619
  "errors": self.model_errors,
620
- "can_recommend": self.models_loaded
 
 
621
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
 
 
45
  # Disable gradients
46
  for m in [self.resnet, self.vit]:
47
  if m is not None:
48
+ for p in m.parameters():
49
+ p.requires_grad_(False)
50
 
51
  # Update overall status
52
  self.models_loaded = self.resnet_loaded and self.vit_loaded
 
177
  # Disable gradients
178
  for m in [self.resnet, self.vit]:
179
  if m is not None:
180
+ for p in m.parameters():
181
+ p.requires_grad_(False)
182
 
183
  # Update overall status
184
  self.models_loaded = self.resnet_loaded and self.vit_loaded
 
202
  return []
203
 
204
  try:
205
+ batch = torch.stack([self.transform(img) for img in images])
206
+ batch = batch.to(self.device, memory_format=torch.channels_last)
207
+ use_amp = (self.device == "cuda")
208
+ with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
209
+ emb = self.resnet(batch)
210
+ emb = nn.functional.normalize(emb, dim=-1)
211
  result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
212
  print(f"πŸ” DEBUG: Successfully generated {len(result)} embeddings")
213
  return result
 
257
  if len(proc_items) < 2:
258
  print("πŸ” DEBUG: Returning empty array - not enough items (< 2)")
259
  return []
260
+
261
+ print("πŸ” DEBUG: Starting candidate generation...")
262
 
263
  # 2) Candidate generation with outfit templates
264
  rng = np.random.default_rng(int(context.get("seed", 42)))
 
301
  # Enhanced category-aware pools with diversity checks
302
  def cat_str(i: int) -> str:
303
  return (proc_items[i].get("category") or "").lower()
304
+
305
+ print("πŸ” DEBUG: Building category pools...")
306
+ # Debug: Print all categories
307
+ for i in range(len(proc_items)):
308
+ print(f"πŸ” DEBUG: Item {i}: category='{proc_items[i].get('category')}' -> cat_str='{cat_str(i)}'")
309
 
310
  def extract_color_from_category(category: str) -> str:
311
  """Extract color information from category name"""
 
363
  def get_category_type(cat: str) -> str:
364
  """Map category to outfit slot type with comprehensive taxonomy"""
365
  cat_lower = cat.lower().strip()
366
+ print(f"πŸ” DEBUG: Mapping category '{cat}' -> '{cat_lower}'")
367
 
368
  # Upper body items (tops, outerwear)
369
  upper_keywords = [
 
413
  return "other"
414
 
415
  # Create category pools
416
+ print("πŸ” DEBUG: Building category pools...")
417
  uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
418
  bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
419
  shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
420
  accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
421
  others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
422
+
423
+ print(f"πŸ” DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
424
+
425
+ # Check if we have the minimum required items
426
+ if len(uppers) == 0 or len(bottoms) == 0 or len(shoes) == 0:
427
+ print(f"πŸ” DEBUG: Missing required categories - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}")
428
+ return []
429
 
430
  candidates: List[List[int]] = []
431
  num_samples = max(num_outfits * 12, 24)
432
+ print(f"πŸ” DEBUG: Generating {num_samples} candidate outfits...")
433
 
434
  def has_category_diversity(subset: List[int]) -> bool:
435
  """Check if subset has good category diversity"""
 
485
  subset = list(set(subset))
486
  if len(subset) >= 3: # At least 3 items for a valid outfit
487
  candidates.append(subset)
488
+ if len(candidates) % 10 == 0: # Log every 10 candidates
489
+ print(f"πŸ” DEBUG: Generated {len(candidates)} candidates so far...")
490
 
491
+ print(f"πŸ” DEBUG: Generated {len(candidates)} total candidates")
492
+
493
  # 3) Score using ViT
494
  def score_subset(idx_subset: List[int]) -> float:
495
  embs = torch.tensor(
 
630
  })
631
 
632
  return results
633
+
634
  def get_model_status(self) -> Dict[str, Any]:
635
  """Get current model loading status and errors."""
636
  return {
 
638
  "resnet_loaded": self.resnet_loaded,
639
  "vit_loaded": self.vit_loaded,
640
  "errors": self.model_errors,
641
+ "can_recommend": self.models_loaded,
642
+ "resnet_model": self.resnet is not None,
643
+ "vit_model": self.vit is not None
644
  }
645
+
646
+ def force_reload_models(self) -> None:
647
+ """Force reload models and update status - useful for debugging."""
648
+ print("πŸ”„ Force reloading models...")
649
+ self.resnet, self.resnet_loaded = self._load_resnet()
650
+ self.vit, self.vit_loaded = self._load_vit()
651
+
652
+ # Move to device and set eval mode
653
+ if self.resnet_loaded:
654
+ self.resnet = self.resnet.to(self.device).eval()
655
+ if self.vit_loaded:
656
+ self.vit = self.vit.to(self.device).eval()
657
+
658
+ # Disable gradients
659
+ for m in [self.resnet, self.vit]:
660
+ if m is not None:
661
+ for p in m.parameters():
662
+ p.requires_grad_(False)
663
+
664
+ # Update overall status
665
+ self.models_loaded = self.resnet_loaded and self.vit_loaded
666
+ print(f"πŸ”„ Models reloaded: resnet={self.resnet_loaded}, vit={self.vit_loaded}, overall={self.models_loaded}")
667
+
668
+ if not self.models_loaded:
669
+ self.model_errors = []
670
+ if not self.resnet_loaded:
671
+ self.model_errors.append("ResNet: No trained weights found")
672
+ if not self.vit_loaded:
673
+ self.model_errors.append("ViT: No trained weights found")
674
 
675