Ali Mohsin commited on
Commit
08dc87f
Β·
1 Parent(s): e13ff13

nee last fixes

Browse files
Files changed (1) hide show
  1. inference.py +345 -91
inference.py CHANGED
@@ -551,28 +551,62 @@ class InferenceService:
551
  return "unknown"
552
 
553
  def calculate_color_consistency_score(items: List[int]) -> float:
554
- """Calculate color consistency score for outfit items"""
555
  colors = [extract_color_from_category(cat_str(i)) for i in items]
556
  color_counts = {}
557
  for color in colors:
558
  color_counts[color] = color_counts.get(color, 0) + 1
559
 
560
- # Prefer outfits with 2-3 dominant colors
561
- dominant_colors = [c for c, count in color_counts.items() if count >= 2]
562
- if len(dominant_colors) == 0:
563
- return 0.5 # Neutral score for all different colors
564
- elif len(dominant_colors) == 1:
565
- return 0.8 # Good consistency
566
- elif len(dominant_colors) == 2:
567
- return 1.0 # Perfect balance
568
- else:
569
- return 0.3 # Too many dominant colors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
  def calculate_style_consistency_score(items: List[int]) -> float:
572
- """Calculate style consistency based on template preferences"""
573
  categories = [cat_str(i) for i in items]
574
  preferred_cats = template["preferred_categories"]
575
 
 
576
  matches = 0
577
  for cat in categories:
578
  if any(pref in cat for pref in preferred_cats):
@@ -580,28 +614,97 @@ class InferenceService:
580
 
581
  base_score = matches / len(categories) if categories else 0.0
582
 
583
- # Bonus for context-specific categories
584
- context_bonus = 0.0
585
  occasion = template["context"]["occasion"]
586
  weather = template["context"]["weather"]
 
587
 
588
- # Occasion-specific bonuses
589
- if occasion == "business" and any("shirt" in cat or "pants" in cat for cat in categories):
590
- context_bonus += 0.2
591
- elif occasion == "formal" and any("shirt" in cat or "pants" in cat for cat in categories):
592
- context_bonus += 0.3
593
- elif occasion == "sport" and any("shirt" in cat or "pants" in cat for cat in categories):
594
- context_bonus += 0.2
595
-
596
- # Weather-specific bonuses
597
- if weather == "hot" and any("shirt" in cat for cat in categories):
598
- context_bonus += 0.1
599
- elif weather == "cold" and any("shirt" in cat or "pants" in cat for cat in categories):
600
- context_bonus += 0.1
601
- elif weather == "rain" and any("shoes" in cat for cat in categories):
602
- context_bonus += 0.1
603
-
604
- return min(1.0, base_score + context_bonus)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
  def get_category_type(cat: str) -> str:
607
  """Map category to outfit slot type with comprehensive taxonomy"""
@@ -716,46 +819,140 @@ class InferenceService:
716
  return len(unique_categories) >= 2
717
 
718
  def calculate_outfit_score(subset: List[int]) -> float:
719
- """Calculate overall outfit quality score"""
720
  if len(subset) < 2:
721
  return 0.0
722
-
723
- # Base score from category diversity
724
- diversity_score = len(set(get_category_type(cat_str(i)) for i in subset)) / 4.0
725
 
726
- # Style consistency score
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  style_score = calculate_style_consistency_score(subset)
728
 
729
- # Color consistency score
730
  color_score = calculate_color_consistency_score(subset)
731
 
732
- # Length appropriateness (prefer 3-4 items)
733
- length_score = 1.0 if 3 <= len(subset) <= 4 else 0.7
 
 
 
 
 
 
 
 
 
 
 
734
 
735
- # Weighted combination
736
- return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
 
738
- # Context-aware candidate generation with enhanced randomization
739
  for _ in range(num_samples):
740
  subset = []
741
 
742
- # Context-aware outfit length selection
743
  occasion = template["context"]["occasion"]
 
 
 
 
744
  if occasion == "formal":
745
- outfit_length = rng.choice([3, 4, 5], p=[0.2, 0.5, 0.3]) # Formal outfits tend to be more complete
 
 
 
746
  elif occasion == "business":
747
- outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) # Business outfits are well-rounded
748
  elif occasion == "sport":
749
- outfit_length = rng.choice([2, 3, 4], p=[0.3, 0.5, 0.2]) # Sport outfits can be minimal
 
 
 
750
  else: # casual
751
- outfit_length = rng.choice([2, 3, 4, 5], p=[0.2, 0.4, 0.3, 0.1]) # Casual is flexible
 
 
 
752
 
753
- # Context-aware strategy selection
754
- strategy_weights = [0.4, 0.3, 0.3] # Default weights
 
 
755
  if occasion == "formal":
756
- strategy_weights = [0.6, 0.2, 0.2] # Favor core outfits for formal
 
 
 
 
757
  elif occasion == "sport":
758
- strategy_weights = [0.3, 0.5, 0.2] # Favor flexible combinations for sport
 
 
 
 
 
 
759
 
760
  strategy = rng.choice([0, 1, 2], p=strategy_weights)
761
 
@@ -858,9 +1055,9 @@ class InferenceService:
858
  if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset):
859
  # Add randomization factor to prevent identical recommendations
860
  subset = rng.permutation(subset).tolist() # Randomize order
861
- candidates.append(subset)
862
- if len(candidates) % 10 == 0: # Log every 10 candidates
863
- print(f"πŸ” DEBUG: Generated {len(candidates)} candidates so far...")
864
 
865
  print(f"πŸ” DEBUG: Generated {len(candidates)} total candidates")
866
 
@@ -904,7 +1101,7 @@ class InferenceService:
904
  return True
905
 
906
  def calculate_outfit_penalty(subset: List[int], base_score: float) -> float:
907
- """Calculate penalty-adjusted score for outfit quality with style/color bonuses"""
908
  categories = [get_category_type(cat_str(i)) for i in subset]
909
  category_counts = {}
910
 
@@ -914,7 +1111,8 @@ class InferenceService:
914
  penalty = 0.0
915
  bonus = 0.0
916
 
917
- # Missing core slots: -∞ penalty
 
918
  if category_counts.get("upper", 0) == 0:
919
  penalty += -1000.0
920
  if category_counts.get("bottom", 0) == 0:
@@ -922,54 +1120,110 @@ class InferenceService:
922
  if category_counts.get("shoe", 0) == 0:
923
  penalty += -1000.0
924
 
925
- # Special penalty for formal outfits missing outerwear
 
 
 
 
 
 
926
  if occasion == "formal" and category_counts.get("outerwear", 0) == 0:
927
- penalty += -500.0 # Strong penalty but not -∞ for formal without jacket
 
 
 
 
928
 
929
- # Duplicate non-accessory categories: -∞ penalty
930
- for cat, count in category_counts.items():
931
- if cat != "accessory" and count > 1:
932
- penalty += -1000.0
 
 
 
 
933
 
934
- # Too many accessories: moderate penalty
935
  max_accs = template["accessory_limit"]
936
- if category_counts.get("accessory", 0) > max_accs:
937
- penalty += -2.0
 
938
 
939
- # Unbalanced outfit: small penalty
940
- if len(subset) < 3:
941
- penalty += -1.0
942
  elif len(subset) > 6:
943
- penalty += -0.5
 
 
944
 
945
- # Style consistency bonus
 
946
  style_score = calculate_style_consistency_score(subset)
947
- bonus += style_score * 0.5 # Up to 0.5 bonus for style consistency
948
 
949
- # Color consistency bonus
950
  color_score = calculate_color_consistency_score(subset)
951
- bonus += color_score * 0.3 # Up to 0.3 bonus for color consistency
952
 
953
- # Template adherence bonus
954
- if style_score > 0.6: # Good style match
955
- bonus += 0.2
 
 
 
 
 
 
956
 
957
- # Accessory inclusion bonus
958
- categories = [get_category_type(cat_str(i)) for i in subset]
959
- if "accessory" in categories:
960
- bonus += 0.3 # Bonus for including accessories
 
 
 
 
961
 
962
- # Formal outfit bonus for including outerwear
963
- if occasion == "formal" and "outerwear" in categories:
964
- bonus += 0.5 # Strong bonus for formal outfits with jackets
 
 
 
965
 
966
- # Traditional Pakistani outfit bonus
967
  if outfit_style == "traditional":
968
  traditional_items = [cat for cat in categories if any(traditional in cat.lower() for traditional in ["kameez", "kurta", "shalwar", "peshawari", "chappal"])]
969
- if len(traditional_items) >= 2: # At least 2 traditional items
970
- bonus += 0.6 # Strong bonus for traditional outfit combinations
971
- if len(traditional_items) == 3: # Complete traditional set
972
- bonus += 0.3 # Additional bonus for complete traditional outfit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973
 
974
  return base_score + penalty + bonus
975
 
@@ -1031,7 +1285,7 @@ class InferenceService:
1031
  else:
1032
  # If we have fewer candidates than requested, shuffle them
1033
  rng.shuffle(scored)
1034
- topk = scored[:num_outfits]
1035
 
1036
  results = []
1037
  for subset, adjusted_score, base_score in topk:
@@ -1098,8 +1352,8 @@ class InferenceService:
1098
  # Disable gradients
1099
  for m in [self.resnet, self.vit]:
1100
  if m is not None:
1101
- for p in m.parameters():
1102
- p.requires_grad_(False)
1103
 
1104
  # Update overall status
1105
  self.models_loaded = self.resnet_loaded and self.vit_loaded
 
551
  return "unknown"
552
 
553
  def calculate_color_consistency_score(items: List[int]) -> float:
554
+ """Calculate sophisticated color harmony score using fashion theory"""
555
  colors = [extract_color_from_category(cat_str(i)) for i in items]
556
  color_counts = {}
557
  for color in colors:
558
  color_counts[color] = color_counts.get(color, 0) + 1
559
 
560
+ # Advanced color harmony rules
561
+ base_score = 0.0
562
+
563
+ # 1. Monochromatic harmony (same color family)
564
+ dominant_color = max(color_counts.items(), key=lambda x: x[1])[0] if color_counts else "unknown"
565
+ if color_counts.get(dominant_color, 0) >= 2:
566
+ base_score += 0.4 # Strong monochromatic bonus
567
+
568
+ # 2. Complementary color harmony
569
+ complementary_pairs = [
570
+ ("black", "white"), ("navy", "white"), ("brown", "beige"),
571
+ ("red", "green"), ("blue", "orange"), ("purple", "yellow")
572
+ ]
573
+ for color1, color2 in complementary_pairs:
574
+ if color_counts.get(color1, 0) > 0 and color_counts.get(color2, 0) > 0:
575
+ base_score += 0.3 # Complementary harmony bonus
576
+ break
577
+
578
+ # 3. Neutral base with accent colors
579
+ neutral_colors = ["black", "white", "gray", "navy", "brown", "beige"]
580
+ neutral_count = sum(color_counts.get(c, 0) for c in neutral_colors)
581
+ if neutral_count >= 2 and len([c for c in colors if c not in neutral_colors]) <= 1:
582
+ base_score += 0.2 # Neutral base bonus
583
+
584
+ # 4. Color distribution penalty
585
+ if len(color_counts) > 4:
586
+ base_score -= 0.2 # Too many different colors
587
+ elif len(color_counts) == 1 and len(items) > 2:
588
+ base_score -= 0.1 # Too monotonous
589
+
590
+ # 5. Context-aware color scoring
591
+ if occasion == "formal":
592
+ formal_colors = ["black", "navy", "white", "gray", "charcoal"]
593
+ formal_count = sum(color_counts.get(c, 0) for c in formal_colors)
594
+ if formal_count >= 2:
595
+ base_score += 0.2 # Formal color bonus
596
+ elif occasion == "business":
597
+ business_colors = ["navy", "white", "gray", "black", "brown"]
598
+ business_count = sum(color_counts.get(c, 0) for c in business_colors)
599
+ if business_count >= 2:
600
+ base_score += 0.15 # Business color bonus
601
+
602
+ return min(1.0, max(0.0, base_score + 0.3)) # Ensure score is between 0-1
603
 
604
  def calculate_style_consistency_score(items: List[int]) -> float:
605
+ """Calculate advanced style consistency using fashion expert rules"""
606
  categories = [cat_str(i) for i in items]
607
  preferred_cats = template["preferred_categories"]
608
 
609
+ # Base template matching
610
  matches = 0
611
  for cat in categories:
612
  if any(pref in cat for pref in preferred_cats):
 
614
 
615
  base_score = matches / len(categories) if categories else 0.0
616
 
617
+ # Advanced fashion rules scoring
618
+ fashion_bonus = 0.0
619
  occasion = template["context"]["occasion"]
620
  weather = template["context"]["weather"]
621
+ outfit_style = template["context"]["style"]
622
 
623
+ # 1. Occasion-appropriate style rules
624
+ if occasion == "formal":
625
+ # Formal requires structured, tailored pieces
626
+ formal_items = ["jacket", "blazer", "suit", "dress shirt", "dress pant", "oxford"]
627
+ formal_count = sum(1 for cat in categories if any(f in cat for f in formal_items))
628
+ if formal_count >= 3: # At least 3 formal items
629
+ fashion_bonus += 0.4
630
+ elif formal_count >= 2:
631
+ fashion_bonus += 0.2
632
+
633
+ elif occasion == "business":
634
+ # Business requires professional but not overly formal
635
+ business_items = ["shirt", "blazer", "pants", "loafer", "oxford", "dress pant"]
636
+ business_count = sum(1 for cat in categories if any(b in cat for b in business_items))
637
+ if business_count >= 3:
638
+ fashion_bonus += 0.3
639
+ elif business_count >= 2:
640
+ fashion_bonus += 0.15
641
+
642
+ elif occasion == "sport":
643
+ # Sport requires athletic, functional pieces
644
+ sport_items = ["athletic", "running", "jogger", "sneaker", "tank", "legging"]
645
+ sport_count = sum(1 for cat in categories if any(s in cat for s in sport_items))
646
+ if sport_count >= 2:
647
+ fashion_bonus += 0.3
648
+
649
+ # 2. Style coherence rules
650
+ if outfit_style == "formal":
651
+ # Formal style coherence
652
+ if "jacket" in categories and "shirt" in categories:
653
+ fashion_bonus += 0.2 # Proper layering
654
+ if len([c for c in categories if c in ["jacket", "shirt", "pants", "shoes"]]) >= 3:
655
+ fashion_bonus += 0.2 # Complete formal set
656
+
657
+ elif outfit_style == "smart_casual":
658
+ # Smart casual balance
659
+ if "shirt" in categories and "pants" in categories:
660
+ fashion_bonus += 0.15
661
+ if "blazer" in categories or "jacket" in categories:
662
+ fashion_bonus += 0.1 # Elevated casual
663
+
664
+ elif outfit_style == "traditional":
665
+ # Traditional Pakistani coherence
666
+ traditional_items = ["kameez", "kurta", "shalwar", "peshawari", "chappal"]
667
+ traditional_count = sum(1 for cat in categories if any(t in cat for t in traditional_items))
668
+ if traditional_count >= 2:
669
+ fashion_bonus += 0.4
670
+ if traditional_count == 3: # Complete traditional set
671
+ fashion_bonus += 0.2
672
+
673
+ # 3. Weather-appropriate logic
674
+ if weather == "hot":
675
+ # Hot weather preferences
676
+ if any("light" in cat or "cotton" in cat or "tank" in cat for cat in categories):
677
+ fashion_bonus += 0.1
678
+ if "jacket" in categories and len(categories) > 3:
679
+ fashion_bonus -= 0.1 # Too many layers for hot weather
680
+
681
+ elif weather == "cold":
682
+ # Cold weather preferences
683
+ if "jacket" in categories or "blazer" in categories:
684
+ fashion_bonus += 0.15
685
+ if "sweater" in categories or "hoodie" in categories:
686
+ fashion_bonus += 0.1
687
+
688
+ elif weather == "rain":
689
+ # Rain weather preferences
690
+ if any("waterproof" in cat or "boot" in cat for cat in categories):
691
+ fashion_bonus += 0.2
692
+ if "jacket" in categories:
693
+ fashion_bonus += 0.1
694
+
695
+ # 4. Proportions and balance
696
+ category_types = [get_category_type(cat) for cat in categories]
697
+ type_counts = {}
698
+ for cat_type in category_types:
699
+ type_counts[cat_type] = type_counts.get(cat_type, 0) + 1
700
+
701
+ # Balanced outfit proportions
702
+ if len(type_counts) >= 3: # Good diversity
703
+ fashion_bonus += 0.1
704
+ if type_counts.get("accessory", 0) <= 2: # Not over-accessorized
705
+ fashion_bonus += 0.05
706
+
707
+ return min(1.0, base_score + fashion_bonus)
708
 
709
  def get_category_type(cat: str) -> str:
710
  """Map category to outfit slot type with comprehensive taxonomy"""
 
819
  return len(unique_categories) >= 2
820
 
821
  def calculate_outfit_score(subset: List[int]) -> float:
822
+ """Calculate sophisticated outfit quality score using advanced fashion reasoning"""
823
  if len(subset) < 2:
824
  return 0.0
 
 
 
825
 
826
+ # 1. Category diversity and completeness
827
+ category_types = [get_category_type(cat_str(i)) for i in subset]
828
+ unique_types = set(category_types)
829
+ diversity_score = len(unique_types) / 5.0 # Normalize to 5 categories max
830
+
831
+ # Completeness bonus for essential categories
832
+ essential_categories = {"upper", "bottom", "shoe"}
833
+ completeness_bonus = 0.0
834
+ if essential_categories.issubset(unique_types):
835
+ completeness_bonus += 0.3 # All essential categories present
836
+ elif len(essential_categories.intersection(unique_types)) >= 2:
837
+ completeness_bonus += 0.15 # Most essential categories present
838
+
839
+ # 2. Advanced style consistency
840
  style_score = calculate_style_consistency_score(subset)
841
 
842
+ # 3. Sophisticated color harmony
843
  color_score = calculate_color_consistency_score(subset)
844
 
845
+ # 4. Context-appropriate length scoring
846
+ occasion = template["context"]["occasion"]
847
+ if occasion == "formal":
848
+ length_score = 1.0 if 4 <= len(subset) <= 5 else 0.6 # Formal prefers complete sets
849
+ elif occasion == "business":
850
+ length_score = 1.0 if 3 <= len(subset) <= 4 else 0.7 # Business balanced
851
+ elif occasion == "sport":
852
+ length_score = 1.0 if 2 <= len(subset) <= 3 else 0.8 # Sport can be minimal
853
+ else: # casual
854
+ length_score = 1.0 if 2 <= len(subset) <= 4 else 0.7 # Casual flexible
855
+
856
+ # 5. Fashion rule compliance
857
+ fashion_rules_score = 0.0
858
 
859
+ # Rule: No more than one item per core category (except accessories)
860
+ core_categories = {"upper", "bottom", "shoe", "outerwear"}
861
+ core_counts = {cat: category_types.count(cat) for cat in core_categories}
862
+ if all(count <= 1 for count in core_counts.values()):
863
+ fashion_rules_score += 0.2 # Perfect core category distribution
864
+
865
+ # Rule: Appropriate accessory count
866
+ accessory_count = category_types.count("accessory")
867
+ max_accessories = template.get("accessory_limit", 3)
868
+ if accessory_count <= max_accessories:
869
+ fashion_rules_score += 0.1
870
+ if accessory_count > 0 and accessory_count <= 2:
871
+ fashion_rules_score += 0.1 # Bonus for tasteful accessorizing
872
+
873
+ # Rule: Occasion-appropriate formality
874
+ if occasion == "formal" and "outerwear" in unique_types:
875
+ fashion_rules_score += 0.2 # Formal requires outerwear
876
+ elif occasion == "business" and len(unique_types) >= 3:
877
+ fashion_rules_score += 0.15 # Business requires completeness
878
+ elif occasion == "sport" and any("athletic" in cat_str(i) for i in subset):
879
+ fashion_rules_score += 0.1 # Sport requires athletic items
880
+
881
+ # 6. Advanced weighted combination with reasoning
882
+ base_score = (
883
+ 0.25 * (diversity_score + completeness_bonus) + # Structure and completeness
884
+ 0.30 * style_score + # Style coherence
885
+ 0.20 * color_score + # Color harmony
886
+ 0.15 * length_score + # Appropriate length
887
+ 0.10 * fashion_rules_score # Fashion rule compliance
888
+ )
889
+
890
+ # 7. Context-specific adjustments
891
+ context_adjustment = 0.0
892
+
893
+ # Weather appropriateness
894
+ weather = template["context"]["weather"]
895
+ if weather == "hot" and len(subset) > 4:
896
+ context_adjustment -= 0.1 # Too many layers for hot weather
897
+ elif weather == "cold" and "outerwear" not in unique_types:
898
+ context_adjustment -= 0.1 # Missing outerwear for cold weather
899
+ elif weather == "rain" and not any("boot" in cat_str(i) for i in subset):
900
+ context_adjustment -= 0.05 # Missing weather-appropriate footwear
901
+
902
+ # Occasion-specific adjustments
903
+ if occasion == "formal" and len(subset) < 4:
904
+ context_adjustment -= 0.1 # Formal outfits should be complete
905
+ elif occasion == "sport" and len(subset) > 4:
906
+ context_adjustment -= 0.05 # Sport outfits can be simpler
907
+
908
+ return min(1.0, max(0.0, base_score + context_adjustment))
909
 
910
+ # Advanced candidate generation with sophisticated reasoning
911
  for _ in range(num_samples):
912
  subset = []
913
 
914
+ # 1. Advanced context-aware outfit length selection
915
  occasion = template["context"]["occasion"]
916
+ weather = template["context"]["weather"]
917
+ outfit_style = template["context"]["style"]
918
+
919
+ # Base length probabilities
920
  if occasion == "formal":
921
+ if weather == "hot":
922
+ outfit_length = rng.choice([3, 4], p=[0.4, 0.6]) # Formal but weather-appropriate
923
+ else:
924
+ outfit_length = rng.choice([4, 5], p=[0.6, 0.4]) # Complete formal sets
925
  elif occasion == "business":
926
+ outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) # Professional balance
927
  elif occasion == "sport":
928
+ if weather == "hot":
929
+ outfit_length = rng.choice([2, 3], p=[0.6, 0.4]) # Minimal for hot weather
930
+ else:
931
+ outfit_length = rng.choice([2, 3, 4], p=[0.3, 0.5, 0.2]) # Sport flexibility
932
  else: # casual
933
+ if weather == "hot":
934
+ outfit_length = rng.choice([2, 3, 4], p=[0.4, 0.4, 0.2]) # Casual, weather-appropriate
935
+ else:
936
+ outfit_length = rng.choice([2, 3, 4, 5], p=[0.2, 0.4, 0.3, 0.1]) # Casual flexibility
937
 
938
+ # 2. Advanced strategy selection with reasoning
939
+ strategy_weights = [0.4, 0.3, 0.3] # Default: Core, Accessory-focused, Flexible
940
+
941
+ # Formal occasions prioritize complete, structured outfits
942
  if occasion == "formal":
943
+ strategy_weights = [0.7, 0.1, 0.2] # Strongly favor core outfits
944
+ # Business occasions need professional balance
945
+ elif occasion == "business":
946
+ strategy_weights = [0.6, 0.2, 0.2] # Favor core with some flexibility
947
+ # Sport occasions can be more flexible
948
  elif occasion == "sport":
949
+ strategy_weights = [0.3, 0.2, 0.5] # Favor flexible combinations
950
+ # Traditional outfits need cultural coherence
951
+ elif outfit_style == "traditional":
952
+ strategy_weights = [0.8, 0.1, 0.1] # Strongly favor traditional core sets
953
+ # Casual occasions allow more creativity
954
+ else:
955
+ strategy_weights = [0.4, 0.3, 0.3] # Balanced approach
956
 
957
  strategy = rng.choice([0, 1, 2], p=strategy_weights)
958
 
 
1055
  if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset):
1056
  # Add randomization factor to prevent identical recommendations
1057
  subset = rng.permutation(subset).tolist() # Randomize order
1058
+ candidates.append(subset)
1059
+ if len(candidates) % 10 == 0: # Log every 10 candidates
1060
+ print(f"πŸ” DEBUG: Generated {len(candidates)} candidates so far...")
1061
 
1062
  print(f"πŸ” DEBUG: Generated {len(candidates)} total candidates")
1063
 
 
1101
  return True
1102
 
1103
  def calculate_outfit_penalty(subset: List[int], base_score: float) -> float:
1104
+ """Calculate sophisticated penalty-adjusted score with advanced fashion reasoning"""
1105
  categories = [get_category_type(cat_str(i)) for i in subset]
1106
  category_counts = {}
1107
 
 
1111
  penalty = 0.0
1112
  bonus = 0.0
1113
 
1114
+ # 1. Critical fashion violations (severe penalties)
1115
+ # Missing essential categories: -∞ penalty
1116
  if category_counts.get("upper", 0) == 0:
1117
  penalty += -1000.0
1118
  if category_counts.get("bottom", 0) == 0:
 
1120
  if category_counts.get("shoe", 0) == 0:
1121
  penalty += -1000.0
1122
 
1123
+ # Duplicate core categories: -∞ penalty (fashion rule violation)
1124
+ core_categories = {"upper", "bottom", "shoe", "outerwear"}
1125
+ for cat in core_categories:
1126
+ if category_counts.get(cat, 0) > 1:
1127
+ penalty += -1000.0
1128
+
1129
+ # 2. Context-specific critical violations
1130
  if occasion == "formal" and category_counts.get("outerwear", 0) == 0:
1131
+ penalty += -500.0 # Formal without jacket is inappropriate
1132
+ elif occasion == "business" and len(subset) < 3:
1133
+ penalty += -200.0 # Business outfits should be complete
1134
+ elif occasion == "sport" and not any("athletic" in cat_str(i) for i in subset):
1135
+ penalty += -300.0 # Sport outfits need athletic items
1136
 
1137
+ # 3. Weather-appropriate violations
1138
+ weather = template["context"]["weather"]
1139
+ if weather == "hot" and len(subset) > 5:
1140
+ penalty += -100.0 # Too many layers for hot weather
1141
+ elif weather == "cold" and category_counts.get("outerwear", 0) == 0:
1142
+ penalty += -150.0 # Missing outerwear for cold weather
1143
+ elif weather == "rain" and not any("boot" in cat_str(i) for i in subset):
1144
+ penalty += -50.0 # Missing weather-appropriate footwear
1145
 
1146
+ # 4. Accessory violations
1147
  max_accs = template["accessory_limit"]
1148
+ accessory_count = category_counts.get("accessory", 0)
1149
+ if accessory_count > max_accs:
1150
+ penalty += -50.0 * (accessory_count - max_accs) # Proportional penalty
1151
 
1152
+ # 5. Outfit balance violations
1153
+ if len(subset) < 2:
1154
+ penalty += -200.0 # Too minimal
1155
  elif len(subset) > 6:
1156
+ penalty += -100.0 # Too complex
1157
+ elif len(subset) == 2 and occasion in ["formal", "business"]:
1158
+ penalty += -100.0 # Too minimal for formal/business
1159
 
1160
+ # 6. Advanced bonus system
1161
+ # Style consistency bonus (weighted by importance)
1162
  style_score = calculate_style_consistency_score(subset)
1163
+ bonus += style_score * 0.6 # Increased weight for style
1164
 
1165
+ # Color harmony bonus
1166
  color_score = calculate_color_consistency_score(subset)
1167
+ bonus += color_score * 0.4 # Increased weight for color
1168
 
1169
+ # 7. Context-specific bonuses
1170
+ # Formal outfit bonuses
1171
+ if occasion == "formal":
1172
+ if "outerwear" in categories:
1173
+ bonus += 0.6 # Strong bonus for proper formal layering
1174
+ if len([c for c in categories if c in ["upper", "bottom", "shoe", "outerwear"]]) >= 4:
1175
+ bonus += 0.4 # Complete formal set bonus
1176
+ if style_score > 0.7:
1177
+ bonus += 0.3 # High style coherence bonus
1178
 
1179
+ # Business outfit bonuses
1180
+ elif occasion == "business":
1181
+ if len(categories) >= 3:
1182
+ bonus += 0.3 # Professional completeness
1183
+ if "outerwear" in categories:
1184
+ bonus += 0.2 # Elevated business look
1185
+ if style_score > 0.6:
1186
+ bonus += 0.2 # Professional style bonus
1187
 
1188
+ # Sport outfit bonuses
1189
+ elif occasion == "sport":
1190
+ if any("athletic" in cat_str(i) for i in subset):
1191
+ bonus += 0.4 # Athletic functionality
1192
+ if len(subset) <= 3:
1193
+ bonus += 0.2 # Appropriate minimalism for sport
1194
 
1195
+ # 8. Traditional Pakistani outfit bonuses
1196
  if outfit_style == "traditional":
1197
  traditional_items = [cat for cat in categories if any(traditional in cat.lower() for traditional in ["kameez", "kurta", "shalwar", "peshawari", "chappal"])]
1198
+ if len(traditional_items) >= 2:
1199
+ bonus += 0.7 # Strong cultural appropriateness bonus
1200
+ if len(traditional_items) == 3:
1201
+ bonus += 0.4 # Complete traditional set bonus
1202
+ if style_score > 0.6:
1203
+ bonus += 0.3 # Traditional style coherence
1204
+
1205
+ # 9. Fashion rule compliance bonuses
1206
+ # Perfect category distribution
1207
+ if all(category_counts.get(cat, 0) <= 1 for cat in core_categories):
1208
+ bonus += 0.3 # Perfect fashion rule compliance
1209
+
1210
+ # Tasteful accessorizing
1211
+ if 1 <= accessory_count <= 2:
1212
+ bonus += 0.2 # Tasteful accessorizing bonus
1213
+
1214
+ # 10. Weather-appropriate bonuses
1215
+ if weather == "hot" and len(subset) <= 4:
1216
+ bonus += 0.1 # Appropriate for hot weather
1217
+ elif weather == "cold" and "outerwear" in categories:
1218
+ bonus += 0.2 # Proper cold weather preparation
1219
+ elif weather == "rain" and any("boot" in cat_str(i) for i in subset):
1220
+ bonus += 0.15 # Weather-appropriate footwear
1221
+
1222
+ # 11. Overall outfit quality bonus
1223
+ if style_score > 0.8 and color_score > 0.7:
1224
+ bonus += 0.3 # Exceptional outfit quality
1225
+ elif style_score > 0.6 and color_score > 0.5:
1226
+ bonus += 0.2 # Good outfit quality
1227
 
1228
  return base_score + penalty + bonus
1229
 
 
1285
  else:
1286
  # If we have fewer candidates than requested, shuffle them
1287
  rng.shuffle(scored)
1288
+ topk = scored[:num_outfits]
1289
 
1290
  results = []
1291
  for subset, adjusted_score, base_score in topk:
 
1352
  # Disable gradients
1353
  for m in [self.resnet, self.vit]:
1354
  if m is not None:
1355
+ for p in m.parameters():
1356
+ p.requires_grad_(False)
1357
 
1358
  # Update overall status
1359
  self.models_loaded = self.resnet_loaded and self.vit_loaded