Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
08dc87f
1
Parent(s):
e13ff13
nee last fixes
Browse files- inference.py +345 -91
inference.py
CHANGED
|
@@ -551,28 +551,62 @@ class InferenceService:
|
|
| 551 |
return "unknown"
|
| 552 |
|
| 553 |
def calculate_color_consistency_score(items: List[int]) -> float:
|
| 554 |
-
"""Calculate color
|
| 555 |
colors = [extract_color_from_category(cat_str(i)) for i in items]
|
| 556 |
color_counts = {}
|
| 557 |
for color in colors:
|
| 558 |
color_counts[color] = color_counts.get(color, 0) + 1
|
| 559 |
|
| 560 |
-
#
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
def calculate_style_consistency_score(items: List[int]) -> float:
|
| 572 |
-
"""Calculate style consistency
|
| 573 |
categories = [cat_str(i) for i in items]
|
| 574 |
preferred_cats = template["preferred_categories"]
|
| 575 |
|
|
|
|
| 576 |
matches = 0
|
| 577 |
for cat in categories:
|
| 578 |
if any(pref in cat for pref in preferred_cats):
|
|
@@ -580,28 +614,97 @@ class InferenceService:
|
|
| 580 |
|
| 581 |
base_score = matches / len(categories) if categories else 0.0
|
| 582 |
|
| 583 |
-
#
|
| 584 |
-
|
| 585 |
occasion = template["context"]["occasion"]
|
| 586 |
weather = template["context"]["weather"]
|
|
|
|
| 587 |
|
| 588 |
-
# Occasion-
|
| 589 |
-
if occasion == "
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
def get_category_type(cat: str) -> str:
|
| 607 |
"""Map category to outfit slot type with comprehensive taxonomy"""
|
|
@@ -716,46 +819,140 @@ class InferenceService:
|
|
| 716 |
return len(unique_categories) >= 2
|
| 717 |
|
| 718 |
def calculate_outfit_score(subset: List[int]) -> float:
|
| 719 |
-
"""Calculate
|
| 720 |
if len(subset) < 2:
|
| 721 |
return 0.0
|
| 722 |
-
|
| 723 |
-
# Base score from category diversity
|
| 724 |
-
diversity_score = len(set(get_category_type(cat_str(i)) for i in subset)) / 4.0
|
| 725 |
|
| 726 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
style_score = calculate_style_consistency_score(subset)
|
| 728 |
|
| 729 |
-
#
|
| 730 |
color_score = calculate_color_consistency_score(subset)
|
| 731 |
|
| 732 |
-
#
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
|
| 735 |
-
#
|
| 736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
|
| 738 |
-
#
|
| 739 |
for _ in range(num_samples):
|
| 740 |
subset = []
|
| 741 |
|
| 742 |
-
#
|
| 743 |
occasion = template["context"]["occasion"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
if occasion == "formal":
|
| 745 |
-
|
|
|
|
|
|
|
|
|
|
| 746 |
elif occasion == "business":
|
| 747 |
-
outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) #
|
| 748 |
elif occasion == "sport":
|
| 749 |
-
|
|
|
|
|
|
|
|
|
|
| 750 |
else: # casual
|
| 751 |
-
|
|
|
|
|
|
|
|
|
|
| 752 |
|
| 753 |
-
#
|
| 754 |
-
strategy_weights = [0.4, 0.3, 0.3] # Default
|
|
|
|
|
|
|
| 755 |
if occasion == "formal":
|
| 756 |
-
strategy_weights = [0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
elif occasion == "sport":
|
| 758 |
-
strategy_weights = [0.3, 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
|
| 760 |
strategy = rng.choice([0, 1, 2], p=strategy_weights)
|
| 761 |
|
|
@@ -858,9 +1055,9 @@ class InferenceService:
|
|
| 858 |
if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset):
|
| 859 |
# Add randomization factor to prevent identical recommendations
|
| 860 |
subset = rng.permutation(subset).tolist() # Randomize order
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
|
| 865 |
print(f"π DEBUG: Generated {len(candidates)} total candidates")
|
| 866 |
|
|
@@ -904,7 +1101,7 @@ class InferenceService:
|
|
| 904 |
return True
|
| 905 |
|
| 906 |
def calculate_outfit_penalty(subset: List[int], base_score: float) -> float:
|
| 907 |
-
"""Calculate penalty-adjusted score
|
| 908 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 909 |
category_counts = {}
|
| 910 |
|
|
@@ -914,7 +1111,8 @@ class InferenceService:
|
|
| 914 |
penalty = 0.0
|
| 915 |
bonus = 0.0
|
| 916 |
|
| 917 |
-
#
|
|
|
|
| 918 |
if category_counts.get("upper", 0) == 0:
|
| 919 |
penalty += -1000.0
|
| 920 |
if category_counts.get("bottom", 0) == 0:
|
|
@@ -922,54 +1120,110 @@ class InferenceService:
|
|
| 922 |
if category_counts.get("shoe", 0) == 0:
|
| 923 |
penalty += -1000.0
|
| 924 |
|
| 925 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
if occasion == "formal" and category_counts.get("outerwear", 0) == 0:
|
| 927 |
-
penalty += -500.0 #
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
|
| 929 |
-
#
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
|
| 934 |
-
#
|
| 935 |
max_accs = template["accessory_limit"]
|
| 936 |
-
|
| 937 |
-
|
|
|
|
| 938 |
|
| 939 |
-
#
|
| 940 |
-
if len(subset) <
|
| 941 |
-
penalty += -
|
| 942 |
elif len(subset) > 6:
|
| 943 |
-
penalty += -0
|
|
|
|
|
|
|
| 944 |
|
| 945 |
-
#
|
|
|
|
| 946 |
style_score = calculate_style_consistency_score(subset)
|
| 947 |
-
bonus += style_score * 0.
|
| 948 |
|
| 949 |
-
# Color
|
| 950 |
color_score = calculate_color_consistency_score(subset)
|
| 951 |
-
bonus += color_score * 0.
|
| 952 |
|
| 953 |
-
#
|
| 954 |
-
|
| 955 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
|
| 957 |
-
#
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
|
| 962 |
-
#
|
| 963 |
-
|
| 964 |
-
|
|
|
|
|
|
|
|
|
|
| 965 |
|
| 966 |
-
# Traditional Pakistani outfit
|
| 967 |
if outfit_style == "traditional":
|
| 968 |
traditional_items = [cat for cat in categories if any(traditional in cat.lower() for traditional in ["kameez", "kurta", "shalwar", "peshawari", "chappal"])]
|
| 969 |
-
if len(traditional_items) >= 2:
|
| 970 |
-
bonus += 0.
|
| 971 |
-
if len(traditional_items) == 3:
|
| 972 |
-
bonus += 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
|
| 974 |
return base_score + penalty + bonus
|
| 975 |
|
|
@@ -1031,7 +1285,7 @@ class InferenceService:
|
|
| 1031 |
else:
|
| 1032 |
# If we have fewer candidates than requested, shuffle them
|
| 1033 |
rng.shuffle(scored)
|
| 1034 |
-
|
| 1035 |
|
| 1036 |
results = []
|
| 1037 |
for subset, adjusted_score, base_score in topk:
|
|
@@ -1098,8 +1352,8 @@ class InferenceService:
|
|
| 1098 |
# Disable gradients
|
| 1099 |
for m in [self.resnet, self.vit]:
|
| 1100 |
if m is not None:
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
|
| 1104 |
# Update overall status
|
| 1105 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
|
|
| 551 |
return "unknown"
|
| 552 |
|
| 553 |
def calculate_color_consistency_score(items: List[int]) -> float:
|
| 554 |
+
"""Calculate sophisticated color harmony score using fashion theory"""
|
| 555 |
colors = [extract_color_from_category(cat_str(i)) for i in items]
|
| 556 |
color_counts = {}
|
| 557 |
for color in colors:
|
| 558 |
color_counts[color] = color_counts.get(color, 0) + 1
|
| 559 |
|
| 560 |
+
# Advanced color harmony rules
|
| 561 |
+
base_score = 0.0
|
| 562 |
+
|
| 563 |
+
# 1. Monochromatic harmony (same color family)
|
| 564 |
+
dominant_color = max(color_counts.items(), key=lambda x: x[1])[0] if color_counts else "unknown"
|
| 565 |
+
if color_counts.get(dominant_color, 0) >= 2:
|
| 566 |
+
base_score += 0.4 # Strong monochromatic bonus
|
| 567 |
+
|
| 568 |
+
# 2. Complementary color harmony
|
| 569 |
+
complementary_pairs = [
|
| 570 |
+
("black", "white"), ("navy", "white"), ("brown", "beige"),
|
| 571 |
+
("red", "green"), ("blue", "orange"), ("purple", "yellow")
|
| 572 |
+
]
|
| 573 |
+
for color1, color2 in complementary_pairs:
|
| 574 |
+
if color_counts.get(color1, 0) > 0 and color_counts.get(color2, 0) > 0:
|
| 575 |
+
base_score += 0.3 # Complementary harmony bonus
|
| 576 |
+
break
|
| 577 |
+
|
| 578 |
+
# 3. Neutral base with accent colors
|
| 579 |
+
neutral_colors = ["black", "white", "gray", "navy", "brown", "beige"]
|
| 580 |
+
neutral_count = sum(color_counts.get(c, 0) for c in neutral_colors)
|
| 581 |
+
if neutral_count >= 2 and len([c for c in colors if c not in neutral_colors]) <= 1:
|
| 582 |
+
base_score += 0.2 # Neutral base bonus
|
| 583 |
+
|
| 584 |
+
# 4. Color distribution penalty
|
| 585 |
+
if len(color_counts) > 4:
|
| 586 |
+
base_score -= 0.2 # Too many different colors
|
| 587 |
+
elif len(color_counts) == 1 and len(items) > 2:
|
| 588 |
+
base_score -= 0.1 # Too monotonous
|
| 589 |
+
|
| 590 |
+
# 5. Context-aware color scoring
|
| 591 |
+
if occasion == "formal":
|
| 592 |
+
formal_colors = ["black", "navy", "white", "gray", "charcoal"]
|
| 593 |
+
formal_count = sum(color_counts.get(c, 0) for c in formal_colors)
|
| 594 |
+
if formal_count >= 2:
|
| 595 |
+
base_score += 0.2 # Formal color bonus
|
| 596 |
+
elif occasion == "business":
|
| 597 |
+
business_colors = ["navy", "white", "gray", "black", "brown"]
|
| 598 |
+
business_count = sum(color_counts.get(c, 0) for c in business_colors)
|
| 599 |
+
if business_count >= 2:
|
| 600 |
+
base_score += 0.15 # Business color bonus
|
| 601 |
+
|
| 602 |
+
return min(1.0, max(0.0, base_score + 0.3)) # Ensure score is between 0-1
|
| 603 |
|
| 604 |
def calculate_style_consistency_score(items: List[int]) -> float:
|
| 605 |
+
"""Calculate advanced style consistency using fashion expert rules"""
|
| 606 |
categories = [cat_str(i) for i in items]
|
| 607 |
preferred_cats = template["preferred_categories"]
|
| 608 |
|
| 609 |
+
# Base template matching
|
| 610 |
matches = 0
|
| 611 |
for cat in categories:
|
| 612 |
if any(pref in cat for pref in preferred_cats):
|
|
|
|
| 614 |
|
| 615 |
base_score = matches / len(categories) if categories else 0.0
|
| 616 |
|
| 617 |
+
# Advanced fashion rules scoring
|
| 618 |
+
fashion_bonus = 0.0
|
| 619 |
occasion = template["context"]["occasion"]
|
| 620 |
weather = template["context"]["weather"]
|
| 621 |
+
outfit_style = template["context"]["style"]
|
| 622 |
|
| 623 |
+
# 1. Occasion-appropriate style rules
|
| 624 |
+
if occasion == "formal":
|
| 625 |
+
# Formal requires structured, tailored pieces
|
| 626 |
+
formal_items = ["jacket", "blazer", "suit", "dress shirt", "dress pant", "oxford"]
|
| 627 |
+
formal_count = sum(1 for cat in categories if any(f in cat for f in formal_items))
|
| 628 |
+
if formal_count >= 3: # At least 3 formal items
|
| 629 |
+
fashion_bonus += 0.4
|
| 630 |
+
elif formal_count >= 2:
|
| 631 |
+
fashion_bonus += 0.2
|
| 632 |
+
|
| 633 |
+
elif occasion == "business":
|
| 634 |
+
# Business requires professional but not overly formal
|
| 635 |
+
business_items = ["shirt", "blazer", "pants", "loafer", "oxford", "dress pant"]
|
| 636 |
+
business_count = sum(1 for cat in categories if any(b in cat for b in business_items))
|
| 637 |
+
if business_count >= 3:
|
| 638 |
+
fashion_bonus += 0.3
|
| 639 |
+
elif business_count >= 2:
|
| 640 |
+
fashion_bonus += 0.15
|
| 641 |
+
|
| 642 |
+
elif occasion == "sport":
|
| 643 |
+
# Sport requires athletic, functional pieces
|
| 644 |
+
sport_items = ["athletic", "running", "jogger", "sneaker", "tank", "legging"]
|
| 645 |
+
sport_count = sum(1 for cat in categories if any(s in cat for s in sport_items))
|
| 646 |
+
if sport_count >= 2:
|
| 647 |
+
fashion_bonus += 0.3
|
| 648 |
+
|
| 649 |
+
# 2. Style coherence rules
|
| 650 |
+
if outfit_style == "formal":
|
| 651 |
+
# Formal style coherence
|
| 652 |
+
if "jacket" in categories and "shirt" in categories:
|
| 653 |
+
fashion_bonus += 0.2 # Proper layering
|
| 654 |
+
if len([c for c in categories if c in ["jacket", "shirt", "pants", "shoes"]]) >= 3:
|
| 655 |
+
fashion_bonus += 0.2 # Complete formal set
|
| 656 |
+
|
| 657 |
+
elif outfit_style == "smart_casual":
|
| 658 |
+
# Smart casual balance
|
| 659 |
+
if "shirt" in categories and "pants" in categories:
|
| 660 |
+
fashion_bonus += 0.15
|
| 661 |
+
if "blazer" in categories or "jacket" in categories:
|
| 662 |
+
fashion_bonus += 0.1 # Elevated casual
|
| 663 |
+
|
| 664 |
+
elif outfit_style == "traditional":
|
| 665 |
+
# Traditional Pakistani coherence
|
| 666 |
+
traditional_items = ["kameez", "kurta", "shalwar", "peshawari", "chappal"]
|
| 667 |
+
traditional_count = sum(1 for cat in categories if any(t in cat for t in traditional_items))
|
| 668 |
+
if traditional_count >= 2:
|
| 669 |
+
fashion_bonus += 0.4
|
| 670 |
+
if traditional_count == 3: # Complete traditional set
|
| 671 |
+
fashion_bonus += 0.2
|
| 672 |
+
|
| 673 |
+
# 3. Weather-appropriate logic
|
| 674 |
+
if weather == "hot":
|
| 675 |
+
# Hot weather preferences
|
| 676 |
+
if any("light" in cat or "cotton" in cat or "tank" in cat for cat in categories):
|
| 677 |
+
fashion_bonus += 0.1
|
| 678 |
+
if "jacket" in categories and len(categories) > 3:
|
| 679 |
+
fashion_bonus -= 0.1 # Too many layers for hot weather
|
| 680 |
+
|
| 681 |
+
elif weather == "cold":
|
| 682 |
+
# Cold weather preferences
|
| 683 |
+
if "jacket" in categories or "blazer" in categories:
|
| 684 |
+
fashion_bonus += 0.15
|
| 685 |
+
if "sweater" in categories or "hoodie" in categories:
|
| 686 |
+
fashion_bonus += 0.1
|
| 687 |
+
|
| 688 |
+
elif weather == "rain":
|
| 689 |
+
# Rain weather preferences
|
| 690 |
+
if any("waterproof" in cat or "boot" in cat for cat in categories):
|
| 691 |
+
fashion_bonus += 0.2
|
| 692 |
+
if "jacket" in categories:
|
| 693 |
+
fashion_bonus += 0.1
|
| 694 |
+
|
| 695 |
+
# 4. Proportions and balance
|
| 696 |
+
category_types = [get_category_type(cat) for cat in categories]
|
| 697 |
+
type_counts = {}
|
| 698 |
+
for cat_type in category_types:
|
| 699 |
+
type_counts[cat_type] = type_counts.get(cat_type, 0) + 1
|
| 700 |
+
|
| 701 |
+
# Balanced outfit proportions
|
| 702 |
+
if len(type_counts) >= 3: # Good diversity
|
| 703 |
+
fashion_bonus += 0.1
|
| 704 |
+
if type_counts.get("accessory", 0) <= 2: # Not over-accessorized
|
| 705 |
+
fashion_bonus += 0.05
|
| 706 |
+
|
| 707 |
+
return min(1.0, base_score + fashion_bonus)
|
| 708 |
|
| 709 |
def get_category_type(cat: str) -> str:
|
| 710 |
"""Map category to outfit slot type with comprehensive taxonomy"""
|
|
|
|
| 819 |
return len(unique_categories) >= 2
|
| 820 |
|
| 821 |
def calculate_outfit_score(subset: List[int]) -> float:
|
| 822 |
+
"""Calculate sophisticated outfit quality score using advanced fashion reasoning"""
|
| 823 |
if len(subset) < 2:
|
| 824 |
return 0.0
|
|
|
|
|
|
|
|
|
|
| 825 |
|
| 826 |
+
# 1. Category diversity and completeness
|
| 827 |
+
category_types = [get_category_type(cat_str(i)) for i in subset]
|
| 828 |
+
unique_types = set(category_types)
|
| 829 |
+
diversity_score = len(unique_types) / 5.0 # Normalize to 5 categories max
|
| 830 |
+
|
| 831 |
+
# Completeness bonus for essential categories
|
| 832 |
+
essential_categories = {"upper", "bottom", "shoe"}
|
| 833 |
+
completeness_bonus = 0.0
|
| 834 |
+
if essential_categories.issubset(unique_types):
|
| 835 |
+
completeness_bonus += 0.3 # All essential categories present
|
| 836 |
+
elif len(essential_categories.intersection(unique_types)) >= 2:
|
| 837 |
+
completeness_bonus += 0.15 # Most essential categories present
|
| 838 |
+
|
| 839 |
+
# 2. Advanced style consistency
|
| 840 |
style_score = calculate_style_consistency_score(subset)
|
| 841 |
|
| 842 |
+
# 3. Sophisticated color harmony
|
| 843 |
color_score = calculate_color_consistency_score(subset)
|
| 844 |
|
| 845 |
+
# 4. Context-appropriate length scoring
|
| 846 |
+
occasion = template["context"]["occasion"]
|
| 847 |
+
if occasion == "formal":
|
| 848 |
+
length_score = 1.0 if 4 <= len(subset) <= 5 else 0.6 # Formal prefers complete sets
|
| 849 |
+
elif occasion == "business":
|
| 850 |
+
length_score = 1.0 if 3 <= len(subset) <= 4 else 0.7 # Business balanced
|
| 851 |
+
elif occasion == "sport":
|
| 852 |
+
length_score = 1.0 if 2 <= len(subset) <= 3 else 0.8 # Sport can be minimal
|
| 853 |
+
else: # casual
|
| 854 |
+
length_score = 1.0 if 2 <= len(subset) <= 4 else 0.7 # Casual flexible
|
| 855 |
+
|
| 856 |
+
# 5. Fashion rule compliance
|
| 857 |
+
fashion_rules_score = 0.0
|
| 858 |
|
| 859 |
+
# Rule: No more than one item per core category (except accessories)
|
| 860 |
+
core_categories = {"upper", "bottom", "shoe", "outerwear"}
|
| 861 |
+
core_counts = {cat: category_types.count(cat) for cat in core_categories}
|
| 862 |
+
if all(count <= 1 for count in core_counts.values()):
|
| 863 |
+
fashion_rules_score += 0.2 # Perfect core category distribution
|
| 864 |
+
|
| 865 |
+
# Rule: Appropriate accessory count
|
| 866 |
+
accessory_count = category_types.count("accessory")
|
| 867 |
+
max_accessories = template.get("accessory_limit", 3)
|
| 868 |
+
if accessory_count <= max_accessories:
|
| 869 |
+
fashion_rules_score += 0.1
|
| 870 |
+
if accessory_count > 0 and accessory_count <= 2:
|
| 871 |
+
fashion_rules_score += 0.1 # Bonus for tasteful accessorizing
|
| 872 |
+
|
| 873 |
+
# Rule: Occasion-appropriate formality
|
| 874 |
+
if occasion == "formal" and "outerwear" in unique_types:
|
| 875 |
+
fashion_rules_score += 0.2 # Formal requires outerwear
|
| 876 |
+
elif occasion == "business" and len(unique_types) >= 3:
|
| 877 |
+
fashion_rules_score += 0.15 # Business requires completeness
|
| 878 |
+
elif occasion == "sport" and any("athletic" in cat_str(i) for i in subset):
|
| 879 |
+
fashion_rules_score += 0.1 # Sport requires athletic items
|
| 880 |
+
|
| 881 |
+
# 6. Advanced weighted combination with reasoning
|
| 882 |
+
base_score = (
|
| 883 |
+
0.25 * (diversity_score + completeness_bonus) + # Structure and completeness
|
| 884 |
+
0.30 * style_score + # Style coherence
|
| 885 |
+
0.20 * color_score + # Color harmony
|
| 886 |
+
0.15 * length_score + # Appropriate length
|
| 887 |
+
0.10 * fashion_rules_score # Fashion rule compliance
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
# 7. Context-specific adjustments
|
| 891 |
+
context_adjustment = 0.0
|
| 892 |
+
|
| 893 |
+
# Weather appropriateness
|
| 894 |
+
weather = template["context"]["weather"]
|
| 895 |
+
if weather == "hot" and len(subset) > 4:
|
| 896 |
+
context_adjustment -= 0.1 # Too many layers for hot weather
|
| 897 |
+
elif weather == "cold" and "outerwear" not in unique_types:
|
| 898 |
+
context_adjustment -= 0.1 # Missing outerwear for cold weather
|
| 899 |
+
elif weather == "rain" and not any("boot" in cat_str(i) for i in subset):
|
| 900 |
+
context_adjustment -= 0.05 # Missing weather-appropriate footwear
|
| 901 |
+
|
| 902 |
+
# Occasion-specific adjustments
|
| 903 |
+
if occasion == "formal" and len(subset) < 4:
|
| 904 |
+
context_adjustment -= 0.1 # Formal outfits should be complete
|
| 905 |
+
elif occasion == "sport" and len(subset) > 4:
|
| 906 |
+
context_adjustment -= 0.05 # Sport outfits can be simpler
|
| 907 |
+
|
| 908 |
+
return min(1.0, max(0.0, base_score + context_adjustment))
|
| 909 |
|
| 910 |
+
# Advanced candidate generation with sophisticated reasoning
|
| 911 |
for _ in range(num_samples):
|
| 912 |
subset = []
|
| 913 |
|
| 914 |
+
# 1. Advanced context-aware outfit length selection
|
| 915 |
occasion = template["context"]["occasion"]
|
| 916 |
+
weather = template["context"]["weather"]
|
| 917 |
+
outfit_style = template["context"]["style"]
|
| 918 |
+
|
| 919 |
+
# Base length probabilities
|
| 920 |
if occasion == "formal":
|
| 921 |
+
if weather == "hot":
|
| 922 |
+
outfit_length = rng.choice([3, 4], p=[0.4, 0.6]) # Formal but weather-appropriate
|
| 923 |
+
else:
|
| 924 |
+
outfit_length = rng.choice([4, 5], p=[0.6, 0.4]) # Complete formal sets
|
| 925 |
elif occasion == "business":
|
| 926 |
+
outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) # Professional balance
|
| 927 |
elif occasion == "sport":
|
| 928 |
+
if weather == "hot":
|
| 929 |
+
outfit_length = rng.choice([2, 3], p=[0.6, 0.4]) # Minimal for hot weather
|
| 930 |
+
else:
|
| 931 |
+
outfit_length = rng.choice([2, 3, 4], p=[0.3, 0.5, 0.2]) # Sport flexibility
|
| 932 |
else: # casual
|
| 933 |
+
if weather == "hot":
|
| 934 |
+
outfit_length = rng.choice([2, 3, 4], p=[0.4, 0.4, 0.2]) # Casual, weather-appropriate
|
| 935 |
+
else:
|
| 936 |
+
outfit_length = rng.choice([2, 3, 4, 5], p=[0.2, 0.4, 0.3, 0.1]) # Casual flexibility
|
| 937 |
|
| 938 |
+
# 2. Advanced strategy selection with reasoning
|
| 939 |
+
strategy_weights = [0.4, 0.3, 0.3] # Default: Core, Accessory-focused, Flexible
|
| 940 |
+
|
| 941 |
+
# Formal occasions prioritize complete, structured outfits
|
| 942 |
if occasion == "formal":
|
| 943 |
+
strategy_weights = [0.7, 0.1, 0.2] # Strongly favor core outfits
|
| 944 |
+
# Business occasions need professional balance
|
| 945 |
+
elif occasion == "business":
|
| 946 |
+
strategy_weights = [0.6, 0.2, 0.2] # Favor core with some flexibility
|
| 947 |
+
# Sport occasions can be more flexible
|
| 948 |
elif occasion == "sport":
|
| 949 |
+
strategy_weights = [0.3, 0.2, 0.5] # Favor flexible combinations
|
| 950 |
+
# Traditional outfits need cultural coherence
|
| 951 |
+
elif outfit_style == "traditional":
|
| 952 |
+
strategy_weights = [0.8, 0.1, 0.1] # Strongly favor traditional core sets
|
| 953 |
+
# Casual occasions allow more creativity
|
| 954 |
+
else:
|
| 955 |
+
strategy_weights = [0.4, 0.3, 0.3] # Balanced approach
|
| 956 |
|
| 957 |
strategy = rng.choice([0, 1, 2], p=strategy_weights)
|
| 958 |
|
|
|
|
| 1055 |
if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset):
|
| 1056 |
# Add randomization factor to prevent identical recommendations
|
| 1057 |
subset = rng.permutation(subset).tolist() # Randomize order
|
| 1058 |
+
candidates.append(subset)
|
| 1059 |
+
if len(candidates) % 10 == 0: # Log every 10 candidates
|
| 1060 |
+
print(f"π DEBUG: Generated {len(candidates)} candidates so far...")
|
| 1061 |
|
| 1062 |
print(f"π DEBUG: Generated {len(candidates)} total candidates")
|
| 1063 |
|
|
|
|
| 1101 |
return True
|
| 1102 |
|
| 1103 |
def calculate_outfit_penalty(subset: List[int], base_score: float) -> float:
|
| 1104 |
+
"""Calculate sophisticated penalty-adjusted score with advanced fashion reasoning"""
|
| 1105 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 1106 |
category_counts = {}
|
| 1107 |
|
|
|
|
| 1111 |
penalty = 0.0
|
| 1112 |
bonus = 0.0
|
| 1113 |
|
| 1114 |
+
# 1. Critical fashion violations (severe penalties)
|
| 1115 |
+
# Missing essential categories: -β penalty
|
| 1116 |
if category_counts.get("upper", 0) == 0:
|
| 1117 |
penalty += -1000.0
|
| 1118 |
if category_counts.get("bottom", 0) == 0:
|
|
|
|
| 1120 |
if category_counts.get("shoe", 0) == 0:
|
| 1121 |
penalty += -1000.0
|
| 1122 |
|
| 1123 |
+
# Duplicate core categories: -β penalty (fashion rule violation)
|
| 1124 |
+
core_categories = {"upper", "bottom", "shoe", "outerwear"}
|
| 1125 |
+
for cat in core_categories:
|
| 1126 |
+
if category_counts.get(cat, 0) > 1:
|
| 1127 |
+
penalty += -1000.0
|
| 1128 |
+
|
| 1129 |
+
# 2. Context-specific critical violations
|
| 1130 |
if occasion == "formal" and category_counts.get("outerwear", 0) == 0:
|
| 1131 |
+
penalty += -500.0 # Formal without jacket is inappropriate
|
| 1132 |
+
elif occasion == "business" and len(subset) < 3:
|
| 1133 |
+
penalty += -200.0 # Business outfits should be complete
|
| 1134 |
+
elif occasion == "sport" and not any("athletic" in cat_str(i) for i in subset):
|
| 1135 |
+
penalty += -300.0 # Sport outfits need athletic items
|
| 1136 |
|
| 1137 |
+
# 3. Weather-appropriate violations
|
| 1138 |
+
weather = template["context"]["weather"]
|
| 1139 |
+
if weather == "hot" and len(subset) > 5:
|
| 1140 |
+
penalty += -100.0 # Too many layers for hot weather
|
| 1141 |
+
elif weather == "cold" and category_counts.get("outerwear", 0) == 0:
|
| 1142 |
+
penalty += -150.0 # Missing outerwear for cold weather
|
| 1143 |
+
elif weather == "rain" and not any("boot" in cat_str(i) for i in subset):
|
| 1144 |
+
penalty += -50.0 # Missing weather-appropriate footwear
|
| 1145 |
|
| 1146 |
+
# 4. Accessory violations
|
| 1147 |
max_accs = template["accessory_limit"]
|
| 1148 |
+
accessory_count = category_counts.get("accessory", 0)
|
| 1149 |
+
if accessory_count > max_accs:
|
| 1150 |
+
penalty += -50.0 * (accessory_count - max_accs) # Proportional penalty
|
| 1151 |
|
| 1152 |
+
# 5. Outfit balance violations
|
| 1153 |
+
if len(subset) < 2:
|
| 1154 |
+
penalty += -200.0 # Too minimal
|
| 1155 |
elif len(subset) > 6:
|
| 1156 |
+
penalty += -100.0 # Too complex
|
| 1157 |
+
elif len(subset) == 2 and occasion in ["formal", "business"]:
|
| 1158 |
+
penalty += -100.0 # Too minimal for formal/business
|
| 1159 |
|
| 1160 |
+
# 6. Advanced bonus system
|
| 1161 |
+
# Style consistency bonus (weighted by importance)
|
| 1162 |
style_score = calculate_style_consistency_score(subset)
|
| 1163 |
+
bonus += style_score * 0.6 # Increased weight for style
|
| 1164 |
|
| 1165 |
+
# Color harmony bonus
|
| 1166 |
color_score = calculate_color_consistency_score(subset)
|
| 1167 |
+
bonus += color_score * 0.4 # Increased weight for color
|
| 1168 |
|
| 1169 |
+
# 7. Context-specific bonuses
|
| 1170 |
+
# Formal outfit bonuses
|
| 1171 |
+
if occasion == "formal":
|
| 1172 |
+
if "outerwear" in categories:
|
| 1173 |
+
bonus += 0.6 # Strong bonus for proper formal layering
|
| 1174 |
+
if len([c for c in categories if c in ["upper", "bottom", "shoe", "outerwear"]]) >= 4:
|
| 1175 |
+
bonus += 0.4 # Complete formal set bonus
|
| 1176 |
+
if style_score > 0.7:
|
| 1177 |
+
bonus += 0.3 # High style coherence bonus
|
| 1178 |
|
| 1179 |
+
# Business outfit bonuses
|
| 1180 |
+
elif occasion == "business":
|
| 1181 |
+
if len(categories) >= 3:
|
| 1182 |
+
bonus += 0.3 # Professional completeness
|
| 1183 |
+
if "outerwear" in categories:
|
| 1184 |
+
bonus += 0.2 # Elevated business look
|
| 1185 |
+
if style_score > 0.6:
|
| 1186 |
+
bonus += 0.2 # Professional style bonus
|
| 1187 |
|
| 1188 |
+
# Sport outfit bonuses
|
| 1189 |
+
elif occasion == "sport":
|
| 1190 |
+
if any("athletic" in cat_str(i) for i in subset):
|
| 1191 |
+
bonus += 0.4 # Athletic functionality
|
| 1192 |
+
if len(subset) <= 3:
|
| 1193 |
+
bonus += 0.2 # Appropriate minimalism for sport
|
| 1194 |
|
| 1195 |
+
# 8. Traditional Pakistani outfit bonuses
|
| 1196 |
if outfit_style == "traditional":
|
| 1197 |
traditional_items = [cat for cat in categories if any(traditional in cat.lower() for traditional in ["kameez", "kurta", "shalwar", "peshawari", "chappal"])]
|
| 1198 |
+
if len(traditional_items) >= 2:
|
| 1199 |
+
bonus += 0.7 # Strong cultural appropriateness bonus
|
| 1200 |
+
if len(traditional_items) == 3:
|
| 1201 |
+
bonus += 0.4 # Complete traditional set bonus
|
| 1202 |
+
if style_score > 0.6:
|
| 1203 |
+
bonus += 0.3 # Traditional style coherence
|
| 1204 |
+
|
| 1205 |
+
# 9. Fashion rule compliance bonuses
|
| 1206 |
+
# Perfect category distribution
|
| 1207 |
+
if all(category_counts.get(cat, 0) <= 1 for cat in core_categories):
|
| 1208 |
+
bonus += 0.3 # Perfect fashion rule compliance
|
| 1209 |
+
|
| 1210 |
+
# Tasteful accessorizing
|
| 1211 |
+
if 1 <= accessory_count <= 2:
|
| 1212 |
+
bonus += 0.2 # Tasteful accessorizing bonus
|
| 1213 |
+
|
| 1214 |
+
# 10. Weather-appropriate bonuses
|
| 1215 |
+
if weather == "hot" and len(subset) <= 4:
|
| 1216 |
+
bonus += 0.1 # Appropriate for hot weather
|
| 1217 |
+
elif weather == "cold" and "outerwear" in categories:
|
| 1218 |
+
bonus += 0.2 # Proper cold weather preparation
|
| 1219 |
+
elif weather == "rain" and any("boot" in cat_str(i) for i in subset):
|
| 1220 |
+
bonus += 0.15 # Weather-appropriate footwear
|
| 1221 |
+
|
| 1222 |
+
# 11. Overall outfit quality bonus
|
| 1223 |
+
if style_score > 0.8 and color_score > 0.7:
|
| 1224 |
+
bonus += 0.3 # Exceptional outfit quality
|
| 1225 |
+
elif style_score > 0.6 and color_score > 0.5:
|
| 1226 |
+
bonus += 0.2 # Good outfit quality
|
| 1227 |
|
| 1228 |
return base_score + penalty + bonus
|
| 1229 |
|
|
|
|
| 1285 |
else:
|
| 1286 |
# If we have fewer candidates than requested, shuffle them
|
| 1287 |
rng.shuffle(scored)
|
| 1288 |
+
topk = scored[:num_outfits]
|
| 1289 |
|
| 1290 |
results = []
|
| 1291 |
for subset, adjusted_score, base_score in topk:
|
|
|
|
| 1352 |
# Disable gradients
|
| 1353 |
for m in [self.resnet, self.vit]:
|
| 1354 |
if m is not None:
|
| 1355 |
+
for p in m.parameters():
|
| 1356 |
+
p.requires_grad_(False)
|
| 1357 |
|
| 1358 |
# Update overall status
|
| 1359 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|