Ali Mohsin commited on
Commit
6fa8724
Β·
1 Parent(s): 25ea1c9

I promise, these are the last changes

Browse files
Files changed (1) hide show
  1. inference.py +154 -26
inference.py CHANGED
@@ -386,42 +386,113 @@ class InferenceService:
386
  print("πŸ” DEBUG: Starting candidate generation...")
387
 
388
  # 2) Candidate generation with outfit templates
389
- rng = np.random.default_rng(int(context.get("seed", 42)))
 
 
 
 
390
  num_outfits = int(context.get("num_outfits", 5)) # Increased default from 3 to 5
391
  min_size, max_size = 2, 6 # Allow smaller outfits (2 items minimum)
392
  ids = list(range(len(proc_items)))
393
 
394
- # Outfit templates for cohesive styling
395
  outfit_templates = {
396
  "casual": {
397
  "style": "relaxed, comfortable, everyday",
398
- "preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant"],
399
  "color_palette": ["neutral", "denim", "white", "black", "gray"],
400
- "accessory_limit": 2
 
 
 
 
 
 
 
 
 
401
  },
402
  "smart_casual": {
403
  "style": "polished but relaxed, business casual",
404
- "preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo"],
405
  "color_palette": ["navy", "white", "khaki", "brown", "gray"],
406
- "accessory_limit": 3
 
 
 
 
 
 
 
 
 
407
  },
408
  "formal": {
409
  "style": "professional, elegant, sophisticated",
410
- "preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit"],
411
  "color_palette": ["navy", "black", "white", "gray", "charcoal"],
412
- "accessory_limit": 4
 
 
 
 
 
 
 
 
 
413
  },
414
  "sporty": {
415
  "style": "athletic, active, performance",
416
- "preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging"],
417
  "color_palette": ["bright", "neon", "white", "black", "primary colors"],
418
- "accessory_limit": 1
 
 
 
 
 
 
 
 
 
419
  }
420
  }
421
 
422
- # Select outfit template (can be passed in context or randomly selected)
423
- template_name = context.get("outfit_style", rng.choice(list(outfit_templates.keys())))
424
- template = outfit_templates[template_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  # Enhanced category-aware pools with diversity checks
427
  def cat_str(i: int) -> str:
@@ -483,7 +554,30 @@ class InferenceService:
483
  if any(pref in cat for pref in preferred_cats):
484
  matches += 1
485
 
486
- return matches / len(categories) if categories else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
  def get_category_type(cat: str) -> str:
489
  """Map category to outfit slot type with comprehensive taxonomy"""
@@ -599,15 +693,32 @@ class InferenceService:
599
  # Weighted combination
600
  return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
601
 
602
- # Generate diverse outfit combinations with randomization
603
  for _ in range(num_samples):
604
  subset = []
605
 
606
- # VARIABLE OUTFIT LENGTH: 2-5 items with different strategies
607
- outfit_length = rng.choice([2, 3, 4, 5], p=[0.1, 0.4, 0.4, 0.1]) # Prefer 3-4 items
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
 
609
  # Strategy 1: Core outfit (shirt + pants + shoes) + accessories
610
- if rng.random() < 0.5 and uppers and bottoms and shoes:
611
  # Core outfit: exactly 1 of each required slot
612
  subset.append(int(rng.choice(uppers)))
613
  subset.append(int(rng.choice(bottoms)))
@@ -634,7 +745,7 @@ class InferenceService:
634
  subset.extend(selected_others.tolist())
635
 
636
  # Strategy 2: Accessory-focused outfit (prioritize accessories)
637
- elif rng.random() < 0.3 and accs:
638
  # Start with accessories if available
639
  num_accs = min(outfit_length, len(accs))
640
  selected_accs = rng.choice(accs, size=num_accs, replace=False)
@@ -656,7 +767,7 @@ class InferenceService:
656
  subset.extend(selected_others.tolist())
657
 
658
  # Strategy 3: Flexible combination (no strict slot requirements)
659
- else:
660
  # Randomly select items from all categories
661
  all_items = list(ids)
662
  rng.shuffle(all_items)
@@ -807,16 +918,33 @@ class InferenceService:
807
  print(f"πŸ” DEBUG: Removed {len(scored) - len(unique_scored)} duplicate outfits")
808
  scored = unique_scored
809
 
810
- # Add randomization to prevent identical recommendations
811
  if len(scored) > num_outfits:
812
- # Take top 50% and randomly sample from them
813
- top_half = scored[:max(num_outfits * 2, len(scored) // 2)]
814
- rng.shuffle(top_half)
815
- topk = top_half[:num_outfits]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  else:
817
  # If we have fewer candidates than requested, shuffle them
818
  rng.shuffle(scored)
819
- topk = scored[:num_outfits]
820
 
821
  results = []
822
  for subset, adjusted_score, base_score in topk:
 
386
  print("πŸ” DEBUG: Starting candidate generation...")
387
 
388
  # 2) Candidate generation with outfit templates
389
+ # Use timestamp-based seed for better randomization
390
+ import time
391
+ seed = context.get("seed", int(time.time() * 1000) % 10000)
392
+ rng = np.random.default_rng(seed)
393
+ print(f"πŸ” DEBUG: Using random seed: {seed}")
394
  num_outfits = int(context.get("num_outfits", 5)) # Increased default from 3 to 5
395
  min_size, max_size = 2, 6 # Allow smaller outfits (2 items minimum)
396
  ids = list(range(len(proc_items)))
397
 
398
+ # Enhanced context-aware outfit templates
399
  outfit_templates = {
400
  "casual": {
401
  "style": "relaxed, comfortable, everyday",
402
+ "preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant", "shirt", "pants", "shoes"],
403
  "color_palette": ["neutral", "denim", "white", "black", "gray"],
404
+ "accessory_limit": 2,
405
+ "weather_modifiers": {
406
+ "hot": {"preferred_categories": ["tank", "short", "sandal", "light shirt"]},
407
+ "cold": {"preferred_categories": ["hoodie", "sweater", "jacket", "boot"]},
408
+ "rain": {"preferred_categories": ["jacket", "boot", "waterproof"]}
409
+ },
410
+ "occasion_modifiers": {
411
+ "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3},
412
+ "formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4}
413
+ }
414
  },
415
  "smart_casual": {
416
  "style": "polished but relaxed, business casual",
417
+ "preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo", "pants", "shoes"],
418
  "color_palette": ["navy", "white", "khaki", "brown", "gray"],
419
+ "accessory_limit": 3,
420
+ "weather_modifiers": {
421
+ "hot": {"preferred_categories": ["polo", "light shirt", "loafer"]},
422
+ "cold": {"preferred_categories": ["blazer", "sweater", "boot"]},
423
+ "rain": {"preferred_categories": ["blazer", "boot", "umbrella"]}
424
+ },
425
+ "occasion_modifiers": {
426
+ "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4},
427
+ "formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4}
428
+ }
429
  },
430
  "formal": {
431
  "style": "professional, elegant, sophisticated",
432
+ "preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit", "shirt", "pants", "shoes"],
433
  "color_palette": ["navy", "black", "white", "gray", "charcoal"],
434
+ "accessory_limit": 4,
435
+ "weather_modifiers": {
436
+ "hot": {"preferred_categories": ["light shirt", "light pant", "oxford"]},
437
+ "cold": {"preferred_categories": ["blazer", "suit", "boot"]},
438
+ "rain": {"preferred_categories": ["blazer", "boot", "umbrella"]}
439
+ },
440
+ "occasion_modifiers": {
441
+ "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4},
442
+ "casual": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
443
+ }
444
  },
445
  "sporty": {
446
  "style": "athletic, active, performance",
447
+ "preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging", "shirt", "pants", "shoes"],
448
  "color_palette": ["bright", "neon", "white", "black", "primary colors"],
449
+ "accessory_limit": 1,
450
+ "weather_modifiers": {
451
+ "hot": {"preferred_categories": ["tank", "short", "running shoe"]},
452
+ "cold": {"preferred_categories": ["hoodie", "legging", "running shoe"]},
453
+ "rain": {"preferred_categories": ["jacket", "running shoe", "cap"]}
454
+ },
455
+ "occasion_modifiers": {
456
+ "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 2},
457
+ "formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
458
+ }
459
  }
460
  }
461
 
462
+ # Context-aware template selection and modification
463
+ occasion = context.get("occasion", "casual")
464
+ weather = context.get("weather", "any")
465
+ outfit_style = context.get("outfit_style", "casual")
466
+
467
+ # Select base template
468
+ template_name = outfit_style
469
+ template = outfit_templates[template_name].copy()
470
+
471
+ # Apply weather modifications
472
+ if weather != "any" and weather in template.get("weather_modifiers", {}):
473
+ weather_mod = template["weather_modifiers"][weather]
474
+ template["preferred_categories"].extend(weather_mod.get("preferred_categories", []))
475
+ if "accessory_limit" in weather_mod:
476
+ template["accessory_limit"] = weather_mod["accessory_limit"]
477
+
478
+ # Apply occasion modifications
479
+ if occasion in template.get("occasion_modifiers", {}):
480
+ occasion_mod = template["occasion_modifiers"][occasion]
481
+ template["preferred_categories"].extend(occasion_mod.get("preferred_categories", []))
482
+ if "accessory_limit" in occasion_mod:
483
+ template["accessory_limit"] = occasion_mod["accessory_limit"]
484
+
485
+ # Remove duplicates and add context info
486
+ template["preferred_categories"] = list(set(template["preferred_categories"]))
487
+ template["context"] = {
488
+ "occasion": occasion,
489
+ "weather": weather,
490
+ "style": outfit_style
491
+ }
492
+
493
+ print(f"πŸ” DEBUG: Using template '{template_name}' with context: occasion={occasion}, weather={weather}")
494
+ print(f"πŸ” DEBUG: Template categories: {template['preferred_categories']}")
495
+ print(f"πŸ” DEBUG: Accessory limit: {template['accessory_limit']}")
496
 
497
  # Enhanced category-aware pools with diversity checks
498
  def cat_str(i: int) -> str:
 
554
  if any(pref in cat for pref in preferred_cats):
555
  matches += 1
556
 
557
+ base_score = matches / len(categories) if categories else 0.0
558
+
559
+ # Bonus for context-specific categories
560
+ context_bonus = 0.0
561
+ occasion = template["context"]["occasion"]
562
+ weather = template["context"]["weather"]
563
+
564
+ # Occasion-specific bonuses
565
+ if occasion == "business" and any("shirt" in cat or "pants" in cat for cat in categories):
566
+ context_bonus += 0.2
567
+ elif occasion == "formal" and any("shirt" in cat or "pants" in cat for cat in categories):
568
+ context_bonus += 0.3
569
+ elif occasion == "sport" and any("shirt" in cat or "pants" in cat for cat in categories):
570
+ context_bonus += 0.2
571
+
572
+ # Weather-specific bonuses
573
+ if weather == "hot" and any("shirt" in cat for cat in categories):
574
+ context_bonus += 0.1
575
+ elif weather == "cold" and any("shirt" in cat or "pants" in cat for cat in categories):
576
+ context_bonus += 0.1
577
+ elif weather == "rain" and any("shoes" in cat for cat in categories):
578
+ context_bonus += 0.1
579
+
580
+ return min(1.0, base_score + context_bonus)
581
 
582
  def get_category_type(cat: str) -> str:
583
  """Map category to outfit slot type with comprehensive taxonomy"""
 
693
  # Weighted combination
694
  return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
695
 
696
+ # Context-aware candidate generation with enhanced randomization
697
  for _ in range(num_samples):
698
  subset = []
699
 
700
+ # Context-aware outfit length selection
701
+ occasion = template["context"]["occasion"]
702
+ if occasion == "formal":
703
+ outfit_length = rng.choice([3, 4, 5], p=[0.2, 0.5, 0.3]) # Formal outfits tend to be more complete
704
+ elif occasion == "business":
705
+ outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) # Business outfits are well-rounded
706
+ elif occasion == "sport":
707
+ outfit_length = rng.choice([2, 3, 4], p=[0.3, 0.5, 0.2]) # Sport outfits can be minimal
708
+ else: # casual
709
+ outfit_length = rng.choice([2, 3, 4, 5], p=[0.2, 0.4, 0.3, 0.1]) # Casual is flexible
710
+
711
+ # Context-aware strategy selection
712
+ strategy_weights = [0.4, 0.3, 0.3] # Default weights
713
+ if occasion == "formal":
714
+ strategy_weights = [0.6, 0.2, 0.2] # Favor core outfits for formal
715
+ elif occasion == "sport":
716
+ strategy_weights = [0.3, 0.5, 0.2] # Favor flexible combinations for sport
717
+
718
+ strategy = rng.choice([0, 1, 2], p=strategy_weights)
719
 
720
  # Strategy 1: Core outfit (shirt + pants + shoes) + accessories
721
+ if strategy == 0 and uppers and bottoms and shoes:
722
  # Core outfit: exactly 1 of each required slot
723
  subset.append(int(rng.choice(uppers)))
724
  subset.append(int(rng.choice(bottoms)))
 
745
  subset.extend(selected_others.tolist())
746
 
747
  # Strategy 2: Accessory-focused outfit (prioritize accessories)
748
+ elif strategy == 1 and accs:
749
  # Start with accessories if available
750
  num_accs = min(outfit_length, len(accs))
751
  selected_accs = rng.choice(accs, size=num_accs, replace=False)
 
767
  subset.extend(selected_others.tolist())
768
 
769
  # Strategy 3: Flexible combination (no strict slot requirements)
770
+ elif strategy == 2:
771
  # Randomly select items from all categories
772
  all_items = list(ids)
773
  rng.shuffle(all_items)
 
918
  print(f"πŸ” DEBUG: Removed {len(scored) - len(unique_scored)} duplicate outfits")
919
  scored = unique_scored
920
 
921
+ # Enhanced randomization with context awareness
922
  if len(scored) > num_outfits:
923
+ # Context-aware selection: prefer higher-scoring outfits but add diversity
924
+ top_third = scored[:max(num_outfits * 3, len(scored) // 3)]
925
+ middle_third = scored[max(num_outfits * 3, len(scored) // 3):max(num_outfits * 6, len(scored) * 2 // 3)]
926
+
927
+ # Select mix of high-scoring and diverse outfits
928
+ selected = []
929
+
930
+ # Take 70% from top third (high quality)
931
+ top_count = int(num_outfits * 0.7)
932
+ rng.shuffle(top_third)
933
+ selected.extend(top_third[:top_count])
934
+
935
+ # Take 30% from middle third (diversity)
936
+ middle_count = num_outfits - len(selected)
937
+ if middle_count > 0 and middle_third:
938
+ rng.shuffle(middle_third)
939
+ selected.extend(middle_third[:middle_count])
940
+
941
+ # Shuffle final selection for randomness
942
+ rng.shuffle(selected)
943
+ topk = selected[:num_outfits]
944
  else:
945
  # If we have fewer candidates than requested, shuffle them
946
  rng.shuffle(scored)
947
+ topk = scored[:num_outfits]
948
 
949
  results = []
950
  for subset, adjusted_score, base_score in topk: