Ali Mohsin commited on
Commit
ac45468
Β·
1 Parent(s): 3dd2128

More robust recommendations overall

Browse files
Files changed (1) hide show
  1. inference.py +130 -51
inference.py CHANGED
@@ -382,7 +382,7 @@ class InferenceService:
382
  if len(proc_items) < 2:
383
  print("πŸ” DEBUG: Returning empty array - not enough items (< 2)")
384
  return []
385
-
386
  print("πŸ” DEBUG: Starting candidate generation...")
387
 
388
  # 2) Candidate generation with outfit templates
@@ -426,7 +426,7 @@ class InferenceService:
426
  # Enhanced category-aware pools with diversity checks
427
  def cat_str(i: int) -> str:
428
  return (proc_items[i].get("category") or "").lower()
429
-
430
  print("πŸ” DEBUG: Building category pools...")
431
  # Debug: Print all categories
432
  for i in range(len(proc_items)):
@@ -490,6 +490,16 @@ class InferenceService:
490
  cat_lower = cat.lower().strip()
491
  print(f"πŸ” DEBUG: Mapping category '{cat}' -> '{cat_lower}'")
492
 
 
 
 
 
 
 
 
 
 
 
493
  # Upper body items (tops, outerwear)
494
  upper_keywords = [
495
  "top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
@@ -547,74 +557,131 @@ class InferenceService:
547
 
548
  print(f"πŸ” DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
549
 
550
- # Check if we have the minimum required items
551
- if len(uppers) == 0 or len(bottoms) == 0 or len(shoes) == 0:
552
- print(f"πŸ” DEBUG: Missing required categories - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}")
 
553
  return []
 
 
 
 
 
554
 
555
  candidates: List[List[int]] = []
556
- num_samples = max(num_outfits * 12, 24)
557
  print(f"πŸ” DEBUG: Generating {num_samples} candidate outfits...")
558
 
559
  def has_category_diversity(subset: List[int]) -> bool:
560
  """Check if subset has good category diversity"""
561
  categories = [get_category_type(cat_str(i)) for i in subset]
562
  unique_categories = set(categories)
563
- # Require at least 3 different category types for good diversity
564
- return len(unique_categories) >= 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
 
566
  for _ in range(num_samples):
567
  subset = []
568
 
569
- # EXACT SLOT CONSTRAINTS: Exactly 1 upper, 1 bottom, 1 shoe, ≀2 accessories
570
- if uppers and bottoms and shoes:
 
 
 
571
  # Core outfit: exactly 1 of each required slot
572
  subset.append(int(rng.choice(uppers)))
573
  subset.append(int(rng.choice(bottoms)))
574
  subset.append(int(rng.choice(shoes)))
575
 
576
- # Add accessories based on template limit
577
- if accs:
578
- max_accs = template["accessory_limit"]
579
- num_accs = rng.integers(1, min(max_accs + 1, len(accs) + 1))
 
580
  available_accs = [i for i in accs if i not in subset]
581
- if available_accs:
582
  selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
583
  subset.extend(selected_accs.tolist())
584
 
585
- # Add 0-1 other items for variety (but not if it would exceed max_size)
586
- if others and len(subset) < max_size:
 
587
  available_others = [i for i in others if i not in subset]
588
- if available_others and rng.random() < 0.3: # 30% chance to add other item
589
- subset.append(int(rng.choice(available_others)))
590
- else:
591
- # Fallback: ensure we have at least 3 items with category diversity
592
- required_categories = []
593
- if uppers: required_categories.append(("upper", uppers))
594
- if bottoms: required_categories.append(("bottom", bottoms))
595
- if shoes: required_categories.append(("shoe", shoes))
 
 
596
 
597
- # Add one from each available required category
598
- for cat_type, cat_items in required_categories:
599
- subset.append(int(rng.choice(cat_items)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
- # Add accessories if available
602
- if accs and len(subset) < max_size:
603
- num_accs = rng.integers(1, min(3, len(accs) + 1))
604
- available_accs = [i for i in accs if i not in subset]
605
- if available_accs:
606
- selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
607
- subset.extend(selected_accs.tolist())
 
 
 
 
 
 
 
608
 
609
  # Remove duplicates and validate
610
  subset = list(set(subset))
611
- if len(subset) >= 3: # At least 3 items for a valid outfit
 
 
612
  candidates.append(subset)
613
  if len(candidates) % 10 == 0: # Log every 10 candidates
614
  print(f"πŸ” DEBUG: Generated {len(candidates)} candidates so far...")
615
 
616
  print(f"πŸ” DEBUG: Generated {len(candidates)} total candidates")
617
-
618
  # 3) Score using ViT
619
  def score_subset(idx_subset: List[int]) -> float:
620
  embs = torch.tensor(
@@ -628,26 +695,28 @@ class InferenceService:
628
 
629
  # Enhanced validation with strict slot constraints
630
  def is_valid_outfit(subset: List[int]) -> bool:
631
- """Check if outfit meets exact slot requirements"""
 
 
 
632
  categories = [get_category_type(cat_str(i)) for i in subset]
633
  category_counts = {}
634
 
635
  for cat in categories:
636
  category_counts[cat] = category_counts.get(cat, 0) + 1
637
 
638
- # STRICT VALIDATION:
639
- # - Exactly 1 upper, 1 bottom, 1 shoe
640
- # - ≀2 accessories
641
- # - No other duplicates
642
- if category_counts.get("upper", 0) != 1:
 
643
  return False
644
- if category_counts.get("bottom", 0) != 1:
645
- return False
646
- if category_counts.get("shoe", 0) != 1:
647
- return False
648
- if category_counts.get("accessory", 0) > 2:
649
  return False
650
- if category_counts.get("other", 0) > 1:
651
  return False
652
 
653
  return True
@@ -714,9 +783,19 @@ class InferenceService:
714
  adjusted_score = calculate_outfit_penalty(subset, base_score)
715
  scored.append((subset, adjusted_score, base_score))
716
 
717
- # Sort by penalty-adjusted score
718
  scored.sort(key=lambda x: x[1], reverse=True)
719
- topk = scored[:num_outfits]
 
 
 
 
 
 
 
 
 
 
720
 
721
  results = []
722
  for subset, adjusted_score, base_score in topk:
 
382
  if len(proc_items) < 2:
383
  print("πŸ” DEBUG: Returning empty array - not enough items (< 2)")
384
  return []
385
+
386
  print("πŸ” DEBUG: Starting candidate generation...")
387
 
388
  # 2) Candidate generation with outfit templates
 
426
  # Enhanced category-aware pools with diversity checks
427
  def cat_str(i: int) -> str:
428
  return (proc_items[i].get("category") or "").lower()
429
+
430
  print("πŸ” DEBUG: Building category pools...")
431
  # Debug: Print all categories
432
  for i in range(len(proc_items)):
 
490
  cat_lower = cat.lower().strip()
491
  print(f"πŸ” DEBUG: Mapping category '{cat}' -> '{cat_lower}'")
492
 
493
+ # Direct mapping for CLIP-detected categories
494
+ if cat_lower == "shirt":
495
+ return "upper"
496
+ elif cat_lower == "pants":
497
+ return "bottom"
498
+ elif cat_lower == "shoes":
499
+ return "shoe"
500
+ elif cat_lower == "accessory":
501
+ return "accessory"
502
+
503
  # Upper body items (tops, outerwear)
504
  upper_keywords = [
505
  "top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
 
557
 
558
  print(f"πŸ” DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
559
 
560
+ # Check if we have enough items to create outfits
561
+ total_items = len(uppers) + len(bottoms) + len(shoes) + len(accs) + len(others)
562
+ if total_items < 2:
563
+ print(f"πŸ” DEBUG: Not enough items to create outfits - total: {total_items}")
564
  return []
565
+
566
+ # Warn if we're missing core categories but still try to generate
567
+ if len(uppers) == 0 or len(bottoms) == 0 or len(shoes) == 0:
568
+ print(f"πŸ” DEBUG: Missing some core categories - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}")
569
+ print(f"πŸ” DEBUG: Will use flexible outfit generation with available items")
570
 
571
  candidates: List[List[int]] = []
572
+ num_samples = max(num_outfits * 15, 30) # Increased for more variety
573
  print(f"πŸ” DEBUG: Generating {num_samples} candidate outfits...")
574
 
575
  def has_category_diversity(subset: List[int]) -> bool:
576
  """Check if subset has good category diversity"""
577
  categories = [get_category_type(cat_str(i)) for i in subset]
578
  unique_categories = set(categories)
579
+ # Require at least 2 different category types for good diversity
580
+ return len(unique_categories) >= 2
581
+
582
+ def calculate_outfit_score(subset: List[int]) -> float:
583
+ """Calculate overall outfit quality score"""
584
+ if len(subset) < 2:
585
+ return 0.0
586
+
587
+ # Base score from category diversity
588
+ diversity_score = len(set(get_category_type(cat_str(i)) for i in subset)) / 4.0
589
+
590
+ # Style consistency score
591
+ style_score = calculate_style_consistency_score(subset)
592
+
593
+ # Color consistency score
594
+ color_score = calculate_color_consistency_score(subset)
595
+
596
+ # Length appropriateness (prefer 3-4 items)
597
+ length_score = 1.0 if 3 <= len(subset) <= 4 else 0.7
598
+
599
+ # Weighted combination
600
+ return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
601
 
602
+ # Generate diverse outfit combinations with randomization
603
  for _ in range(num_samples):
604
  subset = []
605
 
606
+ # VARIABLE OUTFIT LENGTH: 2-5 items with different strategies
607
+ outfit_length = rng.choice([2, 3, 4, 5], p=[0.1, 0.4, 0.4, 0.1]) # Prefer 3-4 items
608
+
609
+ # Strategy 1: Core outfit (shirt + pants + shoes) + accessories
610
+ if rng.random() < 0.6 and uppers and bottoms and shoes:
611
  # Core outfit: exactly 1 of each required slot
612
  subset.append(int(rng.choice(uppers)))
613
  subset.append(int(rng.choice(bottoms)))
614
  subset.append(int(rng.choice(shoes)))
615
 
616
+ # Add accessories based on template limit and remaining slots
617
+ remaining_slots = outfit_length - len(subset)
618
+ if accs and remaining_slots > 0:
619
+ max_accs = min(template["accessory_limit"], remaining_slots, len(accs))
620
+ num_accs = rng.integers(0, max_accs + 1)
621
  available_accs = [i for i in accs if i not in subset]
622
+ if available_accs and num_accs > 0:
623
  selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
624
  subset.extend(selected_accs.tolist())
625
 
626
+ # Fill remaining slots with other items
627
+ remaining_slots = outfit_length - len(subset)
628
+ if others and remaining_slots > 0:
629
  available_others = [i for i in others if i not in subset]
630
+ if available_others:
631
+ num_others = min(remaining_slots, len(available_others))
632
+ selected_others = rng.choice(available_others, size=num_others, replace=False)
633
+ subset.extend(selected_others.tolist())
634
+
635
+ # Strategy 2: Flexible combination (no strict slot requirements)
636
+ elif rng.random() < 0.3:
637
+ # Randomly select items from all categories
638
+ all_items = list(ids)
639
+ rng.shuffle(all_items)
640
 
641
+ # Select items ensuring diversity
642
+ selected_categories = set()
643
+ for item in all_items:
644
+ if len(subset) >= outfit_length:
645
+ break
646
+ item_category = get_category_type(cat_str(item))
647
+ if item_category not in selected_categories or len(subset) < 2:
648
+ subset.append(item)
649
+ selected_categories.add(item_category)
650
+
651
+ # Strategy 3: Accessory-focused outfit (for small wardrobes)
652
+ else:
653
+ # Start with accessories if available
654
+ if accs:
655
+ num_accs = min(outfit_length, len(accs))
656
+ selected_accs = rng.choice(accs, size=num_accs, replace=False)
657
+ subset.extend(selected_accs.tolist())
658
 
659
+ # Fill remaining with other categories
660
+ remaining_slots = outfit_length - len(subset)
661
+ if remaining_slots > 0:
662
+ other_categories = []
663
+ if uppers: other_categories.extend(uppers)
664
+ if bottoms: other_categories.extend(bottoms)
665
+ if shoes: other_categories.extend(shoes)
666
+ if others: other_categories.extend(others)
667
+
668
+ available_others = [i for i in other_categories if i not in subset]
669
+ if available_others:
670
+ num_others = min(remaining_slots, len(available_others))
671
+ selected_others = rng.choice(available_others, size=num_others, replace=False)
672
+ subset.extend(selected_others.tolist())
673
 
674
  # Remove duplicates and validate
675
  subset = list(set(subset))
676
+ if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset):
677
+ # Add randomization factor to prevent identical recommendations
678
+ subset = rng.permutation(subset).tolist() # Randomize order
679
  candidates.append(subset)
680
  if len(candidates) % 10 == 0: # Log every 10 candidates
681
  print(f"πŸ” DEBUG: Generated {len(candidates)} candidates so far...")
682
 
683
  print(f"πŸ” DEBUG: Generated {len(candidates)} total candidates")
684
+
685
  # 3) Score using ViT
686
  def score_subset(idx_subset: List[int]) -> float:
687
  embs = torch.tensor(
 
695
 
696
  # Enhanced validation with strict slot constraints
697
  def is_valid_outfit(subset: List[int]) -> bool:
698
+ """Check if outfit meets flexible requirements"""
699
+ if len(subset) < 2 or len(subset) > max_size:
700
+ return False
701
+
702
  categories = [get_category_type(cat_str(i)) for i in subset]
703
  category_counts = {}
704
 
705
  for cat in categories:
706
  category_counts[cat] = category_counts.get(cat, 0) + 1
707
 
708
+ # FLEXIBLE VALIDATION:
709
+ # - At least 2 different categories
710
+ # - Reasonable limits per category
711
+ # - Allow variable outfit lengths
712
+ unique_categories = len(set(categories))
713
+ if unique_categories < 2:
714
  return False
715
+
716
+ # Reasonable limits (more flexible than before)
717
+ if category_counts.get("accessory", 0) > 3: # Allow up to 3 accessories
 
 
718
  return False
719
+ if category_counts.get("other", 0) > 2: # Allow up to 2 other items
720
  return False
721
 
722
  return True
 
783
  adjusted_score = calculate_outfit_penalty(subset, base_score)
784
  scored.append((subset, adjusted_score, base_score))
785
 
786
+ # Sort by penalty-adjusted score with randomization
787
  scored.sort(key=lambda x: x[1], reverse=True)
788
+
789
+ # Add randomization to prevent identical recommendations
790
+ if len(scored) > num_outfits:
791
+ # Take top 50% and randomly sample from them
792
+ top_half = scored[:max(num_outfits * 2, len(scored) // 2)]
793
+ rng.shuffle(top_half)
794
+ topk = top_half[:num_outfits]
795
+ else:
796
+ # If we have fewer candidates than requested, shuffle them
797
+ rng.shuffle(scored)
798
+ topk = scored[:num_outfits]
799
 
800
  results = []
801
  for subset, adjusted_score, base_score in topk: