Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
6fa8724
1
Parent(s):
25ea1c9
I promise, these are the last changes
Browse files- inference.py +154 -26
inference.py
CHANGED
|
@@ -386,42 +386,113 @@ class InferenceService:
|
|
| 386 |
print("π DEBUG: Starting candidate generation...")
|
| 387 |
|
| 388 |
# 2) Candidate generation with outfit templates
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
num_outfits = int(context.get("num_outfits", 5)) # Increased default from 3 to 5
|
| 391 |
min_size, max_size = 2, 6 # Allow smaller outfits (2 items minimum)
|
| 392 |
ids = list(range(len(proc_items)))
|
| 393 |
|
| 394 |
-
#
|
| 395 |
outfit_templates = {
|
| 396 |
"casual": {
|
| 397 |
"style": "relaxed, comfortable, everyday",
|
| 398 |
-
"preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant"],
|
| 399 |
"color_palette": ["neutral", "denim", "white", "black", "gray"],
|
| 400 |
-
"accessory_limit": 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
},
|
| 402 |
"smart_casual": {
|
| 403 |
"style": "polished but relaxed, business casual",
|
| 404 |
-
"preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo"],
|
| 405 |
"color_palette": ["navy", "white", "khaki", "brown", "gray"],
|
| 406 |
-
"accessory_limit": 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
},
|
| 408 |
"formal": {
|
| 409 |
"style": "professional, elegant, sophisticated",
|
| 410 |
-
"preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit"],
|
| 411 |
"color_palette": ["navy", "black", "white", "gray", "charcoal"],
|
| 412 |
-
"accessory_limit": 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
},
|
| 414 |
"sporty": {
|
| 415 |
"style": "athletic, active, performance",
|
| 416 |
-
"preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging"],
|
| 417 |
"color_palette": ["bright", "neon", "white", "black", "primary colors"],
|
| 418 |
-
"accessory_limit": 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
}
|
| 420 |
}
|
| 421 |
|
| 422 |
-
#
|
| 423 |
-
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
# Enhanced category-aware pools with diversity checks
|
| 427 |
def cat_str(i: int) -> str:
|
|
@@ -483,7 +554,30 @@ class InferenceService:
|
|
| 483 |
if any(pref in cat for pref in preferred_cats):
|
| 484 |
matches += 1
|
| 485 |
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
def get_category_type(cat: str) -> str:
|
| 489 |
"""Map category to outfit slot type with comprehensive taxonomy"""
|
|
@@ -599,15 +693,32 @@ class InferenceService:
|
|
| 599 |
# Weighted combination
|
| 600 |
return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
|
| 601 |
|
| 602 |
-
#
|
| 603 |
for _ in range(num_samples):
|
| 604 |
subset = []
|
| 605 |
|
| 606 |
-
#
|
| 607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
|
| 609 |
# Strategy 1: Core outfit (shirt + pants + shoes) + accessories
|
| 610 |
-
if
|
| 611 |
# Core outfit: exactly 1 of each required slot
|
| 612 |
subset.append(int(rng.choice(uppers)))
|
| 613 |
subset.append(int(rng.choice(bottoms)))
|
|
@@ -634,7 +745,7 @@ class InferenceService:
|
|
| 634 |
subset.extend(selected_others.tolist())
|
| 635 |
|
| 636 |
# Strategy 2: Accessory-focused outfit (prioritize accessories)
|
| 637 |
-
elif
|
| 638 |
# Start with accessories if available
|
| 639 |
num_accs = min(outfit_length, len(accs))
|
| 640 |
selected_accs = rng.choice(accs, size=num_accs, replace=False)
|
|
@@ -656,7 +767,7 @@ class InferenceService:
|
|
| 656 |
subset.extend(selected_others.tolist())
|
| 657 |
|
| 658 |
# Strategy 3: Flexible combination (no strict slot requirements)
|
| 659 |
-
|
| 660 |
# Randomly select items from all categories
|
| 661 |
all_items = list(ids)
|
| 662 |
rng.shuffle(all_items)
|
|
@@ -807,16 +918,33 @@ class InferenceService:
|
|
| 807 |
print(f"π DEBUG: Removed {len(scored) - len(unique_scored)} duplicate outfits")
|
| 808 |
scored = unique_scored
|
| 809 |
|
| 810 |
-
#
|
| 811 |
if len(scored) > num_outfits:
|
| 812 |
-
#
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
else:
|
| 817 |
# If we have fewer candidates than requested, shuffle them
|
| 818 |
rng.shuffle(scored)
|
| 819 |
-
|
| 820 |
|
| 821 |
results = []
|
| 822 |
for subset, adjusted_score, base_score in topk:
|
|
|
|
| 386 |
print("π DEBUG: Starting candidate generation...")
|
| 387 |
|
| 388 |
# 2) Candidate generation with outfit templates
|
| 389 |
+
# Use timestamp-based seed for better randomization
|
| 390 |
+
import time
|
| 391 |
+
seed = context.get("seed", int(time.time() * 1000) % 10000)
|
| 392 |
+
rng = np.random.default_rng(seed)
|
| 393 |
+
print(f"π DEBUG: Using random seed: {seed}")
|
| 394 |
num_outfits = int(context.get("num_outfits", 5)) # Increased default from 3 to 5
|
| 395 |
min_size, max_size = 2, 6 # Allow smaller outfits (2 items minimum)
|
| 396 |
ids = list(range(len(proc_items)))
|
| 397 |
|
| 398 |
+
# Enhanced context-aware outfit templates
|
| 399 |
outfit_templates = {
|
| 400 |
"casual": {
|
| 401 |
"style": "relaxed, comfortable, everyday",
|
| 402 |
+
"preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant", "shirt", "pants", "shoes"],
|
| 403 |
"color_palette": ["neutral", "denim", "white", "black", "gray"],
|
| 404 |
+
"accessory_limit": 2,
|
| 405 |
+
"weather_modifiers": {
|
| 406 |
+
"hot": {"preferred_categories": ["tank", "short", "sandal", "light shirt"]},
|
| 407 |
+
"cold": {"preferred_categories": ["hoodie", "sweater", "jacket", "boot"]},
|
| 408 |
+
"rain": {"preferred_categories": ["jacket", "boot", "waterproof"]}
|
| 409 |
+
},
|
| 410 |
+
"occasion_modifiers": {
|
| 411 |
+
"business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3},
|
| 412 |
+
"formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4}
|
| 413 |
+
}
|
| 414 |
},
|
| 415 |
"smart_casual": {
|
| 416 |
"style": "polished but relaxed, business casual",
|
| 417 |
+
"preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo", "pants", "shoes"],
|
| 418 |
"color_palette": ["navy", "white", "khaki", "brown", "gray"],
|
| 419 |
+
"accessory_limit": 3,
|
| 420 |
+
"weather_modifiers": {
|
| 421 |
+
"hot": {"preferred_categories": ["polo", "light shirt", "loafer"]},
|
| 422 |
+
"cold": {"preferred_categories": ["blazer", "sweater", "boot"]},
|
| 423 |
+
"rain": {"preferred_categories": ["blazer", "boot", "umbrella"]}
|
| 424 |
+
},
|
| 425 |
+
"occasion_modifiers": {
|
| 426 |
+
"business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4},
|
| 427 |
+
"formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4}
|
| 428 |
+
}
|
| 429 |
},
|
| 430 |
"formal": {
|
| 431 |
"style": "professional, elegant, sophisticated",
|
| 432 |
+
"preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit", "shirt", "pants", "shoes"],
|
| 433 |
"color_palette": ["navy", "black", "white", "gray", "charcoal"],
|
| 434 |
+
"accessory_limit": 4,
|
| 435 |
+
"weather_modifiers": {
|
| 436 |
+
"hot": {"preferred_categories": ["light shirt", "light pant", "oxford"]},
|
| 437 |
+
"cold": {"preferred_categories": ["blazer", "suit", "boot"]},
|
| 438 |
+
"rain": {"preferred_categories": ["blazer", "boot", "umbrella"]}
|
| 439 |
+
},
|
| 440 |
+
"occasion_modifiers": {
|
| 441 |
+
"business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4},
|
| 442 |
+
"casual": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
|
| 443 |
+
}
|
| 444 |
},
|
| 445 |
"sporty": {
|
| 446 |
"style": "athletic, active, performance",
|
| 447 |
+
"preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging", "shirt", "pants", "shoes"],
|
| 448 |
"color_palette": ["bright", "neon", "white", "black", "primary colors"],
|
| 449 |
+
"accessory_limit": 1,
|
| 450 |
+
"weather_modifiers": {
|
| 451 |
+
"hot": {"preferred_categories": ["tank", "short", "running shoe"]},
|
| 452 |
+
"cold": {"preferred_categories": ["hoodie", "legging", "running shoe"]},
|
| 453 |
+
"rain": {"preferred_categories": ["jacket", "running shoe", "cap"]}
|
| 454 |
+
},
|
| 455 |
+
"occasion_modifiers": {
|
| 456 |
+
"business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 2},
|
| 457 |
+
"formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
|
| 458 |
+
}
|
| 459 |
}
|
| 460 |
}
|
| 461 |
|
| 462 |
+
# Context-aware template selection and modification
|
| 463 |
+
occasion = context.get("occasion", "casual")
|
| 464 |
+
weather = context.get("weather", "any")
|
| 465 |
+
outfit_style = context.get("outfit_style", "casual")
|
| 466 |
+
|
| 467 |
+
# Select base template
|
| 468 |
+
template_name = outfit_style
|
| 469 |
+
template = outfit_templates[template_name].copy()
|
| 470 |
+
|
| 471 |
+
# Apply weather modifications
|
| 472 |
+
if weather != "any" and weather in template.get("weather_modifiers", {}):
|
| 473 |
+
weather_mod = template["weather_modifiers"][weather]
|
| 474 |
+
template["preferred_categories"].extend(weather_mod.get("preferred_categories", []))
|
| 475 |
+
if "accessory_limit" in weather_mod:
|
| 476 |
+
template["accessory_limit"] = weather_mod["accessory_limit"]
|
| 477 |
+
|
| 478 |
+
# Apply occasion modifications
|
| 479 |
+
if occasion in template.get("occasion_modifiers", {}):
|
| 480 |
+
occasion_mod = template["occasion_modifiers"][occasion]
|
| 481 |
+
template["preferred_categories"].extend(occasion_mod.get("preferred_categories", []))
|
| 482 |
+
if "accessory_limit" in occasion_mod:
|
| 483 |
+
template["accessory_limit"] = occasion_mod["accessory_limit"]
|
| 484 |
+
|
| 485 |
+
# Remove duplicates and add context info
|
| 486 |
+
template["preferred_categories"] = list(set(template["preferred_categories"]))
|
| 487 |
+
template["context"] = {
|
| 488 |
+
"occasion": occasion,
|
| 489 |
+
"weather": weather,
|
| 490 |
+
"style": outfit_style
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
print(f"π DEBUG: Using template '{template_name}' with context: occasion={occasion}, weather={weather}")
|
| 494 |
+
print(f"π DEBUG: Template categories: {template['preferred_categories']}")
|
| 495 |
+
print(f"π DEBUG: Accessory limit: {template['accessory_limit']}")
|
| 496 |
|
| 497 |
# Enhanced category-aware pools with diversity checks
|
| 498 |
def cat_str(i: int) -> str:
|
|
|
|
| 554 |
if any(pref in cat for pref in preferred_cats):
|
| 555 |
matches += 1
|
| 556 |
|
| 557 |
+
base_score = matches / len(categories) if categories else 0.0
|
| 558 |
+
|
| 559 |
+
# Bonus for context-specific categories
|
| 560 |
+
context_bonus = 0.0
|
| 561 |
+
occasion = template["context"]["occasion"]
|
| 562 |
+
weather = template["context"]["weather"]
|
| 563 |
+
|
| 564 |
+
# Occasion-specific bonuses
|
| 565 |
+
if occasion == "business" and any("shirt" in cat or "pants" in cat for cat in categories):
|
| 566 |
+
context_bonus += 0.2
|
| 567 |
+
elif occasion == "formal" and any("shirt" in cat or "pants" in cat for cat in categories):
|
| 568 |
+
context_bonus += 0.3
|
| 569 |
+
elif occasion == "sport" and any("shirt" in cat or "pants" in cat for cat in categories):
|
| 570 |
+
context_bonus += 0.2
|
| 571 |
+
|
| 572 |
+
# Weather-specific bonuses
|
| 573 |
+
if weather == "hot" and any("shirt" in cat for cat in categories):
|
| 574 |
+
context_bonus += 0.1
|
| 575 |
+
elif weather == "cold" and any("shirt" in cat or "pants" in cat for cat in categories):
|
| 576 |
+
context_bonus += 0.1
|
| 577 |
+
elif weather == "rain" and any("shoes" in cat for cat in categories):
|
| 578 |
+
context_bonus += 0.1
|
| 579 |
+
|
| 580 |
+
return min(1.0, base_score + context_bonus)
|
| 581 |
|
| 582 |
def get_category_type(cat: str) -> str:
|
| 583 |
"""Map category to outfit slot type with comprehensive taxonomy"""
|
|
|
|
| 693 |
# Weighted combination
|
| 694 |
return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
|
| 695 |
|
| 696 |
+
# Context-aware candidate generation with enhanced randomization
|
| 697 |
for _ in range(num_samples):
|
| 698 |
subset = []
|
| 699 |
|
| 700 |
+
# Context-aware outfit length selection
|
| 701 |
+
occasion = template["context"]["occasion"]
|
| 702 |
+
if occasion == "formal":
|
| 703 |
+
outfit_length = rng.choice([3, 4, 5], p=[0.2, 0.5, 0.3]) # Formal outfits tend to be more complete
|
| 704 |
+
elif occasion == "business":
|
| 705 |
+
outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) # Business outfits are well-rounded
|
| 706 |
+
elif occasion == "sport":
|
| 707 |
+
outfit_length = rng.choice([2, 3, 4], p=[0.3, 0.5, 0.2]) # Sport outfits can be minimal
|
| 708 |
+
else: # casual
|
| 709 |
+
outfit_length = rng.choice([2, 3, 4, 5], p=[0.2, 0.4, 0.3, 0.1]) # Casual is flexible
|
| 710 |
+
|
| 711 |
+
# Context-aware strategy selection
|
| 712 |
+
strategy_weights = [0.4, 0.3, 0.3] # Default weights
|
| 713 |
+
if occasion == "formal":
|
| 714 |
+
strategy_weights = [0.6, 0.2, 0.2] # Favor core outfits for formal
|
| 715 |
+
elif occasion == "sport":
|
| 716 |
+
strategy_weights = [0.3, 0.5, 0.2] # Favor flexible combinations for sport
|
| 717 |
+
|
| 718 |
+
strategy = rng.choice([0, 1, 2], p=strategy_weights)
|
| 719 |
|
| 720 |
# Strategy 1: Core outfit (shirt + pants + shoes) + accessories
|
| 721 |
+
if strategy == 0 and uppers and bottoms and shoes:
|
| 722 |
# Core outfit: exactly 1 of each required slot
|
| 723 |
subset.append(int(rng.choice(uppers)))
|
| 724 |
subset.append(int(rng.choice(bottoms)))
|
|
|
|
| 745 |
subset.extend(selected_others.tolist())
|
| 746 |
|
| 747 |
# Strategy 2: Accessory-focused outfit (prioritize accessories)
|
| 748 |
+
elif strategy == 1 and accs:
|
| 749 |
# Start with accessories if available
|
| 750 |
num_accs = min(outfit_length, len(accs))
|
| 751 |
selected_accs = rng.choice(accs, size=num_accs, replace=False)
|
|
|
|
| 767 |
subset.extend(selected_others.tolist())
|
| 768 |
|
| 769 |
# Strategy 3: Flexible combination (no strict slot requirements)
|
| 770 |
+
elif strategy == 2:
|
| 771 |
# Randomly select items from all categories
|
| 772 |
all_items = list(ids)
|
| 773 |
rng.shuffle(all_items)
|
|
|
|
| 918 |
print(f"π DEBUG: Removed {len(scored) - len(unique_scored)} duplicate outfits")
|
| 919 |
scored = unique_scored
|
| 920 |
|
| 921 |
+
# Enhanced randomization with context awareness
|
| 922 |
if len(scored) > num_outfits:
|
| 923 |
+
# Context-aware selection: prefer higher-scoring outfits but add diversity
|
| 924 |
+
top_third = scored[:max(num_outfits * 3, len(scored) // 3)]
|
| 925 |
+
middle_third = scored[max(num_outfits * 3, len(scored) // 3):max(num_outfits * 6, len(scored) * 2 // 3)]
|
| 926 |
+
|
| 927 |
+
# Select mix of high-scoring and diverse outfits
|
| 928 |
+
selected = []
|
| 929 |
+
|
| 930 |
+
# Take 70% from top third (high quality)
|
| 931 |
+
top_count = int(num_outfits * 0.7)
|
| 932 |
+
rng.shuffle(top_third)
|
| 933 |
+
selected.extend(top_third[:top_count])
|
| 934 |
+
|
| 935 |
+
# Take 30% from middle third (diversity)
|
| 936 |
+
middle_count = num_outfits - len(selected)
|
| 937 |
+
if middle_count > 0 and middle_third:
|
| 938 |
+
rng.shuffle(middle_third)
|
| 939 |
+
selected.extend(middle_third[:middle_count])
|
| 940 |
+
|
| 941 |
+
# Shuffle final selection for randomness
|
| 942 |
+
rng.shuffle(selected)
|
| 943 |
+
topk = selected[:num_outfits]
|
| 944 |
else:
|
| 945 |
# If we have fewer candidates than requested, shuffle them
|
| 946 |
rng.shuffle(scored)
|
| 947 |
+
topk = scored[:num_outfits]
|
| 948 |
|
| 949 |
results = []
|
| 950 |
for subset, adjusted_score, base_score in topk:
|