import os import base64 import io import tempfile from typing import List, Optional, Any, Dict import gradio as gr import numpy as np import requests import torch from fastapi import FastAPI, Header, HTTPException, UploadFile, File, Form from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import List as TypingList from PIL import Image from starlette.staticfiles import StaticFiles import threading import json from inference import InferenceService from utils.data_fetch import ensure_dataset_ready from utils.tag_system import get_all_tag_options, validate_tags, TagProcessor from utils.image_utils import ( load_images_from_files, load_image_from_bytes, load_image_from_url, is_image_file, get_supported_formats, get_supported_extensions, ensure_rgb_image ) # Global state BOOT_STATUS = "starting" DATASET_ROOT: Optional[str] = None def get_artifact_overview(): """Get comprehensive artifact overview.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() return manager.get_artifact_summary() except Exception as e: return {"error": str(e)} def export_artifact_summary(): """Export artifact summary as JSON file.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() summary = manager.get_artifact_summary() # Save to exports directory export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) summary_path = os.path.join(export_dir, "system_summary.json") with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) return summary_path except Exception as e: return None def create_download_package(package_type: str): """Create a downloadable package.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() # Extract package type from the dropdown choice if "complete" in package_type: pkg_type = "complete" elif "splits_only" in package_type: pkg_type = "splits_only" elif "models_only" in package_type: pkg_type = "models_only" else: return f"❌ Invalid package type: {package_type}", get_available_packages() package_path = manager.create_download_package(pkg_type) package_name = os.path.basename(package_path) return f"✅ Package created: {package_name}", get_available_packages() except Exception as e: return f"❌ Failed to create package: {e}", get_available_packages() def get_available_packages(): """Get list of available packages.""" try: export_dir = os.getenv("EXPORT_DIR", "models/exports") packages = [] if os.path.exists(export_dir): for file in os.listdir(export_dir): if file.endswith((".tar.gz", ".zip")): file_path = os.path.join(export_dir, file) packages.append({ "name": file, "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2), "path": file_path, "url": f"/files/{file}" }) return {"packages": packages} except Exception as e: return {"error": str(e)} def get_individual_files(): """Get list of individual downloadable files.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() files = manager.get_downloadable_files() # Group by category categorized = {} for file in files: category = file["category"] if category not in categorized: categorized[category] = [] categorized[category].append(file) return categorized except Exception as e: return {"error": str(e)} def download_all_files(): """Download all files as a ZIP archive.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() files = manager.get_downloadable_files() # Create ZIP with all files export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) zip_path = os.path.join(export_dir, "all_artifacts.zip") import zipfile with zipfile.ZipFile(zip_path, 'w') as zipf: for file in files: if os.path.exists(file["path"]): zipf.write(file["path"], file["name"]) return zip_path except Exception as e: return None def get_training_status(): """Get current training status from the monitor.""" try: from ui.monitor import create_monitor monitor = create_monitor() status = monitor.get_status() return status if status else {"status": "no-training"} except Exception as e: return {"status": "error", "error": str(e)} def push_splits_to_hf(token, username): """Push splits to HF Hub.""" if not token or not username: return "❌ Please provide HF token and username" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=token, username=username) result = hf.upload_model("splits", "Dressify-Helper") if result.get("success"): return f"✅ Successfully uploaded splits to {username}/Dressify-Helper" else: return f"❌ Failed to upload splits: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Upload failed: {e}" def push_models_to_hf(token, username): """Push models to HF Hub.""" if not token or not username: return "❌ Please provide HF token and username" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=token, username=username) result = hf.upload_model("models", "dressify-models") if result.get("success"): return f"✅ Successfully uploaded models to {username}/dressify-models" else: return f"❌ Failed to upload models: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Upload failed: {e}" def push_everything_to_hf(token, username): """Push everything to HF Hub.""" if not token or not username: return "❌ Please provide HF token and username" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=token, username=username) result = hf.upload_model("everything", "dressify-complete") if result.get("success"): return f"✅ Successfully uploaded everything to HF Hub" else: return f"❌ Failed to upload everything: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Upload failed: {e}" AI_API_KEY = os.getenv("AI_API_KEY") def require_api_key(x_api_key: Optional[str]): if AI_API_KEY and x_api_key != AI_API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") class EmbedRequest(BaseModel): image_urls: Optional[List[str]] = None images_base64: Optional[List[str]] = None class Item(BaseModel): id: str embedding: Optional[List[float]] = None category: Optional[str] = None image_url: Optional[str] = None image_base64: Optional[str] = None # Base64 encoded image class ComposeRequest(BaseModel): items: List[Item] context: Optional[Dict[str, Any]] = None # Expanded tag fields for better API usability occasion: Optional[str] = "casual" weather: Optional[str] = "any" style: Optional[str] = "casual" outfit_style: Optional[str] = None # Alias for style num_outfits: Optional[int] = 5 # Optional tags color_preference: Optional[str] = None fit_preference: Optional[str] = None material_preference: Optional[str] = None season: Optional[str] = None time_of_day: Optional[str] = None budget: Optional[str] = None personal_style: Optional[str] = None age_group: Optional[str] = None gender: Optional[str] = None class Config: extra = "allow" # Allow additional fields for flexibility app = FastAPI(title="Dressify Recommendation Service") service = InferenceService() # Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background BOOT_STATUS = "idle" DATASET_ROOT: Optional[str] = None def _background_bootstrap(): global BOOT_STATUS global DATASET_ROOT try: # Only check if dataset exists - DO NOT prepare it automatically root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore")) images_dir = os.path.join(root, "images") splits_dir = os.path.join(root, "splits") # Check if dataset already exists has_images = os.path.isdir(images_dir) and any(os.listdir(images_dir)) has_splits = ( os.path.isfile(os.path.join(splits_dir, "train.json")) or os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json")) ) if has_images and has_splits: print("✅ Dataset and splits already prepared") DATASET_ROOT = root BOOT_STATUS = "ready" elif has_images: print("✅ Dataset images exist, but splits may be missing (use Advanced Training to prepare)") DATASET_ROOT = root BOOT_STATUS = "ready" else: print("ℹ️ Dataset not prepared. Use 'Download & Prepare Dataset' button in Advanced Training tab if needed.") DATASET_ROOT = None BOOT_STATUS = "ready" # System is ready, just dataset not prepared # NO automatic training - models should be pre-trained or trained manually via UI BOOT_STATUS = "ready" except Exception as e: BOOT_STATUS = f"error: {e}" threading.Thread(target=_background_bootstrap, daemon=True).start() @app.get("/health") def health() -> dict: return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version} @app.get("/tags") def get_tags() -> dict: """ Get all available tag options for API integration. Returns comprehensive list of all tag categories and their valid values. """ return { "tag_categories": get_all_tag_options(), "description": "Available tags for personalized outfit recommendations", "usage": { "primary_tags": ["occasion", "weather", "style"], "optional_tags": ["color_preference", "fit_preference", "material_preference", "season", "time_of_day", "budget", "personal_style", "age_group", "gender"] } } @app.get("/image-formats") def get_image_formats() -> dict: """ Get all supported image formats for API integration. """ return { "supported_formats": get_supported_formats(), "supported_extensions": get_supported_extensions(), "description": "All major image formats are supported including JPG, PNG, WEBP, GIF, BMP, TIFF, and more", "note": "Images are automatically converted to RGB mode for model processing" } @app.post("/compose/upload") async def compose_with_upload( files: TypingList[UploadFile] = File(...), occasion: str = Form("casual"), weather: str = Form("any"), style: str = Form("casual"), num_outfits: int = Form(5), color_preference: Optional[str] = Form(None), fit_preference: Optional[str] = Form(None), material_preference: Optional[str] = Form(None), season: Optional[str] = Form(None), time_of_day: Optional[str] = Form(None), budget: Optional[str] = Form(None), personal_style: Optional[str] = Form(None), x_api_key: Optional[str] = Header(None) ) -> dict: """ Generate outfit recommendations from uploaded image files. This endpoint accepts multipart/form-data with: - files: Image files (JPG, PNG, WEBP, GIF, BMP, TIFF, etc.) - All tag parameters as form fields Returns personalized outfit recommendations. """ require_api_key(x_api_key) if not files or len(files) < 2: raise HTTPException(status_code=400, detail="At least 2 images required for outfit recommendations") # Load images from uploaded files items = [] errors = [] for i, file in enumerate(files): try: # Read file content contents = await file.read() # Load image from bytes img = load_image_from_bytes(contents, convert_to_rgb=True, raise_on_error=False) if img is None: errors.append(f"Failed to load image from file {file.filename}") continue items.append({ "id": f"item_{i}", "image": img, "category": None # Will be auto-detected }) except Exception as e: errors.append(f"Error processing file {file.filename}: {str(e)}") if len(items) < 2: error_msg = f"Not enough valid images. Need at least 2, got {len(items)}." if errors: error_msg += f" Errors: {', '.join(errors[:3])}" raise HTTPException(status_code=400, detail=error_msg) # Build context context = { "occasion": occasion, "weather": weather, "style": style, "outfit_style": style, "num_outfits": num_outfits } # Add optional tags if color_preference and color_preference != "None": context["color_preference"] = color_preference if fit_preference and fit_preference != "None": context["fit_preference"] = fit_preference if material_preference and material_preference != "None": context["material_preference"] = material_preference if season and season != "None": context["season"] = season if time_of_day and time_of_day != "None": context["time_of_day"] = time_of_day if budget and budget != "None": context["budget"] = budget if personal_style and personal_style != "None": context["personal_style"] = personal_style # Validate tags is_valid, tag_errors = validate_tags(context) if not is_valid: return JSONResponse( status_code=400, content={ "error": "Invalid tags provided", "errors": tag_errors, "valid_tag_options": get_all_tag_options() } ) # Generate recommendations try: outfits = service.compose_outfits(items, context=context) # Check for errors if outfits and isinstance(outfits, list) and len(outfits) > 0: if isinstance(outfits[0], dict) and "error" in outfits[0]: return JSONResponse( status_code=500, content={ "error": "Recommendation generation failed", "details": outfits[0].get("details", []), "message": outfits[0].get("message", "Unknown error") } ) return { "outfits": outfits, "version": service.vit_version, "tags_processed": True, "context_used": context, "items_processed": len(items), "warnings": errors if errors else None } except Exception as e: import traceback traceback.print_exc() return JSONResponse( status_code=500, content={ "error": "Internal server error", "message": str(e), "model_status": service.get_model_status() } ) @app.post("/tags/validate") def validate_request_tags(tags: Dict[str, Any], x_api_key: Optional[str] = Header(None)) -> dict: """ Validate tag values before making a recommendation request. Useful for API clients to check tag validity. """ require_api_key(x_api_key) is_valid, errors = validate_tags(tags) return { "valid": is_valid, "errors": errors if not is_valid else [], "validated_tags": tags if is_valid else None } @app.get("/model-status") def model_status() -> dict: """Get detailed model loading status.""" return service.get_model_status() @app.post("/reload-models") def reload_models() -> dict: """Force reload models - useful for debugging.""" try: service.force_reload_models() return {"status": "success", "message": "Models reloaded successfully"} except Exception as e: return {"status": "error", "message": str(e)} @app.post("/test-recommend") def test_recommend() -> dict: """Test recommendation with dummy data to debug the issue.""" try: # Create dummy items for testing dummy_items = [ {"id": "test_1", "image": None, "category": "shirt"}, {"id": "test_2", "image": None, "category": "pants"}, {"id": "test_3", "image": None, "category": "shoes"} ] # Try to get recommendations result = service.compose_outfits(dummy_items, {"num_outfits": 1}) return { "status": "success", "model_status": service.get_model_status(), "result": result, "result_length": len(result) if result else 0 } except Exception as e: return {"status": "error", "message": str(e), "model_status": service.get_model_status()} @app.post("/embed") def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict: """ Generate embeddings for images with comprehensive format support. Supports JPG, PNG, WEBP, GIF, BMP, TIFF, and other major formats. """ require_api_key(x_api_key) images: List[Image.Image] = [] errors = [] # Load from URLs if req.image_urls: for url in req.image_urls: img = load_image_from_url(url, timeout=20, convert_to_rgb=True, raise_on_error=False) if img is not None: images.append(img) else: errors.append(f"Failed to load image from URL: {url}") # Load from base64 if req.images_base64: for b64 in req.images_base64: try: image_bytes = base64.b64decode(b64) img = load_image_from_bytes(image_bytes, convert_to_rgb=True, raise_on_error=False) if img is not None: images.append(img) else: errors.append("Failed to load image from base64") except Exception as e: errors.append(f"Error decoding base64 image: {str(e)}") if not images: error_msg = "No images provided or all images failed to load" if errors: error_msg += f". Errors: {', '.join(errors[:3])}" raise HTTPException(status_code=400, detail=error_msg) # Ensure all images are RGB images = [ensure_rgb_image(img) for img in images] embs = service.embed_images(images) return { "embeddings": [e.tolist() for e in embs], "model_version": service.resnet_version, "images_loaded": len(images), "errors": errors if errors else None } @app.post("/compose") def compose(req: ComposeRequest, x_api_key: Optional[str] = Header(None)) -> dict: """ Generate personalized outfit recommendations with expanded tag support. Supports both legacy context dict format and new tag-based format. Tags are processed and prioritized automatically. Items can provide: - image_url: URL to image (will be downloaded) - image_base64: Base64 encoded image - embedding: Pre-computed embedding (skips ResNet) - category: Item category (optional, auto-detected if not provided) """ require_api_key(x_api_key) # Build items with image loading items = [] errors = [] for it in req.items: item_dict = { "id": it.id, "embedding": np.array(it.embedding, dtype=np.float32) if it.embedding is not None else None, "category": it.category, } # Load image from URL if provided if it.image_url: img = load_image_from_url(it.image_url, timeout=20, convert_to_rgb=True, raise_on_error=False) if img is not None: item_dict["image"] = img else: errors.append(f"Failed to load image from URL for item {it.id}: {it.image_url}") # Load image from base64 if provided elif it.image_base64: try: image_bytes = base64.b64decode(it.image_base64) img = load_image_from_bytes(image_bytes, convert_to_rgb=True, raise_on_error=False) if img is not None: item_dict["image"] = img else: errors.append(f"Failed to load image from base64 for item {it.id}") except Exception as e: errors.append(f"Error decoding base64 image for item {it.id}: {str(e)}") # If no image and no embedding, skip this item if item_dict.get("image") is None and item_dict.get("embedding") is None: errors.append(f"Item {it.id} has no image or embedding - skipping") continue items.append(item_dict) if not items: error_msg = "No valid items provided. All items failed to load or have no images/embeddings." if errors: error_msg += f" Errors: {', '.join(errors[:5])}" raise HTTPException(status_code=400, detail=error_msg) # Build context from request context = req.context or {} # Add explicit tag fields if provided (takes precedence over context dict) if req.occasion: context["occasion"] = req.occasion if req.weather: context["weather"] = req.weather if req.style: context["style"] = req.style if req.outfit_style: context["outfit_style"] = req.outfit_style context["style"] = req.outfit_style # Also set style for consistency if req.num_outfits: context["num_outfits"] = req.num_outfits # Add optional tags optional_tags = { "color_preference": req.color_preference, "fit_preference": req.fit_preference, "material_preference": req.material_preference, "season": req.season, "time_of_day": req.time_of_day, "budget": req.budget, "personal_style": req.personal_style, "age_group": req.age_group, "gender": req.gender, } for tag_name, tag_value in optional_tags.items(): if tag_value: context[tag_name] = tag_value # Validate tags is_valid, tag_errors = validate_tags(context) if not is_valid: return JSONResponse( status_code=400, content={ "error": "Invalid tags provided", "errors": tag_errors, "valid_tag_options": get_all_tag_options() } ) # Generate recommendations try: outfits = service.compose_outfits(items, context=context) # Check if compose_outfits returned an error if outfits and isinstance(outfits, list) and len(outfits) > 0: if isinstance(outfits[0], dict) and "error" in outfits[0]: return JSONResponse( status_code=500, content={ "error": "Recommendation generation failed", "details": outfits[0].get("details", []), "message": outfits[0].get("message", "Unknown error") } ) return { "outfits": outfits, "version": service.vit_version, "tags_processed": True, "context_used": context, "items_processed": len(items), "warnings": errors if errors else None } except Exception as e: import traceback traceback.print_exc() return JSONResponse( status_code=500, content={ "error": "Internal server error during recommendation generation", "message": str(e), "model_status": service.get_model_status() } ) @app.get("/artifacts") def artifacts() -> dict: # list exported model artifacts for download export_dir = os.getenv("EXPORT_DIR", "models/exports") files = [] if os.path.isdir(export_dir): for fn in os.listdir(export_dir): if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")): files.append({ "name": fn, "path": f"{export_dir}/{fn}", "url": f"/files/{fn}", }) return {"artifacts": files} # --------- Gradio UI --------- def _load_images_from_files(files: List[str]) -> List[Image.Image]: """ Load images from file paths with comprehensive format support. Supports JPG, PNG, WEBP, GIF, BMP, TIFF, and other major formats. """ return load_images_from_files(files, convert_to_rgb=True, skip_errors=True) def gradio_embed(files: List[str]): if not files: return "[]" images = _load_images_from_files(files) if not images: return "[]" embs = service.embed_images(images) return str([e.tolist() for e in embs]) def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(245, 245, 245)) -> Image.Image: if not imgs: return Image.new("RGB", (1, height), color=bg) resized = [] for im in imgs: if im.mode != "RGB": im = im.convert("RGB") w, h = im.size scale = height / float(h) nw = max(1, int(w * scale)) resized.append(im.resize((nw, height))) total_w = sum(im.size[0] for im in resized) + pad * (len(resized) + 1) out = Image.new("RGB", (total_w, height + 2 * pad), color=bg) x = pad for im in resized: out.paste(im, (x, pad)) x += im.size[0] + pad return out def gradio_recommend( files: List[str], occasion: str, weather: str, num_outfits: int, outfit_style: str = "casual", color_preference: str = None, fit_preference: str = None, material_preference: str = None, season: str = None, time_of_day: str = None, budget: str = None, personal_style: str = None ): # Check model status first model_status = service.get_model_status() if not model_status["can_recommend"]: error_msg = "❌ Models not ready for recommendations!\n\n" error_msg += "**Model Status:**\n" error_msg += f"- ResNet: {'✅ Loaded' if model_status['resnet_loaded'] else '❌ Not loaded'}\n" error_msg += f"- ViT: {'✅ Loaded' if model_status['vit_loaded'] else '❌ Not loaded'}\n\n" error_msg += "**Errors:**\n" for error in model_status["errors"]: error_msg += f"- {error}\n\n" error_msg += "**Solution:**\n" error_msg += "Please train the models first using the 'Simple Training' or 'Advanced Training' tabs, or ensure trained checkpoints are available." return [], {"error": error_msg, "model_status": model_status} # Return stitched outfit images and a JSON with details if not files: return [], {"error": "No files uploaded"} # Enhanced debug: Log detailed file information for API troubleshooting print(f"🔍 DEBUG: gradio_recommend called with {len(files)} files") file_info = [] for i, f in enumerate(files): if isinstance(f, str): from pathlib import Path path = Path(f) file_size = path.stat().st_size if path.exists() else 0 file_info.append(f"File {i+1}: {path.name} ({file_size} bytes)") print(f"🔍 DEBUG: File {i+1}: path={f}, exists={path.exists()}, size={file_size}, name={path.name}") else: file_info.append(f"Type: {type(f).__name__}, Value: {str(f)[:100]}") print(f"🔍 DEBUG: File {i+1}: type={type(f).__name__}, value={str(f)[:100]}") try: print(f"🔍 DEBUG: Attempting to load {len(files)} images...") images = _load_images_from_files(files) print(f"🔍 DEBUG: Successfully loaded {len(images)} images from {len(files)} files") if not images: error_msg = "Could not load images from uploaded files.\n\n" error_msg += f"Files received: {len(files)}\n" error_msg += f"File details: {', '.join(file_info[:3])}\n\n" error_msg += f"Supported formats: {', '.join(get_supported_extensions())}\n" error_msg += "Please ensure files are valid image files (JPG, PNG, WEBP, GIF, BMP, TIFF, etc.)" print(f"🔍 DEBUG: ERROR - No images loaded. Files: {len(files)}, Images: {len(images)}") return [], {"error": error_msg, "files_received": len(files), "file_info": file_info[:5]} except Exception as e: import traceback error_msg = f"Error processing files: {str(e)}\n\n" error_msg += f"Files received: {len(files)}\n" error_msg += f"File details: {', '.join(file_info[:3])}\n\n" error_msg += "Please check that files are valid image files." print(f"🔍 DEBUG: EXCEPTION during image loading: {e}") traceback.print_exc() return [], {"error": error_msg, "exception": str(e), "files_received": len(files)} # Tag normalization: Map frontend values to backend values # This handles variations and synonyms from different frontend implementations def normalize_tag_value(tag_name: str, value: str) -> str: """Normalize tag values to match backend expectations""" if not value or value == "None" or value is None: return None value_lower = value.lower().strip() # Color preference mappings - normalize variations to standard values if tag_name == "color_preference": color_mappings = { "monochrome": "monochromatic", # Frontend uses "monochrome", backend uses "monochromatic" "mono": "monochromatic", "single_color": "monochromatic", "one_color": "monochromatic", } normalized = color_mappings.get(value_lower, value) # Validate against allowed values allowed_colors = ["neutral", "monochromatic", "complementary", "bold", "subtle", "bright", "muted", "pastel", "dark", "light", "earth_tones", "jewel_tones", "black_white", "navy_white", "colorful", "minimal_color"] if normalized in allowed_colors: return normalized return value # Return original if not in allowed list # Fit preference mappings if tag_name == "fit_preference": fit_mappings = { "slim": "fitted", "tight_fit": "fitted", "baggy": "loose", "wide": "loose", } normalized = fit_mappings.get(value_lower, value) # Validate against allowed values allowed_fits = ["fitted", "loose", "oversized", "relaxed", "comfortable", "structured", "flowy", "tailored", "athletic_fit", "regular_fit"] if normalized in allowed_fits: return normalized return value # Season mappings - both "fall" and "autumn" are valid, keep as-is if tag_name == "season": # Both "fall" and "autumn" are in the Literal type, so no normalization needed return value # Return original value if no mapping found return value # Normalize all tag values occasion = normalize_tag_value("occasion", occasion) or occasion weather = normalize_tag_value("weather", weather) or weather outfit_style = normalize_tag_value("outfit_style", outfit_style) or outfit_style color_preference = normalize_tag_value("color_preference", color_preference) if color_preference else None fit_preference = normalize_tag_value("fit_preference", fit_preference) if fit_preference else None material_preference = normalize_tag_value("material_preference", material_preference) if material_preference else None season = normalize_tag_value("season", season) if season else None time_of_day = normalize_tag_value("time_of_day", time_of_day) if time_of_day else None budget = normalize_tag_value("budget", budget) if budget else None personal_style = normalize_tag_value("personal_style", personal_style) if personal_style else None # Build comprehensive context with all tags context = { "occasion": occasion, "weather": weather, "style": outfit_style, "outfit_style": outfit_style, # Backward compatibility "num_outfits": int(num_outfits) } # Add optional tags if provided if color_preference and color_preference != "None": context["color_preference"] = color_preference if fit_preference and fit_preference != "None": context["fit_preference"] = fit_preference if material_preference and material_preference != "None": context["material_preference"] = material_preference if season and season != "None": context["season"] = season if time_of_day and time_of_day != "None": context["time_of_day"] = time_of_day if budget and budget != "None": context["budget"] = budget if personal_style and personal_style != "None": context["personal_style"] = personal_style # Build items that allow on-the-fly embedding in service items = [ {"id": f"item_{i}", "image": images[i], "category": None} for i in range(len(images)) ] print(f"🔍 DEBUG: Calling compose_outfits with {len(items)} items, context={context}") res = service.compose_outfits(items, context=context) print(f"🔍 DEBUG: compose_outfits returned {len(res) if res else 0} results") if res: print(f"🔍 DEBUG: First result type: {type(res[0])}, keys: {res[0].keys() if isinstance(res[0], dict) else 'N/A'}") # Check if compose_outfits returned an error if res and isinstance(res[0], dict) and "error" in res[0]: print(f"🔍 DEBUG: Error in compose_outfits result: {res[0]}") return [], res[0] # Prepare stitched previews - save to temp files for Gradio API compatibility strips: List[str] = [] # Changed to List[str] for file paths print(f"🔍 DEBUG: Preparing stitched previews for {len(res)} outfits...") for i, r in enumerate(res): idxs = [] item_ids = r.get("item_ids", []) print(f"🔍 DEBUG: Outfit {i+1}: item_ids={item_ids}") for iid in item_ids: try: idx = int(str(iid).split("_")[-1]) idxs.append(idx) print(f"🔍 DEBUG: Mapped {iid} -> index {idx}") except Exception as e: print(f"🔍 DEBUG: Failed to parse {iid}: {e}") continue imgs = [images[i] for i in idxs if 0 <= i < len(images)] print(f"🔍 DEBUG: Extracted {len(imgs)} images from indices {idxs}") if imgs: strip = _stitch_strip(imgs) print(f"🔍 DEBUG: Created stitched image: {strip.size}") # Save to temporary file (Gradio will convert to URL) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='/tmp') strip.save(temp_file.name, 'PNG') temp_file.close() strips.append(temp_file.name) # Return file path instead of PIL Image print(f"🔍 DEBUG: Saved to temp file: {temp_file.name}") else: print(f"⚠️ DEBUG: No images extracted for outfit {i+1}") print(f"🔍 DEBUG: Returning {len(strips)} stitched image file paths and {len(res)} outfit results") return strips, {"outfits": res} def start_training_advanced( # Dataset size dataset_size: str, # ResNet parameters resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str, resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int, resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float, # ViT parameters vit_epochs: int, vit_batch_size: int, vit_max_samples: int, vit_lr: float, vit_optimizer: str, vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int, vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float, # Advanced parameters use_mixed_precision: bool, channels_last: bool, gradient_clip: float, warmup_epochs: int, scheduler_type: str, early_stopping_patience: int, mining_strategy: str, augmentation_level: str, seed: int ): """Start advanced training with custom parameters.""" # Use global dataset size if not specified if not dataset_size or dataset_size == "full": dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000") if not DATASET_ROOT: return "❌ Dataset not ready. Please wait for bootstrap to complete." log_message = "🚀 Advanced training started with custom parameters! Check the log below for progress." def _runner(): nonlocal log_message try: import subprocess import json export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) # Create custom config files resnet_config = { "model": { "backbone": resnet_backbone, "embedding_dim": resnet_embedding_dim, "pretrained": resnet_use_pretrained, "dropout": resnet_dropout }, "training": { "batch_size": resnet_batch_size, "epochs": resnet_epochs, "lr": resnet_lr, "weight_decay": resnet_weight_decay, "triplet_margin": resnet_triplet_margin, "optimizer": resnet_optimizer, "scheduler": scheduler_type, "warmup_epochs": warmup_epochs, "early_stopping_patience": early_stopping_patience, "use_amp": use_mixed_precision, "channels_last": channels_last, "gradient_clip": gradient_clip }, "data": { "image_size": 224, "augmentation_level": augmentation_level }, "advanced": { "mining_strategy": mining_strategy, "seed": seed } } vit_config = { "model": { "embedding_dim": vit_embedding_dim, "num_layers": vit_num_layers, "num_heads": vit_num_heads, "ff_multiplier": vit_ff_multiplier, "dropout": vit_dropout }, "training": { "batch_size": vit_batch_size, "epochs": vit_epochs, "lr": vit_lr, "weight_decay": vit_weight_decay, "triplet_margin": vit_triplet_margin, "optimizer": vit_optimizer, "scheduler": scheduler_type, "warmup_epochs": warmup_epochs, "early_stopping_patience": early_stopping_patience, "use_amp": use_mixed_precision }, "advanced": { "mining_strategy": mining_strategy, "seed": seed } } # Save configs with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f: json.dump(resnet_config, f, indent=2) with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f: json.dump(vit_config, f, indent=2) # Train ResNet with custom parameters log_message = f"🚀 Starting ResNet training with custom parameters...\n" log_message += f"Dataset Size: {dataset_size} samples\n" log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n" log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n" log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n" # Add dataset size limit if not full dataset_args = [] if dataset_size != "full": dataset_args = ["--max_samples", dataset_size] resnet_cmd = [ "python", "training/train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(resnet_epochs), "--batch_size", str(resnet_batch_size), "--lr", str(resnet_lr), "--weight_decay", str(resnet_weight_decay), "--triplet_margin", str(resnet_triplet_margin), "--embedding_dim", str(resnet_embedding_dim), "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth") ] + dataset_args if resnet_backbone != "resnet50": resnet_cmd.extend(["--backbone", resnet_backbone]) result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False) if result.returncode == 0: log_message += "✅ ResNet training completed successfully!\n" log_message += f"📊 ResNet Output:\n{result.stdout}\n\n" else: log_message += f"❌ ResNet training failed: {result.stderr}\n\n" return log_message # Wait a moment for file system sync and ensure ResNet is fully saved import time time.sleep(3) log_message += "⏳ Waiting for ResNet checkpoint to be fully saved...\n" # Verify ResNet checkpoint exists before proceeding resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder_custom.pth") if not os.path.exists(resnet_checkpoint): log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n" log_message += "Cannot proceed with ViT training without ResNet embeddings.\n" return log_message log_message += f"✅ ResNet checkpoint verified: {resnet_checkpoint}\n" # Train ViT with custom parameters log_message += f"🚀 Starting ViT training with custom parameters...\n" log_message += f"Dataset Size: {dataset_size} samples\n" log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n" log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n" log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n" vit_cmd = [ "python", "training/train_vit.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs), "--batch_size", str(vit_batch_size), "--max_samples", str(vit_max_samples), "--lr", str(vit_lr), "--weight_decay", str(vit_weight_decay), "--triplet_margin", str(vit_triplet_margin), "--embedding_dim", str(vit_embedding_dim), "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth") ] + dataset_args result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False) if result.returncode == 0: log_message += "✅ ViT training completed successfully!\n" log_message += f"📊 ViT Output:\n{result.stdout}\n\n" log_message += "🎉 All training completed! Models saved to models/exports/\n" log_message += "🔄 Reloading models for inference...\n" service.reload_models() # Check if models loaded successfully model_status = service.get_model_status() if model_status["can_recommend"]: log_message += "✅ Models reloaded and ready for inference!\n" log_message += "🎉 You can now generate outfit recommendations!\n" else: log_message += "⚠️ Models reloaded but validation failed!\n" log_message += "**Model Status:**\n" log_message += f"- ResNet: {'✅ Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n" log_message += f"- ViT: {'✅ Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n" for error in model_status["errors"]: log_message += f"- {error}\n" # Auto-upload to HF Hub if token is available hf_token = os.getenv("HF_TOKEN") if hf_token: log_message += "📤 Auto-uploading artifacts to Hugging Face Hub...\n" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=hf_token, username="Stylique") result = hf.upload_model("everything", "dressify-complete") if result.get("success"): log_message += "✅ Successfully uploaded to HF Hub!\n" log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n" log_message += "🔗 Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n" else: log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n" except Exception as e: log_message += f"⚠️ Auto-upload failed: {str(e)}\n" else: log_message += "💡 Set HF_TOKEN env var for automatic uploads\n" else: log_message += f"❌ ViT training failed: {result.stderr}\n" except Exception as e: log_message += f"\n❌ Training error: {str(e)}" threading.Thread(target=_runner, daemon=True).start() return log_message def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int): """Start simple training with basic parameters.""" # Use global dataset size if not specified if not dataset_size or dataset_size == "full": dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000") log_message = f"Starting training on {dataset_size} samples..." def _runner(): nonlocal log_message try: import subprocess if not DATASET_ROOT: log_message = "Dataset not ready." return export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) log_message = f"Training ResNet on {dataset_size} samples...\n" # Add dataset size limit if not full dataset_args = [] if dataset_size != "full": dataset_args = ["--max_samples", dataset_size] # Train ResNet first and wait for completion log_message += f"\n🚀 Starting ResNet training on {dataset_size} samples...\n" resnet_result = subprocess.run([ "python", "training/train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs), "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3", "--out", os.path.join(export_dir, "resnet_item_embedder.pth") ] + dataset_args, capture_output=True, text=True, check=False) if resnet_result.returncode == 0: log_message += "✅ ResNet training completed successfully!\n" log_message += f"📊 ResNet Output:\n{resnet_result.stdout}\n" else: log_message += f"❌ ResNet training failed: {resnet_result.stderr}\n" return log_message # Wait a moment for file system sync import time time.sleep(2) # Verify ResNet checkpoint exists before proceeding resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder.pth") if not os.path.exists(resnet_checkpoint): log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n" log_message += "Cannot proceed with ViT training without ResNet embeddings.\n" return log_message log_message += f"✅ ResNet checkpoint verified: {resnet_checkpoint}\n" log_message += f"\n🚀 Starting ViT training on {dataset_size} samples...\n" vit_result = subprocess.run([ "python", "training/train_vit.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs), "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5", "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0", "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth") ] + dataset_args, capture_output=True, text=True, check=False) if vit_result.returncode == 0: log_message += "✅ ViT training completed successfully!\n" log_message += f"📊 ViT Output:\n{vit_result.stdout}\n" else: log_message += f"❌ ViT training failed: {vit_result.stderr}\n" return log_message service.reload_models() # Check if models loaded successfully model_status = service.get_model_status() if model_status["can_recommend"]: log_message += "\n✅ Training completed! Models reloaded and ready for inference.\n" log_message += "🎉 You can now generate outfit recommendations!\n" else: log_message += "\n⚠️ Training completed but models failed to load properly!\n" log_message += "**Model Status:**\n" log_message += f"- ResNet: {'✅ Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n" log_message += f"- ViT: {'✅ Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n" for error in model_status["errors"]: log_message += f"- {error}\n" log_message += "\nArtifacts saved to models/exports/" # Auto-upload to HF Hub if token is available hf_token = os.getenv("HF_TOKEN") if hf_token: log_message += "\n📤 Auto-uploading artifacts to Hugging Face Hub...\n" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=hf_token, username="Stylique") result = hf.upload_model("everything", "dressify-complete") if result.get("success"): log_message += "✅ Successfully uploaded to HF Hub!\n" log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n" log_message += "🔗 Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n" else: log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n" except Exception as e: log_message += f"⚠️ Auto-upload failed: {str(e)}\n" else: log_message += "\n💡 Set HF_TOKEN env var for automatic uploads\n" except Exception as e: log_message += f"\nError: {e}" threading.Thread(target=_runner, daemon=True).start() return log_message with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo: gr.Markdown("## 🏆 Dressify – Advanced Outfit Recommendation System\n*Research-grade, self-contained outfit recommendation with comprehensive training controls*") gr.Markdown("💡 **Pro Tip**: Start with 2000 samples for quick testing, then increase to 50000+ for production training!") with gr.Tab("🎨 Recommend"): gr.Markdown("### 🎯 Personalized Outfit Recommendations\n*Upload your wardrobe and customize recommendations with advanced tag preferences*") gr.Markdown(f"**Supported Formats:** {', '.join(get_supported_extensions())} (JPG, PNG, WEBP, GIF, BMP, TIFF, and more)") inp2 = gr.Files( label="Upload wardrobe images", file_count="multiple" # Note: file_types removed to allow API client flexibility # Validation is handled by our image_utils.load_images_from_files() ) with gr.Accordion("🎯 Primary Tags (Required)", open=True): with gr.Row(): occasion = gr.Dropdown( choices=["casual", "business", "formal", "semi_formal", "business_casual", "cocktail", "wedding", "party", "date", "sport", "workout", "travel", "beach", "outdoor", "night_out", "brunch", "dinner", "meeting", "interview", "cultural", "traditional"], value="casual", label="Occasion", info="Select the occasion or event type" ) weather = gr.Dropdown( choices=["any", "hot", "warm", "mild", "cool", "cold", "freezing", "rain", "snow", "windy", "humid", "sunny", "cloudy"], value="any", label="Weather", info="Current or expected weather conditions" ) outfit_style = gr.Dropdown( choices=["casual", "smart_casual", "formal", "sporty", "athletic", "streetwear", "minimalist", "classic", "modern", "elegant", "sophisticated", "traditional", "ethnic"], value="casual", label="Outfit Style", info="Preferred fashion aesthetic" ) with gr.Accordion("🎨 Style & Preference Tags (Optional)", open=False): with gr.Row(): color_preference = gr.Dropdown( choices=["None", "neutral", "monochromatic", "monochrome", "complementary", "bold", "subtle", "bright", "muted", "pastel", "dark", "light", "earth_tones", "jewel_tones", "black_white", "navy_white", "colorful", "minimal_color"], value="None", label="Color Preference", info="Preferred color scheme" ) fit_preference = gr.Dropdown( choices=["None", "fitted", "loose", "oversized", "relaxed", "comfortable", "structured", "flowy", "tailored", "athletic_fit", "regular_fit"], value="None", label="Fit Preference", info="Preferred fit and silhouette" ) material_preference = gr.Dropdown( choices=["None", "cotton", "linen", "silk", "wool", "cashmere", "denim", "leather", "breathable", "waterproof", "moisture_wicking", "sustainable"], value="None", label="Material Preference", info="Preferred fabric or material type" ) with gr.Accordion("📅 Context Tags (Optional)", open=False): with gr.Row(): season = gr.Dropdown( choices=["None", "spring", "summer", "fall", "autumn", "winter", "year_round", "transitional"], value="None", label="Season", info="Current season" ) time_of_day = gr.Dropdown( choices=["None", "morning", "afternoon", "evening", "night", "all_day"], value="None", label="Time of Day", info="When will you wear this outfit?" ) budget = gr.Dropdown( choices=["None", "luxury", "premium", "mid_range", "affordable", "budget", "value"], value="None", label="Budget Preference", info="Price range preference (informational)" ) personal_style = gr.Dropdown( choices=["None", "conservative", "moderate", "bold", "experimental", "traditional", "trendy", "timeless", "fashion_forward", "classic", "eclectic"], value="None", label="Personal Style", info="Your personal style preference" ) with gr.Row(): num_outfits = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Outfits", info="How many outfit recommendations to generate") out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=400, show_label=True) out_json = gr.JSON(label="Outfit Details & Tag Analysis", show_label=True) btn2 = gr.Button("✨ Generate Personalized Outfits", variant="primary", size="lg") btn2.click( fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits, outfit_style, color_preference, fit_preference, material_preference, season, time_of_day, budget, personal_style], outputs=[out_gallery, out_json] ) with gr.Tab("🔬 Advanced Training"): gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.") # Dataset Preparation Section with gr.Accordion("📦 Dataset Preparation (Optional)", open=False): gr.Markdown("**Note**: Dataset preparation is now manual only. Click the button below to download and prepare the dataset when needed.") with gr.Row(): prepare_dataset_btn = gr.Button("📥 Download & Prepare Dataset", variant="secondary") prepare_status = gr.Textbox(label="Dataset Preparation Status", value="Dataset will be prepared if missing", interactive=False) def prepare_dataset_manual(): """Manually trigger dataset preparation.""" global DATASET_ROOT, BOOT_STATUS try: BOOT_STATUS = "preparing-dataset" # Check if dataset already exists root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore")) images_dir = os.path.join(root, "images") has_images = os.path.isdir(images_dir) and any(os.listdir(images_dir)) if has_images: print("✅ Images already exist, skipping download/extraction") ds_root = root else: print("📥 Downloading and extracting dataset...") ds_root = ensure_dataset_ready() DATASET_ROOT = ds_root if not ds_root: BOOT_STATUS = "dataset-not-prepared" return "❌ Failed to prepare dataset" # Prepare splits if missing splits_dir = os.path.join(ds_root, "splits") has_splits = ( os.path.isfile(os.path.join(splits_dir, "train.json")) or os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json")) ) if not has_splits: os.makedirs(splits_dir, exist_ok=True) from scripts.prepare_polyvore import main as prepare_main os.environ.setdefault("PYTHONWARNINGS", "ignore") import sys argv_bak = sys.argv try: sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "500"] prepare_main() BOOT_STATUS = "ready" return "✅ Dataset and splits prepared successfully!" finally: sys.argv = argv_bak else: BOOT_STATUS = "ready" return "✅ Dataset already prepared (images and splits exist)" except Exception as e: BOOT_STATUS = "error" import traceback return f"❌ Error: {str(e)}\n{traceback.format_exc()}" prepare_dataset_btn.click(fn=prepare_dataset_manual, inputs=[], outputs=prepare_status) # Global Dataset Size Control with gr.Row(): gr.Markdown("#### 🎯 **Global Dataset Size Control**") gr.Markdown("**Note**: Use 'Apply' button to regenerate splits with different size limits.") with gr.Row(): gr.Markdown("#### 📊 **Current Behavior**") gr.Markdown("• **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **500 samples by default**\n• **Training**: Uses 500 samples (ultra-fast training!)\n• **Apply Button**: Regenerates splits with your selected size limit") with gr.Row(): global_dataset_size = gr.Dropdown( choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], value="500", label="Global Dataset Size (Affects Prep + Training)" ) gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)") with gr.Row(): # Apply dataset size button apply_size_btn = gr.Button("🔄 Apply Dataset Size & Regenerate Splits", variant="primary") size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 500 samples (click Apply to regenerate splits)", interactive=False) # Current dataset info gr.Markdown("#### 📊 **Current Dataset Status**") gr.Markdown("• **Full dataset downloaded**: 53,306 outfits (required for system)\n• **Splits generated**: **500 samples by default** (ultra-fast training!)\n• **Training will use**: 500 samples (ultra-fast training!)\n• **Scale up**: Use Apply button to increase to larger sizes") def apply_dataset_size(size: str): """Apply global dataset size and regenerate splits.""" try: if size == "full": return f"✅ Using full dataset ({size}) - no size limit applied" # Call the dataset preparation with size limit import subprocess import os # Set environment variable for dataset size os.environ["DATASET_SIZE_LIMIT"] = size # Check if script exists script_path = "scripts/prepare_polyvore.py" if not os.path.exists(script_path): return f"❌ Script not found: {script_path}" # Regenerate splits with size limit using subprocess cmd = [ "python", script_path, "--root", "/home/user/app/data/Polyvore", "--out", "/home/user/app/data/Polyvore/splits", "--max_samples", size ] print(f"Running command: {' '.join(cmd)}") print(f"Current working directory: {os.getcwd()}") # Run from the correct directory result = subprocess.run(cmd, capture_output=True, text=True, check=False, cwd="/home/user/app") if result.returncode == 0: return f"✅ Successfully regenerated splits with {size} samples limit" else: error_msg = f"❌ Failed to regenerate splits:\n" error_msg += f"Return code: {result.returncode}\n" error_msg += f"STDOUT: {result.stdout}\n" error_msg += f"STDERR: {result.stderr}" return error_msg except Exception as e: return f"❌ Failed to apply dataset size: {str(e)}" apply_size_btn.click(fn=apply_dataset_size, inputs=[global_dataset_size], outputs=[size_status]) with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 📊 Dataset Size Control") gr.Markdown("Start small for testing, increase for production training") dataset_size = gr.Dropdown( choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], value="500", label="Training Dataset Size" ) gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training") with gr.Column(scale=1): gr.Markdown("#### 🖼️ ResNet Item Embedder") # Model architecture resnet_backbone = gr.Dropdown( choices=["resnet50", "resnet101"], value="resnet50", label="Backbone Architecture" ) resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained") resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") # Training parameters resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs") resnet_batch_size = gr.Slider(4, 128, value=4, step=4, label="Batch Size") resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate") resnet_optimizer = gr.Dropdown( choices=["adamw", "adam", "sgd", "rmsprop"], value="adamw", label="Optimizer" ) resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay") resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin") with gr.Column(scale=1): gr.Markdown("#### 🧠 ViT Outfit Encoder") # Model architecture vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers") vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads") vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier") vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") # Training parameters vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs") vit_batch_size = gr.Slider(2, 64, value=4, step=2, label="Batch Size") vit_max_samples = gr.Slider(100, 5000, value=500, step=100, label="Max Training Samples") vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate") vit_optimizer = gr.Dropdown( choices=["adamw", "adam", "sgd", "rmsprop"], value="adamw", label="Optimizer" ) vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay") vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### ⚙️ Advanced Training Settings") # Hardware optimization use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)") channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format") gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping") # Learning rate scheduling warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs") scheduler_type = gr.Dropdown( choices=["cosine", "step", "plateau", "linear"], value="cosine", label="Learning Rate Scheduler" ) early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience") # Training strategy mining_strategy = gr.Dropdown( choices=["semi_hard", "hardest", "random"], value="semi_hard", label="Triplet Mining Strategy" ) augmentation_level = gr.Dropdown( choices=["minimal", "standard", "aggressive"], value="standard", label="Data Augmentation Level" ) seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") with gr.Column(scale=1): gr.Markdown("#### 🚀 Training Control") # Quick training gr.Markdown("**Quick Training (Basic Parameters)**") epochs_res = gr.Slider(1, 50, value=3, step=1, label="ResNet epochs") epochs_vit = gr.Slider(1, 100, value=3, step=1, label="ViT epochs") start_btn = gr.Button("🚀 Start Quick Training", variant="secondary") # Advanced training gr.Markdown("**Advanced Training (Custom Parameters)**") start_advanced_btn = gr.Button("🎯 Start Advanced Training", variant="primary") # Training log train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20) # Status gr.Markdown("**Training Status**") training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False) # Event handlers start_btn.click( fn=start_training_simple, inputs=[dataset_size, epochs_res, epochs_vit], outputs=train_log ) start_advanced_btn.click( fn=start_training_advanced, inputs=[ # Dataset size dataset_size, # ResNet parameters resnet_epochs, resnet_batch_size, resnet_lr, resnet_optimizer, resnet_weight_decay, resnet_triplet_margin, resnet_embedding_dim, resnet_backbone, resnet_use_pretrained, resnet_dropout, # ViT parameters vit_epochs, vit_batch_size, vit_max_samples, vit_lr, vit_optimizer, vit_weight_decay, vit_triplet_margin, vit_embedding_dim, vit_num_layers, vit_num_heads, vit_ff_multiplier, vit_dropout, # Advanced parameters use_mixed_precision, channels_last, gradient_clip, warmup_epochs, scheduler_type, early_stopping_patience, mining_strategy, augmentation_level, seed ], outputs=train_log ) with gr.Tab("📦 Artifact Management"): gr.Markdown("### 🎯 Comprehensive Artifact Management\nManage, package, and upload all system artifacts to Hugging Face Hub.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 📊 Artifact Overview") artifact_overview = gr.JSON(label="System Artifacts", value=get_artifact_overview) refresh_overview = gr.Button("🔄 Refresh Overview") refresh_overview.click(fn=get_artifact_overview, inputs=[], outputs=artifact_overview) gr.Markdown("#### 📦 Create Packages") package_type = gr.Dropdown( choices=["complete", "splits_only", "models_only"], value="complete", label="Package Type" ) create_package_btn = gr.Button("📦 Create Package") package_result = gr.Textbox(label="Package Result", interactive=False) available_packages = gr.JSON(label="Available Packages", value=get_available_packages) create_package_btn.click( fn=create_download_package, inputs=[package_type], outputs=[package_result, available_packages] ) with gr.Column(scale=1): gr.Markdown("#### 🚀 Hugging Face Hub Integration") gr.Markdown("💡 **Pro Tip**: Set `HF_TOKEN` environment variable for automatic uploads after training!") hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_...") hf_username = gr.Textbox(label="Username", placeholder="your-username") with gr.Row(): push_splits_btn = gr.Button("📤 Push Splits", variant="secondary") push_models_btn = gr.Button("📤 Push Models", variant="secondary") push_everything_btn = gr.Button("📤 Push Everything", variant="primary") hf_result = gr.Textbox(label="Upload Result", interactive=False, lines=3) push_splits_btn.click(fn=push_splits_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) push_models_btn.click(fn=push_models_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) push_everything_btn.click(fn=push_everything_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) gr.Markdown("#### 📥 Download Management") individual_files = gr.JSON(label="Individual Files", value=get_individual_files) download_all_btn = gr.Button("📥 Download All as ZIP") download_result = gr.Textbox(label="Download Result", interactive=False) download_all_btn.click(fn=download_all_files, inputs=[], outputs=download_result) with gr.Tab("📈 Status"): gr.Markdown("### 🚦 System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.") status = gr.Textbox(label="Bootstrap Status", value=lambda: BOOT_STATUS) refresh_status = gr.Button("🔄 Refresh Status") refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status) # Model Status gr.Markdown("#### 🤖 Model Status") model_status = gr.JSON(label="Model Loading Status", value=lambda: service.get_model_status()) refresh_models = gr.Button("🔄 Refresh Model Status") refresh_models.click(fn=lambda: service.get_model_status(), inputs=[], outputs=model_status) # System info gr.Markdown("#### 💻 System Information") device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}") resnet_version = gr.Textbox(label="ResNet Version", value=lambda: f"ResNet: {service.resnet_version}") vit_version = gr.Textbox(label="ViT Version", value=lambda: f"ViT: {service.vit_version}") # Health check gr.Markdown("#### 🏥 Health Check") health_btn = gr.Button("🔍 Check Health") health_status = gr.Textbox(label="Health Status", value="Click to check") def check_health(): try: health = app.get("/health") return f"✅ System Healthy - {health}" except Exception as e: return f"❌ Health Check Failed: {str(e)}" health_btn.click(fn=check_health, inputs=[], outputs=health_status) try: # Mount Gradio onto FastAPI root path (disable SSR to avoid stray port fetches) demo.queue() app = gr.mount_gradio_app(app, demo, path="/") except Exception: # In case mounting fails in certain runners, we still want FastAPI to be available pass # Mount static files for direct artifact download export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) try: app.mount("/files", StaticFiles(directory=export_dir), name="files") except Exception: pass if __name__ == "__main__": # Local/Space run demo.queue().launch(ssr_mode=False)