Spaces:
Paused
Paused
Ali Mohsin
Update category mappings and enhance outfit templates for better filtering and accuracy
7669ee7
| import os | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| import open_clip | |
| CLIP_AVAILABLE = True | |
| except ImportError: | |
| CLIP_AVAILABLE = False | |
| from utils.transforms import build_inference_transform | |
| from models.resnet_embedder import ResNetItemEmbedder | |
| from models.vit_outfit import OutfitCompatibilityModel | |
| from utils.tag_system import TagProcessor, get_all_tag_options, validate_tags | |
| from utils.image_utils import ensure_rgb_image, validate_image_format | |
| def _get_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| class InferenceService: | |
| def __init__(self) -> None: | |
| self.device = _get_device() | |
| self.transform = build_inference_transform() | |
| self.embed_dim = int(os.getenv("EMBED_DIM", "512")) | |
| self.resnet_version = "resnet_v1" | |
| self.vit_version = "vit_v1" | |
| # Model loading status tracking | |
| self.models_loaded = False | |
| self.model_errors = [] | |
| # Tag processing system | |
| self.tag_processor = TagProcessor() | |
| # Load CLIP for category detection | |
| self.clip_model, self.clip_preprocess = None, None | |
| self._load_clip() | |
| # Load models with validation | |
| self.resnet, self.resnet_loaded = self._load_resnet() | |
| self.vit, self.vit_loaded = self._load_vit() | |
| # Move to device and set eval mode | |
| if self.resnet_loaded: | |
| self.resnet = self.resnet.to(self.device).eval() | |
| if self.vit_loaded: | |
| self.vit = self.vit.to(self.device).eval() | |
| # Disable gradients | |
| for m in [self.resnet, self.vit]: | |
| if m is not None: | |
| for p in m.parameters(): | |
| p.requires_grad_(False) | |
| # Update overall status | |
| self.models_loaded = self.resnet_loaded and self.vit_loaded | |
| if not self.models_loaded: | |
| self.model_errors = [] | |
| if not self.resnet_loaded: | |
| self.model_errors.append("ResNet: No trained weights found") | |
| if not self.vit_loaded: | |
| self.model_errors.append("ViT: No trained weights found") | |
| def _load_clip(self) -> None: | |
| """Load CLIP model for category detection.""" | |
| if not CLIP_AVAILABLE: | |
| print("β οΈ CLIP not available, using filename-based category detection") | |
| self.clip_model, self.clip_preprocess = None, None | |
| return | |
| try: | |
| print("π Loading CLIP model for category detection...") | |
| self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( | |
| 'ViT-B-32', pretrained='laion2b_s34b_b79k', device=self.device | |
| ) | |
| print("β CLIP model loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load CLIP model: {e}") | |
| self.clip_model, self.clip_preprocess = None, None | |
| def _detect_category_with_clip(self, image: Image.Image) -> str: | |
| """Detect clothing category using CLIP.""" | |
| if self.clip_model is None or self.clip_preprocess is None: | |
| return "other" | |
| try: | |
| # Define clothing categories with descriptions (including Pakistani traditional wear) | |
| categories = [ | |
| "a shirt, t-shirt, blouse, or top", | |
| "pants, jeans, trousers, or bottoms", | |
| "shoes, sneakers, boots, or footwear", | |
| "a jacket, blazer, coat, or outerwear", | |
| "a dress or gown", | |
| "a skirt or shorts", | |
| "a sweater, hoodie, or pullover", | |
| "a watch, ring, necklace, or jewelry", | |
| "a bag, purse, or handbag", | |
| "a hat, cap, or headwear", | |
| "a belt or accessory", | |
| "a kameez, kurta, or traditional Pakistani shirt", | |
| "shalwar, traditional Pakistani pants, or loose trousers", | |
| "Peshawari chappal, traditional Pakistani sandals, or ethnic footwear" | |
| ] | |
| # Prepare image and text | |
| image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device) | |
| text_inputs = open_clip.tokenize(categories).to(self.device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| image_features = self.clip_model.encode_image(image_input) | |
| text_features = self.clip_model.encode_text(text_inputs) | |
| # Compute similarity | |
| similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| values, indices = similarity[0].topk(1) | |
| # Map to outfit categories (including Pakistani traditional wear) | |
| category_map = { | |
| 0: "shirt", # shirt, t-shirt, blouse, top | |
| 1: "pants", # pants, jeans, trousers, bottoms | |
| 2: "shoes", # shoes, sneakers, boots, footwear | |
| 3: "jacket", # jacket, blazer, coat, outerwear | |
| 4: "dress", # dress, gown | |
| 5: "shorts", # skirt, shorts (Updated from pants to shorts for better filtering) | |
| 6: "shirt", # sweater, hoodie, pullover (map to shirt) | |
| 7: "accessory", # watch, ring, necklace, jewelry | |
| 8: "accessory", # bag, purse, handbag | |
| 9: "accessory", # hat, cap, headwear | |
| 10: "accessory", # belt, accessory | |
| 11: "kameez", # kameez, kurta, traditional Pakistani shirt | |
| 12: "shalwar", # shalwar, traditional Pakistani pants | |
| 13: "sandals" # Peshawari chappal, traditional Pakistani sandals (Updated to sandals) | |
| } | |
| predicted_category = category_map.get(indices[0].item(), "other") | |
| confidence = values[0].item() | |
| print(f"π CLIP detected: '{predicted_category}' (confidence: {confidence:.3f})") | |
| return predicted_category | |
| except Exception as e: | |
| print(f"β CLIP category detection failed: {e}") | |
| return "other" | |
| def _detect_category_from_filename(self, filename: str) -> str: | |
| """Fallback: Detect category from filename using keyword matching.""" | |
| if not filename: | |
| return "other" | |
| filename_lower = filename.lower() | |
| # Traditional Pakistani Wear | |
| if any(kw in filename_lower for kw in ["kameez", "kurta", "kurti"]): | |
| return "kameez" | |
| if any(kw in filename_lower for kw in ["shalwar", "salwar", "pyjama", "pajama"]): | |
| return "shalwar" | |
| if any(kw in filename_lower for kw in ["peshawari", "chappal", "khussa", "kolhapuri"]): | |
| return "sandals" | |
| # Specific Bottoms | |
| if any(kw in filename_lower for kw in ["short", "shorts", "bermuda"]): | |
| return "shorts" | |
| if any(kw in filename_lower for kw in ["jean", "jeans", "denim"]): | |
| return "jeans" | |
| if any(kw in filename_lower for kw in ["skirt", "miniskirt"]): | |
| return "skirt" | |
| if any(kw in filename_lower for kw in ["pant", "trouser", "slack", "chino", "legging", "jogger"]): | |
| return "pants" | |
| # Specific Footwear | |
| if any(kw in filename_lower for kw in ["sandal", "flip flop", "slide", "slipper"]): | |
| return "sandals" | |
| if any(kw in filename_lower for kw in ["sneaker", "trainer", "runner", "athletic shoe"]): | |
| return "sneakers" | |
| if any(kw in filename_lower for kw in ["boot", "bootie"]): | |
| return "boots" | |
| if any(kw in filename_lower for kw in ["shoe", "heel", "loafer", "oxford", "pump", "flat"]): | |
| return "shoes" | |
| # Specific Tops/Outerwear | |
| if any(kw in filename_lower for kw in ["waistcoat", "vest"]): | |
| return "waistcoat" | |
| if any(kw in filename_lower for kw in ["blazer", "suit jacket", "coat", "jacket"]): | |
| return "jacket" | |
| if any(kw in filename_lower for kw in ["hoodie", "sweatshirt"]): | |
| return "hoodie" | |
| if any(kw in filename_lower for kw in ["shirt", "top", "blouse", "tank", "tee", "t-shirt", "polo", "sweater", "cardigan"]): | |
| return "shirt" | |
| # Accessories | |
| if any(kw in filename_lower for kw in ["watch", "ring", "necklace", "bracelet", "bag", "hat", "belt", "scarf", "tie", "pocket square"]): | |
| return "accessory" | |
| # Default fallback | |
| return "other" | |
| def _load_resnet(self) -> tuple[nn.Module, bool]: | |
| strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict") | |
| ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth") | |
| if strategy == "random": | |
| print("β οΈ Random strategy selected - no trained weights will be loaded!") | |
| return ResNetItemEmbedder(embedding_dim=self.embed_dim), False | |
| # Try to download from Hugging Face Hub first | |
| try: | |
| print("π Attempting to download ResNet from Hugging Face Hub...") | |
| hf_path = hf_hub_download( | |
| repo_id="Stylique/dressify-models", | |
| filename="resnet_item_embedder_best.pth", | |
| local_dir="models/exports", | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"π₯ Downloaded ResNet from HF Hub: {hf_path}") | |
| model = ResNetItemEmbedder(embedding_dim=self.embed_dim) | |
| state = torch.load(hf_path, map_location="cpu") | |
| state_dict = state.get("state_dict", state) if isinstance(state, dict) else state | |
| model.load_state_dict(state_dict, strict=False) | |
| print("β ResNet model loaded successfully from HF Hub") | |
| return model, True | |
| except Exception as e: | |
| print(f"β Failed to download ResNet from HF Hub: {e}") | |
| # Check for local best checkpoint first | |
| best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth") | |
| if os.path.exists(best_path): | |
| print(f"π Loading ResNet from best checkpoint: {best_path}") | |
| model = ResNetItemEmbedder(embedding_dim=self.embed_dim) | |
| state = torch.load(best_path, map_location="cpu") | |
| state_dict = state.get("state_dict", state) if isinstance(state, dict) else state | |
| model.load_state_dict(state_dict, strict=False) | |
| print("β ResNet model loaded successfully from best checkpoint") | |
| return model, True | |
| # Check for regular checkpoint | |
| if os.path.exists(ckpt_path): | |
| print(f"π Loading ResNet from checkpoint: {ckpt_path}") | |
| model = ResNetItemEmbedder(embedding_dim=self.embed_dim) | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = state.get("state_dict", state) if isinstance(state, dict) else state | |
| model.load_state_dict(state_dict, strict=False) | |
| print("β ResNet model loaded successfully from checkpoint") | |
| return model, True | |
| print("β CRITICAL: No trained ResNet weights found!") | |
| print("π¨ Cannot provide recommendations without trained weights!") | |
| print("π‘ Please train the ResNet model first using the training tabs.") | |
| return ResNetItemEmbedder(embedding_dim=self.embed_dim), False | |
| def _load_vit(self) -> tuple[nn.Module, bool]: | |
| strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict") | |
| ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth") | |
| if strategy == "random": | |
| print("β οΈ Random strategy selected - no trained weights will be loaded!") | |
| return OutfitCompatibilityModel(embedding_dim=self.embed_dim), False | |
| # Try to download from Hugging Face Hub first | |
| try: | |
| print("π Attempting to download ViT from Hugging Face Hub...") | |
| hf_path = hf_hub_download( | |
| repo_id="Stylique/dressify-models", | |
| filename="vit_outfit_model_best.pth", | |
| local_dir="models/exports", | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"π₯ Downloaded ViT from HF Hub: {hf_path}") | |
| model = OutfitCompatibilityModel(embedding_dim=self.embed_dim) | |
| state = torch.load(hf_path, map_location="cpu") | |
| state_dict = state.get("state_dict", state) if isinstance(state, dict) else state | |
| model.load_state_dict(state_dict, strict=False) | |
| print("β ViT model loaded successfully from HF Hub") | |
| return model, True | |
| except Exception as e: | |
| print(f"β Failed to download ViT from HF Hub: {e}") | |
| # Check for local best checkpoint first | |
| best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth") | |
| if os.path.exists(best_path): | |
| print(f"π Loading ViT from best checkpoint: {best_path}") | |
| model = OutfitCompatibilityModel(embedding_dim=self.embed_dim) | |
| state = torch.load(best_path, map_location="cpu") | |
| state_dict = state.get("state_dict", state) if isinstance(state, dict) else state | |
| model.load_state_dict(state_dict, strict=False) | |
| print("β ViT model loaded successfully from best checkpoint") | |
| return model, True | |
| # Check for regular checkpoint | |
| if os.path.exists(ckpt_path): | |
| print(f"π Loading ViT from checkpoint: {ckpt_path}") | |
| model = OutfitCompatibilityModel(embedding_dim=self.embed_dim) | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = state.get("state_dict", state) if isinstance(state, dict) else state | |
| model.load_state_dict(state_dict, strict=False) | |
| print("β ViT model loaded successfully from checkpoint") | |
| return model, True | |
| print("β CRITICAL: No trained ViT weights found!") | |
| print("π¨ Cannot provide recommendations without trained weights!") | |
| print("π‘ Please train the ViT model first using the training tabs.") | |
| return OutfitCompatibilityModel(embedding_dim=self.embed_dim), False | |
| def reload_models(self) -> None: | |
| """Reload weights from current checkpoint locations (used after background training).""" | |
| self.resnet, self.resnet_loaded = self._load_resnet() | |
| self.vit, self.vit_loaded = self._load_vit() | |
| # Move to device and set eval mode | |
| if self.resnet_loaded: | |
| self.resnet = self.resnet.to(self.device).eval() | |
| if self.vit_loaded: | |
| self.vit = self.vit.to(self.device).eval() | |
| # Disable gradients | |
| for m in [self.resnet, self.vit]: | |
| if m is not None: | |
| for p in m.parameters(): | |
| p.requires_grad_(False) | |
| # Update overall status | |
| self.models_loaded = self.resnet_loaded and self.vit_loaded | |
| if not self.models_loaded: | |
| self.model_errors = [] | |
| if not self.resnet_loaded: | |
| self.model_errors.append("ResNet: No trained weights found") | |
| if not self.vit_loaded: | |
| self.model_errors.append("ViT: No trained weights found") | |
| def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]: | |
| """ | |
| Generate embeddings for images with comprehensive format support. | |
| All images are validated and converted to RGB before processing. | |
| """ | |
| print(f"π DEBUG: embed_images called with {len(images)} images") | |
| if len(images) == 0: | |
| print("π DEBUG: No images provided, returning empty list") | |
| return [] | |
| print(f"π DEBUG: ResNet model is None: {self.resnet is None}") | |
| if self.resnet is None: | |
| print("π DEBUG: ResNet model is None, returning empty list") | |
| return [] | |
| # Validate and convert all images to RGB | |
| processed_images = [] | |
| for i, img in enumerate(images): | |
| is_valid, error_msg = validate_image_format(img) | |
| if not is_valid: | |
| print(f"β οΈ Skipping invalid image {i}: {error_msg}") | |
| continue | |
| # Ensure RGB mode (required for ResNet) | |
| rgb_img = ensure_rgb_image(img) | |
| processed_images.append(rgb_img) | |
| if len(processed_images) == 0: | |
| print("β οΈ No valid images after processing") | |
| return [] | |
| print(f"π DEBUG: Processing {len(processed_images)} valid images") | |
| try: | |
| batch = torch.stack([self.transform(img) for img in processed_images]) | |
| batch = batch.to(self.device, memory_format=torch.channels_last) | |
| use_amp = (self.device == "cuda") | |
| with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp): | |
| emb = self.resnet(batch) | |
| emb = nn.functional.normalize(emb, dim=-1) | |
| result = [e.detach().cpu().numpy().astype(np.float32) for e in emb] | |
| print(f"π DEBUG: Successfully generated {len(result)} embeddings") | |
| return result | |
| except Exception as e: | |
| print(f"π DEBUG: Error in embed_images: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| # Debug: Print model status | |
| print(f"π DEBUG: models_loaded={self.models_loaded}, resnet_loaded={self.resnet_loaded}, vit_loaded={self.vit_loaded}") | |
| print(f"π DEBUG: model_errors={self.model_errors}") | |
| print(f"π DEBUG: items count={len(items)}") | |
| # Validate that models are properly loaded | |
| if not self.models_loaded: | |
| error_msg = f"β Cannot provide recommendations: Models not properly loaded. Errors: {self.model_errors}" | |
| print(error_msg) | |
| return [{ | |
| "error": "Models not trained or loaded properly", | |
| "details": self.model_errors, | |
| "message": "Please ensure models are trained and checkpoints exist before generating recommendations." | |
| }] | |
| # 1) Ensure embeddings for each input item | |
| proc_items: List[Dict[str, Any]] = [] | |
| for i, it in enumerate(items): | |
| print(f"π DEBUG: Processing item {i}: id={it.get('id')}, has_image={it.get('image') is not None}, has_embedding={it.get('embedding') is not None}") | |
| # Auto-detect category if not provided or is None | |
| category = it.get("category") | |
| if not category or category == "None" or category == "": | |
| if it.get("image") is not None and self.clip_model is not None: | |
| print(f"π DEBUG: Auto-detecting category for item {i} using CLIP...") | |
| category = self._detect_category_with_clip(it["image"]) | |
| else: | |
| # Fallback to filename-based detection | |
| filename = it.get("id", "") | |
| print(f"π DEBUG: Auto-detecting category for item {i} using filename '{filename}'...") | |
| category = self._detect_category_from_filename(filename) | |
| print(f"π DEBUG: Filename-based detection result: '{category}'") | |
| emb = it.get("embedding") | |
| if emb is None and it.get("image") is not None: | |
| # Compute on-the-fly if image provided | |
| print(f"π DEBUG: Generating embedding for item {i}") | |
| emb = self.embed_images([it["image"]])[0] | |
| if emb is None: | |
| # Skip if we cannot get an embedding | |
| print(f"π DEBUG: Skipping item {i} - no embedding generated") | |
| continue | |
| emb_np = np.asarray(emb, dtype=np.float32) | |
| proc_items.append({ | |
| "id": it.get("id"), | |
| "embedding": emb_np, | |
| "category": category | |
| }) | |
| print(f"π DEBUG: Added item {i} to proc_items with category '{category}', total: {len(proc_items)}") | |
| print(f"π DEBUG: Final proc_items count: {len(proc_items)}") | |
| if len(proc_items) < 2: | |
| print("π DEBUG: Returning empty array - not enough items (< 2)") | |
| return [] | |
| print("π DEBUG: Starting candidate generation...") | |
| # 2) Candidate generation with outfit templates | |
| # Use timestamp-based seed for better randomization | |
| import time | |
| seed = context.get("seed", int(time.time() * 1000) % 10000) | |
| rng = np.random.default_rng(seed) | |
| print(f"π DEBUG: Using random seed: {seed}") | |
| num_outfits = int(context.get("num_outfits", 5)) # Increased default from 3 to 5 | |
| min_size, max_size = 2, 6 # Allow smaller outfits (2 items minimum) | |
| ids = list(range(len(proc_items))) | |
| # Enhanced context-aware outfit templates | |
| outfit_templates = { | |
| "casual": { | |
| "style": "relaxed, comfortable, everyday", | |
| "preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant", "shirt", "pants", "shoes", "shorts", "jeans", "sneakers"], | |
| "excluded_categories": ["waistcoat", "suit jacket", "dress pant", "oxford"], | |
| "color_palette": ["neutral", "denim", "white", "black", "gray"], | |
| "accessory_limit": 2, | |
| "weather_modifiers": { | |
| "hot": {"preferred_categories": ["tank", "shorts", "sandals", "light shirt"], "excluded_categories": ["hoodie", "sweater", "jacket", "boots"]}, | |
| "cold": {"preferred_categories": ["hoodie", "sweater", "jacket", "boots"], "excluded_categories": ["shorts", "sandals", "tank"]}, | |
| "rain": {"preferred_categories": ["jacket", "boots", "waterproof"], "excluded_categories": ["sandals", "suede"]} | |
| }, | |
| "occasion_modifiers": { | |
| "business": {"preferred_categories": ["shirt", "pants", "shoes", "blazer"], "excluded_categories": ["shorts", "sandals", "tank", "sweatpant", "hoodie", "legging"], "accessory_limit": 3}, | |
| "formal": {"preferred_categories": ["shirt", "pants", "shoes", "blazer"], "excluded_categories": ["shorts", "sandals", "sneakers", "jeans", "tshirt"], "accessory_limit": 4} | |
| } | |
| }, | |
| "smart_casual": { | |
| "style": "polished but relaxed, business casual", | |
| "preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo", "pants", "shoes", "jeans", "boots"], | |
| "excluded_categories": ["shorts", "sandals", "tank", "sweatpant", "hoodie", "athletic"], | |
| "color_palette": ["navy", "white", "khaki", "brown", "gray"], | |
| "accessory_limit": 3, | |
| "weather_modifiers": { | |
| "hot": {"preferred_categories": ["polo", "light shirt", "loafer"], "excluded_categories": ["boots", "heavy jacket"]}, | |
| "cold": {"preferred_categories": ["blazer", "sweater", "boots"], "excluded_categories": ["loafer"]}, | |
| "rain": {"preferred_categories": ["blazer", "boots", "umbrella"], "excluded_categories": ["suede"]} | |
| }, | |
| "occasion_modifiers": { | |
| "business": {"preferred_categories": ["shirt", "pants", "shoes", "blazer"], "excluded_categories": ["jeans", "sneakers"], "accessory_limit": 4}, | |
| "formal": {"preferred_categories": ["shirt", "pants", "shoes", "blazer", "suit"], "excluded_categories": ["jeans", "sneakers", "polo"], "accessory_limit": 4} | |
| } | |
| }, | |
| "formal": { | |
| "style": "professional, elegant, sophisticated", | |
| "preferred_categories": ["blazer", "jacket", "suit jacket", "dress shirt", "dress pant", "oxford", "suit", "shirt", "pants", "shoes", "waistcoat"], | |
| "excluded_categories": ["shorts", "sandals", "sneakers", "jeans", "tshirt", "hoodie", "sweatpant", "tank", "legging"], | |
| "color_palette": ["navy", "black", "white", "gray", "charcoal"], | |
| "accessory_limit": 4, | |
| "requires_outerwear": True, # Flag to indicate formal outfits should include jackets | |
| "weather_modifiers": { | |
| "hot": {"preferred_categories": ["light shirt", "light pant", "oxford"], "requires_outerwear": False}, | |
| "cold": {"preferred_categories": ["blazer", "suit", "boots", "waistcoat"], "requires_outerwear": True}, | |
| "rain": {"preferred_categories": ["blazer", "boots", "umbrella"], "requires_outerwear": True} | |
| }, | |
| "occasion_modifiers": { | |
| "business": {"preferred_categories": ["shirt", "pants", "shoes", "waistcoat"], "accessory_limit": 4, "requires_outerwear": True}, | |
| "casual": {"preferred_categories": ["shirt", "pants", "shoes", "blazer"], "accessory_limit": 3, "requires_outerwear": False} | |
| } | |
| }, | |
| "sporty": { | |
| "style": "athletic, active, performance", | |
| "preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging", "shirt", "pants", "shoes", "sneakers", "shorts", "hoodie"], | |
| "excluded_categories": ["blazer", "suit", "dress pant", "oxford", "loafer", "waistcoat", "jeans", "sandals"], | |
| "color_palette": ["bright", "neon", "white", "black", "primary colors"], | |
| "accessory_limit": 1, | |
| "weather_modifiers": { | |
| "hot": {"preferred_categories": ["tank", "shorts", "running shoe"]}, | |
| "cold": {"preferred_categories": ["hoodie", "legging", "running shoe"]}, | |
| "rain": {"preferred_categories": ["jacket", "running shoe", "cap"]} | |
| }, | |
| "occasion_modifiers": { | |
| "business": {"preferred_categories": ["shirt", "pants", "shoes"], "excluded_categories": ["tank", "shorts", "legging"], "accessory_limit": 2}, | |
| "formal": {"preferred_categories": ["shirt", "pants", "shoes"], "excluded_categories": ["tank", "shorts", "legging", "hoodie"], "accessory_limit": 3} | |
| } | |
| }, | |
| "traditional": { | |
| "style": "Pakistani traditional, cultural, ethnic", | |
| "preferred_categories": ["kameez", "kurta", "shalwar", "peshawari", "chappal", "traditional", "ethnic", "waistcoat", "sandals"], | |
| "excluded_categories": ["shorts", "jeans", "sneakers", "hoodie", "tank", "suit", "tie"], | |
| "color_palette": ["white", "black", "navy", "maroon", "gold", "green", "traditional colors"], | |
| "accessory_limit": 3, | |
| "requires_traditional": True, # Flag for traditional outfit combinations | |
| "weather_modifiers": { | |
| "hot": {"preferred_categories": ["light kameez", "cotton shalwar", "peshawari chappal", "sandals"]}, | |
| "cold": {"preferred_categories": ["warm kameez", "thick shalwar", "traditional boots", "waistcoat", "shawl"]}, | |
| "rain": {"preferred_categories": ["waterproof kameez", "shalwar", "traditional boots"]} | |
| }, | |
| "occasion_modifiers": { | |
| "business": {"preferred_categories": ["formal kameez", "shalwar", "peshawari", "waistcoat"], "accessory_limit": 2}, | |
| "formal": {"preferred_categories": ["elegant kameez", "shalwar", "peshawari", "waistcoat"], "accessory_limit": 3}, | |
| "casual": {"preferred_categories": ["casual kameez", "shalwar", "chappal", "sandals"], "accessory_limit": 2} | |
| } | |
| } | |
| } | |
| # Process tags using the tag processor | |
| processed_tags = self.tag_processor.process_tags(context) | |
| # Extract primary tags (backward compatible) | |
| occasion = context.get("occasion", processed_tags["primary_tags"].get("occasion", "casual")) | |
| weather = context.get("weather", processed_tags["primary_tags"].get("weather", "any")) | |
| # Support both "outfit_style" and "style" for backward compatibility | |
| outfit_style = context.get("outfit_style") or context.get("style") or processed_tags["primary_tags"].get("outfit_style") or processed_tags["primary_tags"].get("style", "casual") | |
| # Select base template | |
| template_name = outfit_style | |
| if template_name not in outfit_templates: | |
| # Fallback to closest match | |
| if template_name in ["semi_formal", "business_casual"]: | |
| template_name = "smart_casual" | |
| elif template_name in ["athletic", "workout"]: | |
| template_name = "sporty" | |
| else: | |
| template_name = "casual" | |
| template = outfit_templates[template_name].copy() | |
| # Apply processed tag preferences | |
| if processed_tags["preferences"]["preferred_categories"]: | |
| template["preferred_categories"].extend(processed_tags["preferences"]["preferred_categories"]) | |
| # Apply constraints from processed tags | |
| constraints = processed_tags["constraints"] | |
| if constraints.get("accessory_limit"): | |
| template["accessory_limit"] = constraints["accessory_limit"] | |
| if constraints.get("requires_outerwear"): | |
| template["requires_outerwear"] = constraints["requires_outerwear"] | |
| # Initialize excluded categories if not present | |
| if "excluded_categories" not in template: | |
| template["excluded_categories"] = [] | |
| # Apply weather modifications | |
| if weather != "any" and weather in template.get("weather_modifiers", {}): | |
| weather_mod = template["weather_modifiers"][weather] | |
| template["preferred_categories"].extend(weather_mod.get("preferred_categories", [])) | |
| if "excluded_categories" in weather_mod: | |
| template["excluded_categories"].extend(weather_mod["excluded_categories"]) | |
| if "accessory_limit" in weather_mod: | |
| template["accessory_limit"] = weather_mod["accessory_limit"] | |
| # Apply occasion modifications | |
| if occasion in template.get("occasion_modifiers", {}): | |
| occasion_mod = template["occasion_modifiers"][occasion] | |
| template["preferred_categories"].extend(occasion_mod.get("preferred_categories", [])) | |
| if "excluded_categories" in occasion_mod: | |
| template["excluded_categories"].extend(occasion_mod["excluded_categories"]) | |
| if "accessory_limit" in occasion_mod: | |
| template["accessory_limit"] = occasion_mod["accessory_limit"] | |
| # Remove duplicates and add context info | |
| template["preferred_categories"] = list(set(template["preferred_categories"])) | |
| template["excluded_categories"] = list(set(template["excluded_categories"])) | |
| template["context"] = { | |
| "occasion": occasion, | |
| "weather": weather, | |
| "style": outfit_style, | |
| "processed_tags": processed_tags, # Include full processed tags | |
| "tag_weights": processed_tags["weights"], | |
| "tag_synergies": processed_tags["synergies"] | |
| } | |
| print(f"π DEBUG: Using template '{template_name}' with context: occasion={occasion}, weather={weather}") | |
| print(f"π DEBUG: Template categories: {template['preferred_categories']}") | |
| print(f"π DEBUG: Excluded categories: {template['excluded_categories']}") | |
| print(f"π DEBUG: Accessory limit: {template['accessory_limit']}") | |
| # Enhanced category-aware pools with diversity checks | |
| def cat_str(i: int) -> str: | |
| return (proc_items[i].get("category") or "").lower() | |
| print("π DEBUG: Building category pools...") | |
| # Debug: Print all categories | |
| for i in range(len(proc_items)): | |
| print(f"π DEBUG: Item {i}: category='{proc_items[i].get('category')}' -> cat_str='{cat_str(i)}'") | |
| def extract_color_from_category(category: str) -> str: | |
| """Extract color information from category name""" | |
| category_lower = category.lower() | |
| color_keywords = { | |
| "black": ["black", "dark", "charcoal", "navy"], | |
| "white": ["white", "cream", "ivory", "off-white"], | |
| "gray": ["gray", "grey", "silver", "ash"], | |
| "brown": ["brown", "tan", "beige", "khaki", "camel"], | |
| "blue": ["blue", "navy", "denim", "indigo", "royal"], | |
| "red": ["red", "burgundy", "maroon", "crimson"], | |
| "green": ["green", "olive", "emerald", "forest"], | |
| "yellow": ["yellow", "gold", "mustard", "lemon"], | |
| "pink": ["pink", "rose", "coral", "salmon"], | |
| "purple": ["purple", "violet", "lavender", "plum"], | |
| "orange": ["orange", "peach", "apricot", "tangerine"], | |
| "neutral": ["neutral", "nude", "natural", "earth"] | |
| } | |
| for color, keywords in color_keywords.items(): | |
| if any(kw in category_lower for kw in keywords): | |
| return color | |
| return "unknown" | |
| def calculate_color_consistency_score(items: List[int]) -> float: | |
| """Calculate sophisticated color harmony score using fashion theory""" | |
| colors = [extract_color_from_category(cat_str(i)) for i in items] | |
| color_counts = {} | |
| for color in colors: | |
| color_counts[color] = color_counts.get(color, 0) + 1 | |
| # Advanced color harmony rules | |
| base_score = 0.0 | |
| # 1. Monochromatic harmony (same color family) | |
| dominant_color = max(color_counts.items(), key=lambda x: x[1])[0] if color_counts else "unknown" | |
| if color_counts.get(dominant_color, 0) >= 2: | |
| base_score += 0.4 # Strong monochromatic bonus | |
| # 2. Complementary color harmony | |
| complementary_pairs = [ | |
| ("black", "white"), ("navy", "white"), ("brown", "beige"), | |
| ("red", "green"), ("blue", "orange"), ("purple", "yellow") | |
| ] | |
| for color1, color2 in complementary_pairs: | |
| if color_counts.get(color1, 0) > 0 and color_counts.get(color2, 0) > 0: | |
| base_score += 0.3 # Complementary harmony bonus | |
| break | |
| # 3. Neutral base with accent colors | |
| neutral_colors = ["black", "white", "gray", "navy", "brown", "beige"] | |
| neutral_count = sum(color_counts.get(c, 0) for c in neutral_colors) | |
| if neutral_count >= 2 and len([c for c in colors if c not in neutral_colors]) <= 1: | |
| base_score += 0.2 # Neutral base bonus | |
| # 4. Color distribution penalty | |
| if len(color_counts) > 4: | |
| base_score -= 0.2 # Too many different colors | |
| elif len(color_counts) == 1 and len(items) > 2: | |
| base_score -= 0.1 # Too monotonous | |
| # 5. Context-aware color scoring | |
| if occasion == "formal": | |
| formal_colors = ["black", "navy", "white", "gray", "charcoal"] | |
| formal_count = sum(color_counts.get(c, 0) for c in formal_colors) | |
| if formal_count >= 2: | |
| base_score += 0.2 # Formal color bonus | |
| elif occasion == "business": | |
| business_colors = ["navy", "white", "gray", "black", "brown"] | |
| business_count = sum(color_counts.get(c, 0) for c in business_colors) | |
| if business_count >= 2: | |
| base_score += 0.15 # Business color bonus | |
| return min(1.0, max(0.0, base_score + 0.3)) # Ensure score is between 0-1 | |
| def calculate_style_consistency_score(items: List[int]) -> float: | |
| """Calculate advanced style consistency using fashion expert rules""" | |
| categories = [cat_str(i) for i in items] | |
| preferred_cats = template["preferred_categories"] | |
| # Base template matching | |
| matches = 0 | |
| for cat in categories: | |
| if any(pref in cat for pref in preferred_cats): | |
| matches += 1 | |
| base_score = matches / len(categories) if categories else 0.0 | |
| # Advanced fashion rules scoring | |
| fashion_bonus = 0.0 | |
| occasion = template["context"]["occasion"] | |
| weather = template["context"]["weather"] | |
| outfit_style = template["context"]["style"] | |
| # 1. Occasion-appropriate style rules | |
| if occasion == "formal": | |
| # Formal requires structured, tailored pieces | |
| formal_items = ["jacket", "blazer", "suit", "dress shirt", "dress pant", "oxford"] | |
| formal_count = sum(1 for cat in categories if any(f in cat for f in formal_items)) | |
| if formal_count >= 3: # At least 3 formal items | |
| fashion_bonus += 0.4 | |
| elif formal_count >= 2: | |
| fashion_bonus += 0.2 | |
| elif occasion == "business": | |
| # Business requires professional but not overly formal | |
| business_items = ["shirt", "blazer", "pants", "loafer", "oxford", "dress pant"] | |
| business_count = sum(1 for cat in categories if any(b in cat for b in business_items)) | |
| if business_count >= 3: | |
| fashion_bonus += 0.3 | |
| elif business_count >= 2: | |
| fashion_bonus += 0.15 | |
| elif occasion == "sport": | |
| # Sport requires athletic, functional pieces | |
| sport_items = ["athletic", "running", "jogger", "sneaker", "tank", "legging"] | |
| sport_count = sum(1 for cat in categories if any(s in cat for s in sport_items)) | |
| if sport_count >= 2: | |
| fashion_bonus += 0.3 | |
| # 2. Style coherence rules | |
| if outfit_style == "formal": | |
| # Formal style coherence | |
| if "jacket" in categories and "shirt" in categories: | |
| fashion_bonus += 0.2 # Proper layering | |
| if len([c for c in categories if c in ["jacket", "shirt", "pants", "shoes"]]) >= 3: | |
| fashion_bonus += 0.2 # Complete formal set | |
| elif outfit_style == "smart_casual": | |
| # Smart casual balance | |
| if "shirt" in categories and "pants" in categories: | |
| fashion_bonus += 0.15 | |
| if "blazer" in categories or "jacket" in categories: | |
| fashion_bonus += 0.1 # Elevated casual | |
| elif outfit_style == "traditional": | |
| # Traditional Pakistani coherence | |
| traditional_items = ["kameez", "kurta", "shalwar", "peshawari", "chappal"] | |
| traditional_count = sum(1 for cat in categories if any(t in cat for t in traditional_items)) | |
| if traditional_count >= 2: | |
| fashion_bonus += 0.4 | |
| if traditional_count == 3: # Complete traditional set | |
| fashion_bonus += 0.2 | |
| # 3. Weather-appropriate logic | |
| if weather == "hot": | |
| # Hot weather preferences | |
| if any("light" in cat or "cotton" in cat or "tank" in cat for cat in categories): | |
| fashion_bonus += 0.1 | |
| if "jacket" in categories and len(categories) > 3: | |
| fashion_bonus -= 0.1 # Too many layers for hot weather | |
| elif weather == "cold": | |
| # Cold weather preferences | |
| if "jacket" in categories or "blazer" in categories: | |
| fashion_bonus += 0.15 | |
| if "sweater" in categories or "hoodie" in categories: | |
| fashion_bonus += 0.1 | |
| elif weather == "rain": | |
| # Rain weather preferences | |
| if any("waterproof" in cat or "boot" in cat for cat in categories): | |
| fashion_bonus += 0.2 | |
| if "jacket" in categories: | |
| fashion_bonus += 0.1 | |
| # 4. Proportions and balance | |
| category_types = [get_category_type(cat) for cat in categories] | |
| type_counts = {} | |
| for cat_type in category_types: | |
| type_counts[cat_type] = type_counts.get(cat_type, 0) + 1 | |
| # Balanced outfit proportions | |
| if len(type_counts) >= 3: # Good diversity | |
| fashion_bonus += 0.1 | |
| if type_counts.get("accessory", 0) <= 2: # Not over-accessorized | |
| fashion_bonus += 0.05 | |
| return min(1.0, base_score + fashion_bonus) | |
| def get_category_type(cat: str) -> str: | |
| """Map category to outfit slot type with comprehensive taxonomy""" | |
| cat_lower = cat.lower().strip() | |
| # print(f"π DEBUG: Mapping category '{cat}' -> '{cat_lower}'") | |
| # Direct mapping for CLIP-detected categories | |
| if cat_lower == "shirt": | |
| return "upper" | |
| elif cat_lower == "pants": | |
| return "bottom" | |
| elif cat_lower == "shoes": | |
| return "shoe" | |
| elif cat_lower == "jacket": | |
| return "outerwear" # Separate category for jackets/blazers | |
| elif cat_lower == "accessory": | |
| return "accessory" | |
| elif cat_lower == "kameez": | |
| return "upper" # Kameez is upper body wear | |
| elif cat_lower == "shalwar": | |
| return "bottom" # Shalwar is bottom wear | |
| elif cat_lower == "peshawari": | |
| return "shoe" # Peshawari chappal is footwear | |
| elif cat_lower == "shorts": | |
| return "bottom" | |
| elif cat_lower == "sandals": | |
| return "shoe" | |
| elif cat_lower == "sneakers": | |
| return "shoe" | |
| elif cat_lower == "jeans": | |
| return "bottom" | |
| elif cat_lower == "waistcoat": | |
| return "outerwear" | |
| # Upper body items (tops, innerwear) | |
| upper_keywords = [ | |
| "top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami", | |
| "hoodie", "sweater", "pullover", "cardigan", "polo", "henley", "tunic", | |
| "crop top", "bodysuit", "romper", "jumpsuit", "kameez", "kurta", "shalwar kameez" | |
| ] | |
| # Outerwear items (jackets, coats, blazers) | |
| outerwear_keywords = [ | |
| "jacket", "blazer", "coat", "vest", "waistcoat", "windbreaker", "bomber", | |
| "denim jacket", "leather jacket", "suit jacket", "sport coat", "trench coat", | |
| "pea coat", "overcoat", "cardigan", "sweater jacket" | |
| ] | |
| # Bottom items | |
| bottom_keywords = [ | |
| "pant", "pants", "trouser", "trousers", "jean", "jeans", "denim", | |
| "skirt", "short", "shorts", "legging", "leggings", "tights", | |
| "chino", "khaki", "cargo", "jogger", "sweatpant", "sweatpants", | |
| "culotte", "palazzo", "mini skirt", "midi skirt", "maxi skirt", | |
| "bermuda", "capri", "bike short", "bike shorts", "shalwar", "shalwar kameez" | |
| ] | |
| # Footwear | |
| shoe_keywords = [ | |
| "shoe", "shoes", "sneaker", "sneakers", "boot", "boots", "heel", "heels", | |
| "sandal", "sandals", "flat", "flats", "loafer", "loafers", "oxford", | |
| "pump", "pumps", "stiletto", "wedge", "ankle boot", "knee high boot", | |
| "combat boot", "hiking boot", "running shoe", "athletic shoe", | |
| "mule", "mules", "clog", "clogs", "espadrille", "espadrilles", | |
| "peshawari", "chappal", "peshawari chappal", "traditional sandal" | |
| ] | |
| # Accessories (can have multiple) | |
| accessory_keywords = [ | |
| "watch", "belt", "ring", "rings", "bracelet", "bracelets", "necklace", "necklaces", | |
| "earring", "earrings", "bag", "bags", "handbag", "purse", "clutch", "tote", | |
| "hat", "cap", "beanie", "scarf", "scarves", "glove", "gloves", "sunglass", "sunglasses", | |
| "tie", "bow tie", "pocket square", "cufflink", "cufflinks", "brooch", "pin", | |
| "hair accessory", "headband", "hair clip", "barrette", "scrunchy", "scrunchies" | |
| ] | |
| # Check each category | |
| if any(k in cat_lower for k in outerwear_keywords): | |
| return "outerwear" | |
| elif any(k in cat_lower for k in upper_keywords): | |
| return "upper" | |
| elif any(k in cat_lower for k in bottom_keywords): | |
| return "bottom" | |
| elif any(k in cat_lower for k in shoe_keywords): | |
| return "shoe" | |
| elif any(k in cat_lower for k in accessory_keywords): | |
| return "accessory" | |
| else: | |
| return "other" | |
| # Create category pools | |
| print("π DEBUG: Building category pools...") | |
| uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"] | |
| bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"] | |
| shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"] | |
| outerwear = [i for i in ids if get_category_type(cat_str(i)) == "outerwear"] | |
| accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"] | |
| others = [i for i in ids if get_category_type(cat_str(i)) == "other"] | |
| print(f"π DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, outerwear: {len(outerwear)}, accessories: {len(accs)}, others: {len(others)}") | |
| # Check if we have enough items to create outfits | |
| total_items = len(uppers) + len(bottoms) + len(shoes) + len(outerwear) + len(accs) + len(others) | |
| if total_items < 2: | |
| print(f"π DEBUG: Not enough items to create outfits - total: {total_items}") | |
| return [] | |
| # Warn if we're missing core categories but still try to generate | |
| if len(uppers) == 0 or len(bottoms) == 0 or len(shoes) == 0: | |
| print(f"π DEBUG: Missing some core categories - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}") | |
| print(f"π DEBUG: Will use flexible outfit generation with available items") | |
| candidates: List[List[int]] = [] | |
| num_samples = max(num_outfits * 25, 50) # Further increased for more variety | |
| print(f"π DEBUG: Generating {num_samples} candidate outfits...") | |
| def has_category_diversity(subset: List[int]) -> bool: | |
| """Check if subset has good category diversity""" | |
| categories = [get_category_type(cat_str(i)) for i in subset] | |
| unique_categories = set(categories) | |
| # Require at least 2 different category types for good diversity | |
| return len(unique_categories) >= 2 | |
| def calculate_outfit_score(subset: List[int]) -> float: | |
| """Calculate sophisticated outfit quality score using advanced fashion reasoning""" | |
| if len(subset) < 2: | |
| return 0.0 | |
| # 1. Category diversity and completeness | |
| category_types = [get_category_type(cat_str(i)) for i in subset] | |
| unique_types = set(category_types) | |
| diversity_score = len(unique_types) / 5.0 # Normalize to 5 categories max | |
| # Completeness bonus for essential categories | |
| essential_categories = {"upper", "bottom", "shoe"} | |
| completeness_bonus = 0.0 | |
| if essential_categories.issubset(unique_types): | |
| completeness_bonus += 0.3 # All essential categories present | |
| elif len(essential_categories.intersection(unique_types)) >= 2: | |
| completeness_bonus += 0.15 # Most essential categories present | |
| # 2. Advanced style consistency | |
| style_score = calculate_style_consistency_score(subset) | |
| # 3. Sophisticated color harmony | |
| color_score = calculate_color_consistency_score(subset) | |
| # 4. Context-appropriate length scoring | |
| occasion = template["context"]["occasion"] | |
| if occasion == "formal": | |
| length_score = 1.0 if 4 <= len(subset) <= 5 else 0.6 # Formal prefers complete sets | |
| elif occasion == "business": | |
| length_score = 1.0 if 3 <= len(subset) <= 4 else 0.7 # Business balanced | |
| elif occasion == "sport": | |
| length_score = 1.0 if 2 <= len(subset) <= 3 else 0.8 # Sport can be minimal | |
| else: # casual | |
| length_score = 1.0 if 2 <= len(subset) <= 4 else 0.7 # Casual flexible | |
| # 5. Fashion rule compliance | |
| fashion_rules_score = 0.0 | |
| # Rule: No more than one item per core category (except accessories) | |
| core_categories = {"upper", "bottom", "shoe", "outerwear"} | |
| core_counts = {cat: category_types.count(cat) for cat in core_categories} | |
| if all(count <= 1 for count in core_counts.values()): | |
| fashion_rules_score += 0.2 # Perfect core category distribution | |
| # Rule: Appropriate accessory count | |
| accessory_count = category_types.count("accessory") | |
| max_accessories = template.get("accessory_limit", 3) | |
| if accessory_count <= max_accessories: | |
| fashion_rules_score += 0.1 | |
| if accessory_count > 0 and accessory_count <= 2: | |
| fashion_rules_score += 0.1 # Bonus for tasteful accessorizing | |
| # Rule: Occasion-appropriate formality | |
| if occasion == "formal" and "outerwear" in unique_types: | |
| fashion_rules_score += 0.2 # Formal requires outerwear | |
| elif occasion == "business" and len(unique_types) >= 3: | |
| fashion_rules_score += 0.15 # Business requires completeness | |
| elif occasion == "sport" and any("athletic" in cat_str(i) for i in subset): | |
| fashion_rules_score += 0.1 # Sport requires athletic items | |
| # 6. Advanced weighted combination with reasoning | |
| base_score = ( | |
| 0.25 * (diversity_score + completeness_bonus) + # Structure and completeness | |
| 0.30 * style_score + # Style coherence | |
| 0.20 * color_score + # Color harmony | |
| 0.15 * length_score + # Appropriate length | |
| 0.10 * fashion_rules_score # Fashion rule compliance | |
| ) | |
| # 7. Context-specific adjustments | |
| context_adjustment = 0.0 | |
| # Weather appropriateness | |
| weather = template["context"]["weather"] | |
| if weather == "hot" and len(subset) > 4: | |
| context_adjustment -= 0.1 # Too many layers for hot weather | |
| elif weather == "cold" and "outerwear" not in unique_types: | |
| context_adjustment -= 0.1 # Missing outerwear for cold weather | |
| elif weather == "rain" and not any("boot" in cat_str(i) for i in subset): | |
| context_adjustment -= 0.05 # Missing weather-appropriate footwear | |
| # Occasion-specific adjustments | |
| if occasion == "formal" and len(subset) < 4: | |
| context_adjustment -= 0.1 # Formal outfits should be complete | |
| elif occasion == "sport" and len(subset) > 4: | |
| context_adjustment -= 0.05 # Sport outfits can be simpler | |
| return min(1.0, max(0.0, base_score + context_adjustment)) | |
| # Advanced candidate generation with sophisticated reasoning | |
| for _ in range(num_samples): | |
| subset = [] | |
| # 1. Advanced context-aware outfit length selection | |
| occasion = template["context"]["occasion"] | |
| weather = template["context"]["weather"] | |
| outfit_style = template["context"]["style"] | |
| # Base length probabilities | |
| if occasion == "formal": | |
| if weather == "hot": | |
| outfit_length = rng.choice([3, 4], p=[0.4, 0.6]) # Formal but weather-appropriate | |
| else: | |
| outfit_length = rng.choice([4, 5], p=[0.6, 0.4]) # Complete formal sets | |
| elif occasion == "business": | |
| outfit_length = rng.choice([3, 4, 5], p=[0.3, 0.5, 0.2]) # Professional balance | |
| elif occasion == "sport": | |
| if weather == "hot": | |
| outfit_length = rng.choice([2, 3], p=[0.6, 0.4]) # Minimal for hot weather | |
| else: | |
| outfit_length = rng.choice([2, 3, 4], p=[0.3, 0.5, 0.2]) # Sport flexibility | |
| else: # casual | |
| if weather == "hot": | |
| outfit_length = rng.choice([2, 3, 4], p=[0.4, 0.4, 0.2]) # Casual, weather-appropriate | |
| else: | |
| outfit_length = rng.choice([2, 3, 4, 5], p=[0.2, 0.4, 0.3, 0.1]) # Casual flexibility | |
| # 2. Advanced strategy selection with reasoning | |
| strategy_weights = [0.4, 0.3, 0.3] # Default: Core, Accessory-focused, Flexible | |
| # Formal occasions prioritize complete, structured outfits | |
| if occasion == "formal": | |
| strategy_weights = [0.7, 0.1, 0.2] # Strongly favor core outfits | |
| # Business occasions need professional balance | |
| elif occasion == "business": | |
| strategy_weights = [0.6, 0.2, 0.2] # Favor core with some flexibility | |
| # Sport occasions can be more flexible | |
| elif occasion == "sport": | |
| strategy_weights = [0.3, 0.2, 0.5] # Favor flexible combinations | |
| # Traditional outfits need cultural coherence | |
| elif outfit_style == "traditional": | |
| strategy_weights = [0.8, 0.1, 0.1] # Strongly favor traditional core sets | |
| # Casual occasions allow more creativity | |
| else: | |
| strategy_weights = [0.4, 0.3, 0.3] # Balanced approach | |
| strategy = rng.choice([0, 1, 2], p=strategy_weights) | |
| # Strategy 1: Core outfit (shirt + pants + shoes) + accessories | |
| if strategy == 0 and uppers and bottoms and shoes: | |
| # Special handling for traditional Pakistani outfits: kameez + shalwar + peshawari | |
| if outfit_style == "traditional": | |
| # Check for traditional items | |
| traditional_uppers = [i for i in uppers if "kameez" in cat_str(i) or "kurta" in cat_str(i)] | |
| traditional_bottoms = [i for i in bottoms if "shalwar" in cat_str(i)] | |
| traditional_shoes = [i for i in shoes if "peshawari" in cat_str(i) or "chappal" in cat_str(i)] | |
| if traditional_uppers and traditional_bottoms and traditional_shoes: | |
| # Traditional Pakistani outfit: kameez + shalwar + peshawari | |
| subset.append(int(rng.choice(traditional_uppers))) # Kameez/Kurta | |
| subset.append(int(rng.choice(traditional_bottoms))) # Shalwar | |
| subset.append(int(rng.choice(traditional_shoes))) # Peshawari chappal | |
| print(f"π DEBUG: Generated traditional Pakistani outfit: kameez + shalwar + peshawari") | |
| else: | |
| # Fallback to regular outfit if traditional items not available | |
| subset.append(int(rng.choice(uppers))) | |
| subset.append(int(rng.choice(bottoms))) | |
| subset.append(int(rng.choice(shoes))) | |
| print(f"π DEBUG: Generated regular outfit (traditional items not available)") | |
| # Special handling for formal outfits: require jacket + shirt + pants + shoes | |
| elif occasion == "formal" and outerwear and len(outerwear) > 0: | |
| # Formal 3-piece suit: jacket + shirt + pants + shoes | |
| subset.append(int(rng.choice(outerwear))) # Jacket/blazer | |
| subset.append(int(rng.choice(uppers))) # Shirt | |
| subset.append(int(rng.choice(bottoms))) # Pants | |
| subset.append(int(rng.choice(shoes))) # Shoes | |
| print(f"π DEBUG: Generated formal 3-piece suit with jacket") | |
| else: | |
| # Regular core outfit: shirt + pants + shoes | |
| subset.append(int(rng.choice(uppers))) | |
| subset.append(int(rng.choice(bottoms))) | |
| subset.append(int(rng.choice(shoes))) | |
| # Prioritize accessories - higher chance of including them | |
| remaining_slots = outfit_length - len(subset) | |
| if accs and remaining_slots > 0: | |
| max_accs = min(template["accessory_limit"], remaining_slots, len(accs)) | |
| # Higher probability of including accessories | |
| num_accs = rng.integers(1, max_accs + 1) if rng.random() < 0.8 else 0 | |
| available_accs = [i for i in accs if i not in subset] | |
| if available_accs and num_accs > 0: | |
| selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False) | |
| subset.extend(selected_accs.tolist()) | |
| # Fill remaining slots with other items | |
| remaining_slots = outfit_length - len(subset) | |
| if others and remaining_slots > 0: | |
| available_others = [i for i in others if i not in subset] | |
| if available_others: | |
| num_others = min(remaining_slots, len(available_others)) | |
| selected_others = rng.choice(available_others, size=num_others, replace=False) | |
| subset.extend(selected_others.tolist()) | |
| # Strategy 2: Accessory-focused outfit (prioritize accessories) | |
| elif strategy == 1 and accs: | |
| # Start with accessories if available | |
| num_accs = min(outfit_length, len(accs)) | |
| selected_accs = rng.choice(accs, size=num_accs, replace=False) | |
| subset.extend(selected_accs.tolist()) | |
| # Fill remaining with other categories | |
| remaining_slots = outfit_length - len(subset) | |
| if remaining_slots > 0: | |
| other_categories = [] | |
| if uppers: other_categories.extend(uppers) | |
| if bottoms: other_categories.extend(bottoms) | |
| if shoes: other_categories.extend(shoes) | |
| if others: other_categories.extend(others) | |
| available_others = [i for i in other_categories if i not in subset] | |
| if available_others: | |
| num_others = min(remaining_slots, len(available_others)) | |
| selected_others = rng.choice(available_others, size=num_others, replace=False) | |
| subset.extend(selected_others.tolist()) | |
| # Strategy 3: Flexible combination (no strict slot requirements) | |
| elif strategy == 2: | |
| # Randomly select items from all categories | |
| all_items = list(ids) | |
| rng.shuffle(all_items) | |
| # Select items ensuring diversity | |
| selected_categories = set() | |
| for item in all_items: | |
| if len(subset) >= outfit_length: | |
| break | |
| item_category = get_category_type(cat_str(item)) | |
| if item_category not in selected_categories or len(subset) < 2: | |
| subset.append(item) | |
| selected_categories.add(item_category) | |
| # Remove duplicates and validate | |
| subset = list(set(subset)) | |
| if len(subset) >= 2 and len(subset) <= max_size and has_category_diversity(subset): | |
| # Add randomization factor to prevent identical recommendations | |
| subset = rng.permutation(subset).tolist() # Randomize order | |
| candidates.append(subset) | |
| if len(candidates) % 10 == 0: # Log every 10 candidates | |
| print(f"π DEBUG: Generated {len(candidates)} candidates so far...") | |
| print(f"π DEBUG: Generated {len(candidates)} total candidates") | |
| # 3) Score using ViT | |
| def score_subset(idx_subset: List[int]) -> float: | |
| embs = torch.tensor( | |
| np.stack([proc_items[i]["embedding"] for i in idx_subset], axis=0), | |
| dtype=torch.float32, | |
| device=self.device, | |
| ) # (N, D) | |
| embs = embs.unsqueeze(0) # (1, N, D) | |
| s = self.vit.score_compatibility(embs).item() | |
| return float(s) | |
| # Enhanced validation with strict slot constraints | |
| def is_valid_outfit(subset: List[int]) -> bool: | |
| """Check if outfit meets flexible requirements""" | |
| if len(subset) < 2 or len(subset) > max_size: | |
| return False | |
| categories = [get_category_type(cat_str(i)) for i in subset] | |
| raw_categories = [cat_str(i) for i in subset] | |
| category_counts = {} | |
| # Check for excluded categories | |
| excluded = template.get("excluded_categories", []) | |
| for cat in raw_categories: | |
| if any(ex in cat for ex in excluded): | |
| return False | |
| for cat in categories: | |
| category_counts[cat] = category_counts.get(cat, 0) + 1 | |
| # FLEXIBLE VALIDATION: | |
| # - At least 2 different categories | |
| # - Reasonable limits per category | |
| # - Allow variable outfit lengths | |
| unique_categories = len(set(categories)) | |
| if unique_categories < 2: | |
| return False | |
| # Reasonable limits (more flexible than before) | |
| if category_counts.get("accessory", 0) > 3: # Allow up to 3 accessories | |
| return False | |
| if category_counts.get("other", 0) > 2: # Allow up to 2 other items | |
| return False | |
| return True | |
| def calculate_outfit_penalty(subset: List[int], base_score: float) -> float: | |
| """Calculate sophisticated penalty-adjusted score with advanced fashion reasoning""" | |
| categories = [get_category_type(cat_str(i)) for i in subset] | |
| raw_categories = [cat_str(i) for i in subset] | |
| category_counts = {} | |
| for cat in categories: | |
| category_counts[cat] = category_counts.get(cat, 0) + 1 | |
| penalty = 0.0 | |
| bonus = 0.0 | |
| # 1. Critical fashion violations (severe penalties) | |
| # Missing essential categories: -β penalty | |
| if category_counts.get("upper", 0) == 0: | |
| penalty += -1000.0 | |
| if category_counts.get("bottom", 0) == 0: | |
| penalty += -1000.0 | |
| if category_counts.get("shoe", 0) == 0: | |
| penalty += -1000.0 | |
| # Duplicate core categories: -β penalty (fashion rule violation) | |
| # EXCEPTION: Allow multiple outerwear if one is a waistcoat (3-piece suit) | |
| core_categories = {"upper", "bottom", "shoe", "outerwear"} | |
| has_waistcoat = any("waistcoat" in c for c in raw_categories) | |
| for cat in core_categories: | |
| count = category_counts.get(cat, 0) | |
| if cat == "outerwear" and has_waistcoat and count <= 2: | |
| continue # Allow waistcoat + jacket | |
| if count > 1: | |
| penalty += -1000.0 | |
| # 2. Context-specific critical violations | |
| if occasion == "formal" and category_counts.get("outerwear", 0) == 0: | |
| penalty += -500.0 # Formal without jacket is inappropriate | |
| elif occasion == "business" and len(subset) < 3: | |
| penalty += -200.0 # Business outfits should be complete | |
| elif occasion == "sport" and not any("athletic" in cat_str(i) for i in subset): | |
| penalty += -300.0 # Sport outfits need athletic items | |
| # 3. Weather-appropriate violations | |
| weather = template["context"]["weather"] | |
| if weather == "hot" and len(subset) > 5: | |
| penalty += -100.0 # Too many layers for hot weather | |
| elif weather == "cold" and category_counts.get("outerwear", 0) == 0: | |
| penalty += -150.0 # Missing outerwear for cold weather | |
| elif weather == "rain" and not any("boot" in cat_str(i) for i in subset): | |
| penalty += -50.0 # Missing weather-appropriate footwear | |
| # 4. Accessory violations | |
| max_accs = template["accessory_limit"] | |
| accessory_count = category_counts.get("accessory", 0) | |
| if accessory_count > max_accs: | |
| penalty += -50.0 * (accessory_count - max_accs) # Proportional penalty | |
| # 5. Outfit balance violations | |
| if len(subset) < 2: | |
| penalty += -200.0 # Too minimal | |
| elif len(subset) > 6: | |
| penalty += -100.0 # Too complex | |
| elif len(subset) == 2 and occasion in ["formal", "business"]: | |
| penalty += -100.0 # Too minimal for formal/business | |
| # 6. Advanced bonus system | |
| # Style consistency bonus (weighted by importance) | |
| style_score = calculate_style_consistency_score(subset) | |
| bonus += style_score * 0.6 # Increased weight for style | |
| # Color harmony bonus | |
| color_score = calculate_color_consistency_score(subset) | |
| bonus += color_score * 0.4 # Increased weight for color | |
| # 7. Context-specific bonuses | |
| # Formal outfit bonuses | |
| if occasion == "formal": | |
| if "outerwear" in categories: | |
| bonus += 0.6 # Strong bonus for proper formal layering | |
| if len([c for c in categories if c in ["upper", "bottom", "shoe", "outerwear"]]) >= 4: | |
| bonus += 0.4 # Complete formal set bonus | |
| if style_score > 0.7: | |
| bonus += 0.3 # High style coherence bonus | |
| if has_waistcoat and category_counts.get("outerwear", 0) == 2: | |
| bonus += 0.5 # 3-piece suit bonus | |
| # Business outfit bonuses | |
| elif occasion == "business": | |
| if len(categories) >= 3: | |
| bonus += 0.3 # Professional completeness | |
| if "outerwear" in categories: | |
| bonus += 0.2 # Elevated business look | |
| if style_score > 0.6: | |
| bonus += 0.2 # Professional style bonus | |
| # Sport outfit bonuses | |
| elif occasion == "sport": | |
| if any("athletic" in cat_str(i) for i in subset): | |
| bonus += 0.4 # Athletic functionality | |
| if len(subset) <= 3: | |
| bonus += 0.2 # Appropriate minimalism for sport | |
| # 8. Traditional Pakistani outfit bonuses | |
| if outfit_style == "traditional": | |
| traditional_items = [cat for cat in raw_categories if any(traditional in cat for traditional in ["kameez", "kurta", "shalwar", "peshawari", "chappal", "waistcoat"])] | |
| if len(traditional_items) >= 2: | |
| bonus += 0.7 # Strong cultural appropriateness bonus | |
| if len(traditional_items) >= 3: | |
| bonus += 0.4 # Complete traditional set bonus | |
| if style_score > 0.6: | |
| bonus += 0.3 # Traditional style coherence | |
| if has_waistcoat: | |
| bonus += 0.3 # Waistcoat with traditional wear | |
| # 9. Fashion rule compliance bonuses | |
| # Perfect category distribution | |
| if all(category_counts.get(cat, 0) <= 1 for cat in core_categories): | |
| bonus += 0.3 # Perfect fashion rule compliance | |
| # Tasteful accessorizing | |
| if 1 <= accessory_count <= 2: | |
| bonus += 0.2 # Tasteful accessorizing bonus | |
| # 10. Weather-appropriate bonuses | |
| if weather == "hot" and len(subset) <= 4: | |
| bonus += 0.1 # Appropriate for hot weather | |
| elif weather == "cold" and "outerwear" in categories: | |
| bonus += 0.2 # Proper cold weather preparation | |
| elif weather == "rain" and any("boot" in cat_str(i) for i in subset): | |
| bonus += 0.15 # Weather-appropriate footwear | |
| # 11. Overall outfit quality bonus | |
| if style_score > 0.8 and color_score > 0.7: | |
| bonus += 0.3 # Exceptional outfit quality | |
| elif style_score > 0.6 and color_score > 0.5: | |
| bonus += 0.2 # Good outfit quality | |
| return base_score + penalty + bonus | |
| # Score and filter valid outfits with penalty adjustment | |
| valid_candidates = [subset for subset in candidates if is_valid_outfit(subset)] | |
| if not valid_candidates: | |
| # Fallback: use all candidates if no valid ones found | |
| valid_candidates = candidates | |
| # Score with penalty adjustment | |
| scored = [] | |
| for subset in valid_candidates: | |
| base_score = score_subset(subset) | |
| adjusted_score = calculate_outfit_penalty(subset, base_score) | |
| scored.append((subset, adjusted_score, base_score)) | |
| # Sort by penalty-adjusted score with randomization | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| # Remove duplicate outfits (same items, different order) | |
| def normalize_outfit(subset): | |
| """Normalize outfit by sorting item IDs for duplicate detection""" | |
| return tuple(sorted(subset)) | |
| seen_outfits = set() | |
| unique_scored = [] | |
| for subset, adjusted_score, base_score in scored: | |
| normalized = normalize_outfit(subset) | |
| if normalized not in seen_outfits: | |
| seen_outfits.add(normalized) | |
| unique_scored.append((subset, adjusted_score, base_score)) | |
| print(f"π DEBUG: Removed {len(scored) - len(unique_scored)} duplicate outfits") | |
| scored = unique_scored | |
| # Enhanced randomization with context awareness | |
| if len(scored) > num_outfits: | |
| # Context-aware selection: prefer higher-scoring outfits but add diversity | |
| top_third = scored[:max(num_outfits * 3, len(scored) // 3)] | |
| middle_third = scored[max(num_outfits * 3, len(scored) // 3):max(num_outfits * 6, len(scored) * 2 // 3)] | |
| # Select mix of high-scoring and diverse outfits | |
| selected = [] | |
| # Take 70% from top third (high quality) | |
| top_count = int(num_outfits * 0.7) | |
| rng.shuffle(top_third) | |
| selected.extend(top_third[:top_count]) | |
| # Take 30% from middle third (diversity) | |
| middle_count = num_outfits - len(selected) | |
| if middle_count > 0 and middle_third: | |
| rng.shuffle(middle_third) | |
| selected.extend(middle_third[:middle_count]) | |
| # Shuffle final selection for randomness | |
| rng.shuffle(selected) | |
| topk = selected[:num_outfits] | |
| else: | |
| # If we have fewer candidates than requested, shuffle them | |
| rng.shuffle(scored) | |
| topk = scored[:num_outfits] | |
| results = [] | |
| for subset, adjusted_score, base_score in topk: | |
| # Double-check validity and get item details | |
| outfit_items = [] | |
| for i in subset: | |
| item = proc_items[i] | |
| outfit_items.append({ | |
| "id": item["id"], | |
| "category": item.get("category", "unknown"), | |
| "category_type": get_category_type(item.get("category", "")) | |
| }) | |
| # Calculate additional metrics | |
| style_score = calculate_style_consistency_score(subset) | |
| color_score = calculate_color_consistency_score(subset) | |
| colors = [extract_color_from_category(cat_str(i)) for i in subset] | |
| results.append({ | |
| "item_ids": [item["id"] for item in outfit_items], | |
| "items": outfit_items, | |
| "score": float(adjusted_score), | |
| "base_score": float(base_score), | |
| "categories": [item["category"] for item in outfit_items], | |
| "category_types": [item["category_type"] for item in outfit_items], | |
| "outfit_size": len(outfit_items), | |
| "is_valid": is_valid_outfit(subset), | |
| "template": { | |
| "name": template_name, | |
| "style": template["style"], | |
| "style_score": float(style_score), | |
| "color_score": float(color_score), | |
| "colors": colors, | |
| "accessory_limit": template["accessory_limit"] | |
| } | |
| }) | |
| return results | |
| def get_model_status(self) -> Dict[str, Any]: | |
| """Get current model loading status and errors.""" | |
| return { | |
| "models_loaded": self.models_loaded, | |
| "resnet_loaded": self.resnet_loaded, | |
| "vit_loaded": self.vit_loaded, | |
| "errors": self.model_errors, | |
| "can_recommend": self.models_loaded, | |
| "resnet_model": self.resnet is not None, | |
| "vit_model": self.vit is not None | |
| } | |
| def force_reload_models(self) -> None: | |
| """Force reload models and update status - useful for debugging.""" | |
| print("π Force reloading models...") | |
| self.resnet, self.resnet_loaded = self._load_resnet() | |
| self.vit, self.vit_loaded = self._load_vit() | |
| # Move to device and set eval mode | |
| if self.resnet_loaded: | |
| self.resnet = self.resnet.to(self.device).eval() | |
| if self.vit_loaded: | |
| self.vit = self.vit.to(self.device).eval() | |
| # Disable gradients | |
| for m in [self.resnet, self.vit]: | |
| if m is not None: | |
| for p in m.parameters(): | |
| p.requires_grad_(False) | |
| # Update overall status | |
| self.models_loaded = self.resnet_loaded and self.vit_loaded | |
| print(f"π Models reloaded: resnet={self.resnet_loaded}, vit={self.vit_loaded}, overall={self.models_loaded}") | |
| if not self.models_loaded: | |
| self.model_errors = [] | |
| if not self.resnet_loaded: | |
| self.model_errors.append("ResNet: No trained weights found") | |
| if not self.vit_loaded: | |
| self.model_errors.append("ViT: No trained weights found") | |