Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
ac45468
1
Parent(s):
3dd2128
More robust recommendations overall
Browse files- inference.py +130 -51
inference.py
CHANGED
|
@@ -382,7 +382,7 @@ class InferenceService:
|
|
| 382 |
if len(proc_items) < 2:
|
| 383 |
print("π DEBUG: Returning empty array - not enough items (< 2)")
|
| 384 |
return []
|
| 385 |
-
|
| 386 |
print("π DEBUG: Starting candidate generation...")
|
| 387 |
|
| 388 |
# 2) Candidate generation with outfit templates
|
|
@@ -426,7 +426,7 @@ class InferenceService:
|
|
| 426 |
# Enhanced category-aware pools with diversity checks
|
| 427 |
def cat_str(i: int) -> str:
|
| 428 |
return (proc_items[i].get("category") or "").lower()
|
| 429 |
-
|
| 430 |
print("π DEBUG: Building category pools...")
|
| 431 |
# Debug: Print all categories
|
| 432 |
for i in range(len(proc_items)):
|
|
@@ -490,6 +490,16 @@ class InferenceService:
|
|
| 490 |
cat_lower = cat.lower().strip()
|
| 491 |
print(f"π DEBUG: Mapping category '{cat}' -> '{cat_lower}'")
|
| 492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
# Upper body items (tops, outerwear)
|
| 494 |
upper_keywords = [
|
| 495 |
"top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
|
|
@@ -547,74 +557,131 @@ class InferenceService:
|
|
| 547 |
|
| 548 |
print(f"π DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
|
| 549 |
|
| 550 |
-
# Check if we have
|
| 551 |
-
|
| 552 |
-
|
|
|
|
| 553 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
candidates: List[List[int]] = []
|
| 556 |
-
num_samples = max(num_outfits *
|
| 557 |
print(f"π DEBUG: Generating {num_samples} candidate outfits...")
|
| 558 |
|
| 559 |
def has_category_diversity(subset: List[int]) -> bool:
|
| 560 |
"""Check if subset has good category diversity"""
|
| 561 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 562 |
unique_categories = set(categories)
|
| 563 |
-
# Require at least
|
| 564 |
-
return len(unique_categories) >=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
|
|
|
| 566 |
for _ in range(num_samples):
|
| 567 |
subset = []
|
| 568 |
|
| 569 |
-
#
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
| 571 |
# Core outfit: exactly 1 of each required slot
|
| 572 |
subset.append(int(rng.choice(uppers)))
|
| 573 |
subset.append(int(rng.choice(bottoms)))
|
| 574 |
subset.append(int(rng.choice(shoes)))
|
| 575 |
|
| 576 |
-
# Add accessories based on template limit
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
|
|
|
| 580 |
available_accs = [i for i in accs if i not in subset]
|
| 581 |
-
if available_accs:
|
| 582 |
selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
|
| 583 |
subset.extend(selected_accs.tolist())
|
| 584 |
|
| 585 |
-
#
|
| 586 |
-
|
|
|
|
| 587 |
available_others = [i for i in others if i not in subset]
|
| 588 |
-
if available_others
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
|
|
|
|
|
|
| 596 |
|
| 597 |
-
#
|
| 598 |
-
|
| 599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
|
| 601 |
-
#
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
if
|
| 606 |
-
|
| 607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
|
| 609 |
# Remove duplicates and validate
|
| 610 |
subset = list(set(subset))
|
| 611 |
-
if len(subset) >=
|
|
|
|
|
|
|
| 612 |
candidates.append(subset)
|
| 613 |
if len(candidates) % 10 == 0: # Log every 10 candidates
|
| 614 |
print(f"π DEBUG: Generated {len(candidates)} candidates so far...")
|
| 615 |
|
| 616 |
print(f"π DEBUG: Generated {len(candidates)} total candidates")
|
| 617 |
-
|
| 618 |
# 3) Score using ViT
|
| 619 |
def score_subset(idx_subset: List[int]) -> float:
|
| 620 |
embs = torch.tensor(
|
|
@@ -628,26 +695,28 @@ class InferenceService:
|
|
| 628 |
|
| 629 |
# Enhanced validation with strict slot constraints
|
| 630 |
def is_valid_outfit(subset: List[int]) -> bool:
|
| 631 |
-
"""Check if outfit meets
|
|
|
|
|
|
|
|
|
|
| 632 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 633 |
category_counts = {}
|
| 634 |
|
| 635 |
for cat in categories:
|
| 636 |
category_counts[cat] = category_counts.get(cat, 0) + 1
|
| 637 |
|
| 638 |
-
#
|
| 639 |
-
# -
|
| 640 |
-
# -
|
| 641 |
-
# -
|
| 642 |
-
|
|
|
|
| 643 |
return False
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
if category_counts.get("
|
| 647 |
-
return False
|
| 648 |
-
if category_counts.get("accessory", 0) > 2:
|
| 649 |
return False
|
| 650 |
-
if category_counts.get("other", 0) >
|
| 651 |
return False
|
| 652 |
|
| 653 |
return True
|
|
@@ -714,9 +783,19 @@ class InferenceService:
|
|
| 714 |
adjusted_score = calculate_outfit_penalty(subset, base_score)
|
| 715 |
scored.append((subset, adjusted_score, base_score))
|
| 716 |
|
| 717 |
-
# Sort by penalty-adjusted score
|
| 718 |
scored.sort(key=lambda x: x[1], reverse=True)
|
| 719 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
|
| 721 |
results = []
|
| 722 |
for subset, adjusted_score, base_score in topk:
|
|
|
|
| 382 |
if len(proc_items) < 2:
|
| 383 |
print("π DEBUG: Returning empty array - not enough items (< 2)")
|
| 384 |
return []
|
| 385 |
+
|
| 386 |
print("π DEBUG: Starting candidate generation...")
|
| 387 |
|
| 388 |
# 2) Candidate generation with outfit templates
|
|
|
|
| 426 |
# Enhanced category-aware pools with diversity checks
|
| 427 |
def cat_str(i: int) -> str:
|
| 428 |
return (proc_items[i].get("category") or "").lower()
|
| 429 |
+
|
| 430 |
print("π DEBUG: Building category pools...")
|
| 431 |
# Debug: Print all categories
|
| 432 |
for i in range(len(proc_items)):
|
|
|
|
| 490 |
cat_lower = cat.lower().strip()
|
| 491 |
print(f"π DEBUG: Mapping category '{cat}' -> '{cat_lower}'")
|
| 492 |
|
| 493 |
+
# Direct mapping for CLIP-detected categories
|
| 494 |
+
if cat_lower == "shirt":
|
| 495 |
+
return "upper"
|
| 496 |
+
elif cat_lower == "pants":
|
| 497 |
+
return "bottom"
|
| 498 |
+
elif cat_lower == "shoes":
|
| 499 |
+
return "shoe"
|
| 500 |
+
elif cat_lower == "accessory":
|
| 501 |
+
return "accessory"
|
| 502 |
+
|
| 503 |
# Upper body items (tops, outerwear)
|
| 504 |
upper_keywords = [
|
| 505 |
"top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
|
|
|
|
| 557 |
|
| 558 |
print(f"π DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
|
| 559 |
|
| 560 |
+
# Check if we have enough items to create outfits
|
| 561 |
+
total_items = len(uppers) + len(bottoms) + len(shoes) + len(accs) + len(others)
|
| 562 |
+
if total_items < 2:
|
| 563 |
+
print(f"π DEBUG: Not enough items to create outfits - total: {total_items}")
|
| 564 |
return []
|
| 565 |
+
|
| 566 |
+
# Warn if we're missing core categories but still try to generate
|
| 567 |
+
if len(uppers) == 0 or len(bottoms) == 0 or len(shoes) == 0:
|
| 568 |
+
print(f"π DEBUG: Missing some core categories - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}")
|
| 569 |
+
print(f"π DEBUG: Will use flexible outfit generation with available items")
|
| 570 |
|
| 571 |
candidates: List[List[int]] = []
|
| 572 |
+
num_samples = max(num_outfits * 15, 30) # Increased for more variety
|
| 573 |
print(f"π DEBUG: Generating {num_samples} candidate outfits...")
|
| 574 |
|
| 575 |
def has_category_diversity(subset: List[int]) -> bool:
|
| 576 |
"""Check if subset has good category diversity"""
|
| 577 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 578 |
unique_categories = set(categories)
|
| 579 |
+
# Require at least 2 different category types for good diversity
|
| 580 |
+
return len(unique_categories) >= 2
|
| 581 |
+
|
| 582 |
+
def calculate_outfit_score(subset: List[int]) -> float:
|
| 583 |
+
"""Calculate overall outfit quality score"""
|
| 584 |
+
if len(subset) < 2:
|
| 585 |
+
return 0.0
|
| 586 |
+
|
| 587 |
+
# Base score from category diversity
|
| 588 |
+
diversity_score = len(set(get_category_type(cat_str(i)) for i in subset)) / 4.0
|
| 589 |
+
|
| 590 |
+
# Style consistency score
|
| 591 |
+
style_score = calculate_style_consistency_score(subset)
|
| 592 |
+
|
| 593 |
+
# Color consistency score
|
| 594 |
+
color_score = calculate_color_consistency_score(subset)
|
| 595 |
+
|
| 596 |
+
# Length appropriateness (prefer 3-4 items)
|
| 597 |
+
length_score = 1.0 if 3 <= len(subset) <= 4 else 0.7
|
| 598 |
+
|
| 599 |
+
# Weighted combination
|
| 600 |
+
return 0.3 * diversity_score + 0.3 * style_score + 0.2 * color_score + 0.2 * length_score
|
| 601 |
|
| 602 |
+
# Generate diverse outfit combinations with randomization
|
| 603 |
for _ in range(num_samples):
|
| 604 |
subset = []
|
| 605 |
|
| 606 |
+
# VARIABLE OUTFIT LENGTH: 2-5 items with different strategies
|
| 607 |
+
outfit_length = rng.choice([2, 3, 4, 5], p=[0.1, 0.4, 0.4, 0.1]) # Prefer 3-4 items
|
| 608 |
+
|
| 609 |
+
# Strategy 1: Core outfit (shirt + pants + shoes) + accessories
|
| 610 |
+
if rng.random() < 0.6 and uppers and bottoms and shoes:
|
| 611 |
# Core outfit: exactly 1 of each required slot
|
| 612 |
subset.append(int(rng.choice(uppers)))
|
| 613 |
subset.append(int(rng.choice(bottoms)))
|
| 614 |
subset.append(int(rng.choice(shoes)))
|
| 615 |
|
| 616 |
+
# Add accessories based on template limit and remaining slots
|
| 617 |
+
remaining_slots = outfit_length - len(subset)
|
| 618 |
+
if accs and remaining_slots > 0:
|
| 619 |
+
max_accs = min(template["accessory_limit"], remaining_slots, len(accs))
|
| 620 |
+
num_accs = rng.integers(0, max_accs + 1)
|
| 621 |
available_accs = [i for i in accs if i not in subset]
|
| 622 |
+
if available_accs and num_accs > 0:
|
| 623 |
selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
|
| 624 |
subset.extend(selected_accs.tolist())
|
| 625 |
|
| 626 |
+
# Fill remaining slots with other items
|
| 627 |
+
remaining_slots = outfit_length - len(subset)
|
| 628 |
+
if others and remaining_slots > 0:
|
| 629 |
available_others = [i for i in others if i not in subset]
|
| 630 |
+
if available_others:
|
| 631 |
+
num_others = min(remaining_slots, len(available_others))
|
| 632 |
+
selected_others = rng.choice(available_others, size=num_others, replace=False)
|
| 633 |
+
subset.extend(selected_others.tolist())
|
| 634 |
+
|
| 635 |
+
# Strategy 2: Flexible combination (no strict slot requirements)
|
| 636 |
+
elif rng.random() < 0.3:
|
| 637 |
+
# Randomly select items from all categories
|
| 638 |
+
all_items = list(ids)
|
| 639 |
+
rng.shuffle(all_items)
|
| 640 |
|
| 641 |
+
# Select items ensuring diversity
|
| 642 |
+
selected_categories = set()
|
| 643 |
+
for item in all_items:
|
| 644 |
+
if len(subset) >= outfit_length:
|
| 645 |
+
break
|
| 646 |
+
item_category = get_category_type(cat_str(item))
|
| 647 |
+
if item_category not in selected_categories or len(subset) < 2:
|
| 648 |
+
subset.append(item)
|
| 649 |
+
selected_categories.add(item_category)
|
| 650 |
+
|
| 651 |
+
# Strategy 3: Accessory-focused outfit (for small wardrobes)
|
| 652 |
+
else:
|
| 653 |
+
# Start with accessories if available
|
| 654 |
+
if accs:
|
| 655 |
+
num_accs = min(outfit_length, len(accs))
|
| 656 |
+
selected_accs = rng.choice(accs, size=num_accs, replace=False)
|
| 657 |
+
subset.extend(selected_accs.tolist())
|
| 658 |
|
| 659 |
+
# Fill remaining with other categories
|
| 660 |
+
remaining_slots = outfit_length - len(subset)
|
| 661 |
+
if remaining_slots > 0:
|
| 662 |
+
other_categories = []
|
| 663 |
+
if uppers: other_categories.extend(uppers)
|
| 664 |
+
if bottoms: other_categories.extend(bottoms)
|
| 665 |
+
if shoes: other_categories.extend(shoes)
|
| 666 |
+
if others: other_categories.extend(others)
|
| 667 |
+
|
| 668 |
+
available_others = [i for i in other_categories if i not in subset]
|
| 669 |
+
if available_others:
|
| 670 |
+
num_others = min(remaining_slots, len(available_others))
|
| 671 |
+
selected_others = rng.choice(available_others, size=num_others, replace=False)
|
| 672 |
+
subset.extend(selected_others.tolist())
|
| 673 |
|
| 674 |
# Remove duplicates and validate
|
| 675 |
subset = list(set(subset))
|
| 676 |
+
if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset):
|
| 677 |
+
# Add randomization factor to prevent identical recommendations
|
| 678 |
+
subset = rng.permutation(subset).tolist() # Randomize order
|
| 679 |
candidates.append(subset)
|
| 680 |
if len(candidates) % 10 == 0: # Log every 10 candidates
|
| 681 |
print(f"π DEBUG: Generated {len(candidates)} candidates so far...")
|
| 682 |
|
| 683 |
print(f"π DEBUG: Generated {len(candidates)} total candidates")
|
| 684 |
+
|
| 685 |
# 3) Score using ViT
|
| 686 |
def score_subset(idx_subset: List[int]) -> float:
|
| 687 |
embs = torch.tensor(
|
|
|
|
| 695 |
|
| 696 |
# Enhanced validation with strict slot constraints
|
| 697 |
def is_valid_outfit(subset: List[int]) -> bool:
|
| 698 |
+
"""Check if outfit meets flexible requirements"""
|
| 699 |
+
if len(subset) < 2 or len(subset) > max_size:
|
| 700 |
+
return False
|
| 701 |
+
|
| 702 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 703 |
category_counts = {}
|
| 704 |
|
| 705 |
for cat in categories:
|
| 706 |
category_counts[cat] = category_counts.get(cat, 0) + 1
|
| 707 |
|
| 708 |
+
# FLEXIBLE VALIDATION:
|
| 709 |
+
# - At least 2 different categories
|
| 710 |
+
# - Reasonable limits per category
|
| 711 |
+
# - Allow variable outfit lengths
|
| 712 |
+
unique_categories = len(set(categories))
|
| 713 |
+
if unique_categories < 2:
|
| 714 |
return False
|
| 715 |
+
|
| 716 |
+
# Reasonable limits (more flexible than before)
|
| 717 |
+
if category_counts.get("accessory", 0) > 3: # Allow up to 3 accessories
|
|
|
|
|
|
|
| 718 |
return False
|
| 719 |
+
if category_counts.get("other", 0) > 2: # Allow up to 2 other items
|
| 720 |
return False
|
| 721 |
|
| 722 |
return True
|
|
|
|
| 783 |
adjusted_score = calculate_outfit_penalty(subset, base_score)
|
| 784 |
scored.append((subset, adjusted_score, base_score))
|
| 785 |
|
| 786 |
+
# Sort by penalty-adjusted score with randomization
|
| 787 |
scored.sort(key=lambda x: x[1], reverse=True)
|
| 788 |
+
|
| 789 |
+
# Add randomization to prevent identical recommendations
|
| 790 |
+
if len(scored) > num_outfits:
|
| 791 |
+
# Take top 50% and randomly sample from them
|
| 792 |
+
top_half = scored[:max(num_outfits * 2, len(scored) // 2)]
|
| 793 |
+
rng.shuffle(top_half)
|
| 794 |
+
topk = top_half[:num_outfits]
|
| 795 |
+
else:
|
| 796 |
+
# If we have fewer candidates than requested, shuffle them
|
| 797 |
+
rng.shuffle(scored)
|
| 798 |
+
topk = scored[:num_outfits]
|
| 799 |
|
| 800 |
results = []
|
| 801 |
for subset, adjusted_score, base_score in topk:
|