Spaces:
Paused
Paused
| import os | |
| import base64 | |
| import io | |
| import tempfile | |
| from typing import List, Optional, Any, Dict | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| import torch | |
| from fastapi import FastAPI, Header, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from typing import List as TypingList | |
| from PIL import Image | |
| from starlette.staticfiles import StaticFiles | |
| import threading | |
| import json | |
| from inference import InferenceService | |
| from utils.data_fetch import ensure_dataset_ready | |
| from utils.tag_system import get_all_tag_options, validate_tags, TagProcessor | |
| from utils.image_utils import ( | |
| load_images_from_files, | |
| load_image_from_bytes, | |
| load_image_from_url, | |
| is_image_file, | |
| get_supported_formats, | |
| get_supported_extensions, | |
| ensure_rgb_image | |
| ) | |
| # Global state | |
| BOOT_STATUS = "starting" | |
| DATASET_ROOT: Optional[str] = None | |
| def get_artifact_overview(): | |
| """Get comprehensive artifact overview.""" | |
| try: | |
| from utils.artifact_manager import create_artifact_manager | |
| manager = create_artifact_manager() | |
| return manager.get_artifact_summary() | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def export_artifact_summary(): | |
| """Export artifact summary as JSON file.""" | |
| try: | |
| from utils.artifact_manager import create_artifact_manager | |
| manager = create_artifact_manager() | |
| summary = manager.get_artifact_summary() | |
| # Save to exports directory | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| summary_path = os.path.join(export_dir, "system_summary.json") | |
| with open(summary_path, 'w') as f: | |
| json.dump(summary, f, indent=2) | |
| return summary_path | |
| except Exception as e: | |
| return None | |
| def create_download_package(package_type: str): | |
| """Create a downloadable package.""" | |
| try: | |
| from utils.artifact_manager import create_artifact_manager | |
| manager = create_artifact_manager() | |
| # Extract package type from the dropdown choice | |
| if "complete" in package_type: | |
| pkg_type = "complete" | |
| elif "splits_only" in package_type: | |
| pkg_type = "splits_only" | |
| elif "models_only" in package_type: | |
| pkg_type = "models_only" | |
| else: | |
| return f"β Invalid package type: {package_type}", get_available_packages() | |
| package_path = manager.create_download_package(pkg_type) | |
| package_name = os.path.basename(package_path) | |
| return f"β Package created: {package_name}", get_available_packages() | |
| except Exception as e: | |
| return f"β Failed to create package: {e}", get_available_packages() | |
| def get_available_packages(): | |
| """Get list of available packages.""" | |
| try: | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| packages = [] | |
| if os.path.exists(export_dir): | |
| for file in os.listdir(export_dir): | |
| if file.endswith((".tar.gz", ".zip")): | |
| file_path = os.path.join(export_dir, file) | |
| packages.append({ | |
| "name": file, | |
| "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2), | |
| "path": file_path, | |
| "url": f"/files/{file}" | |
| }) | |
| return {"packages": packages} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def get_individual_files(): | |
| """Get list of individual downloadable files.""" | |
| try: | |
| from utils.artifact_manager import create_artifact_manager | |
| manager = create_artifact_manager() | |
| files = manager.get_downloadable_files() | |
| # Group by category | |
| categorized = {} | |
| for file in files: | |
| category = file["category"] | |
| if category not in categorized: | |
| categorized[category] = [] | |
| categorized[category].append(file) | |
| return categorized | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def download_all_files(): | |
| """Download all files as a ZIP archive.""" | |
| try: | |
| from utils.artifact_manager import create_artifact_manager | |
| manager = create_artifact_manager() | |
| files = manager.get_downloadable_files() | |
| # Create ZIP with all files | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| zip_path = os.path.join(export_dir, "all_artifacts.zip") | |
| import zipfile | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for file in files: | |
| if os.path.exists(file["path"]): | |
| zipf.write(file["path"], file["name"]) | |
| return zip_path | |
| except Exception as e: | |
| return None | |
| def get_training_status(): | |
| """Get current training status from the monitor.""" | |
| try: | |
| from ui.monitor import create_monitor | |
| monitor = create_monitor() | |
| status = monitor.get_status() | |
| return status if status else {"status": "no-training"} | |
| except Exception as e: | |
| return {"status": "error", "error": str(e)} | |
| def push_splits_to_hf(token, username): | |
| """Push splits to HF Hub.""" | |
| if not token or not username: | |
| return "β Please provide HF token and username" | |
| try: | |
| from utils.hf_utils import HFModelManager | |
| hf = HFModelManager(token=token, username=username) | |
| result = hf.upload_model("splits", "Dressify-Helper") | |
| if result.get("success"): | |
| return f"β Successfully uploaded splits to {username}/Dressify-Helper" | |
| else: | |
| return f"β Failed to upload splits: {result.get('error', 'Unknown error')}" | |
| except Exception as e: | |
| return f"β Upload failed: {e}" | |
| def push_models_to_hf(token, username): | |
| """Push models to HF Hub.""" | |
| if not token or not username: | |
| return "β Please provide HF token and username" | |
| try: | |
| from utils.hf_utils import HFModelManager | |
| hf = HFModelManager(token=token, username=username) | |
| result = hf.upload_model("models", "dressify-models") | |
| if result.get("success"): | |
| return f"β Successfully uploaded models to {username}/dressify-models" | |
| else: | |
| return f"β Failed to upload models: {result.get('error', 'Unknown error')}" | |
| except Exception as e: | |
| return f"β Upload failed: {e}" | |
| def push_everything_to_hf(token, username): | |
| """Push everything to HF Hub.""" | |
| if not token or not username: | |
| return "β Please provide HF token and username" | |
| try: | |
| from utils.hf_utils import HFModelManager | |
| hf = HFModelManager(token=token, username=username) | |
| result = hf.upload_model("everything", "dressify-complete") | |
| if result.get("success"): | |
| return f"β Successfully uploaded everything to HF Hub" | |
| else: | |
| return f"β Failed to upload everything: {result.get('error', 'Unknown error')}" | |
| except Exception as e: | |
| return f"β Upload failed: {e}" | |
| AI_API_KEY = os.getenv("AI_API_KEY") | |
| def require_api_key(x_api_key: Optional[str]): | |
| if AI_API_KEY and x_api_key != AI_API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| class EmbedRequest(BaseModel): | |
| image_urls: Optional[List[str]] = None | |
| images_base64: Optional[List[str]] = None | |
| class Item(BaseModel): | |
| id: str | |
| embedding: Optional[List[float]] = None | |
| category: Optional[str] = None | |
| image_url: Optional[str] = None | |
| image_base64: Optional[str] = None # Base64 encoded image | |
| class ComposeRequest(BaseModel): | |
| items: List[Item] | |
| context: Optional[Dict[str, Any]] = None | |
| # Expanded tag fields for better API usability | |
| occasion: Optional[str] = "casual" | |
| weather: Optional[str] = "any" | |
| style: Optional[str] = "casual" | |
| outfit_style: Optional[str] = None # Alias for style | |
| num_outfits: Optional[int] = 5 | |
| # Optional tags | |
| color_preference: Optional[str] = None | |
| fit_preference: Optional[str] = None | |
| material_preference: Optional[str] = None | |
| season: Optional[str] = None | |
| time_of_day: Optional[str] = None | |
| budget: Optional[str] = None | |
| personal_style: Optional[str] = None | |
| age_group: Optional[str] = None | |
| gender: Optional[str] = None | |
| class Config: | |
| extra = "allow" # Allow additional fields for flexibility | |
| app = FastAPI(title="Dressify Recommendation Service") | |
| service = InferenceService() | |
| # Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background | |
| BOOT_STATUS = "idle" | |
| DATASET_ROOT: Optional[str] = None | |
| def _background_bootstrap(): | |
| global BOOT_STATUS | |
| global DATASET_ROOT | |
| try: | |
| # Only check if dataset exists - DO NOT prepare it automatically | |
| root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore")) | |
| images_dir = os.path.join(root, "images") | |
| splits_dir = os.path.join(root, "splits") | |
| # Check if dataset already exists | |
| has_images = os.path.isdir(images_dir) and any(os.listdir(images_dir)) | |
| has_splits = ( | |
| os.path.isfile(os.path.join(splits_dir, "train.json")) or | |
| os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json")) | |
| ) | |
| if has_images and has_splits: | |
| print("β Dataset and splits already prepared") | |
| DATASET_ROOT = root | |
| BOOT_STATUS = "ready" | |
| elif has_images: | |
| print("β Dataset images exist, but splits may be missing (use Advanced Training to prepare)") | |
| DATASET_ROOT = root | |
| BOOT_STATUS = "ready" | |
| else: | |
| print("βΉοΈ Dataset not prepared. Use 'Download & Prepare Dataset' button in Advanced Training tab if needed.") | |
| DATASET_ROOT = None | |
| BOOT_STATUS = "ready" # System is ready, just dataset not prepared | |
| # NO automatic training - models should be pre-trained or trained manually via UI | |
| BOOT_STATUS = "ready" | |
| except Exception as e: | |
| BOOT_STATUS = f"error: {e}" | |
| threading.Thread(target=_background_bootstrap, daemon=True).start() | |
| def health() -> dict: | |
| return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version} | |
| def get_tags() -> dict: | |
| """ | |
| Get all available tag options for API integration. | |
| Returns comprehensive list of all tag categories and their valid values. | |
| """ | |
| return { | |
| "tag_categories": get_all_tag_options(), | |
| "description": "Available tags for personalized outfit recommendations", | |
| "usage": { | |
| "primary_tags": ["occasion", "weather", "style"], | |
| "optional_tags": ["color_preference", "fit_preference", "material_preference", | |
| "season", "time_of_day", "budget", "personal_style", | |
| "age_group", "gender"] | |
| } | |
| } | |
| def get_image_formats() -> dict: | |
| """ | |
| Get all supported image formats for API integration. | |
| """ | |
| return { | |
| "supported_formats": get_supported_formats(), | |
| "supported_extensions": get_supported_extensions(), | |
| "description": "All major image formats are supported including JPG, PNG, WEBP, GIF, BMP, TIFF, and more", | |
| "note": "Images are automatically converted to RGB mode for model processing" | |
| } | |
| async def compose_with_upload( | |
| files: TypingList[UploadFile] = File(...), | |
| occasion: str = Form("casual"), | |
| weather: str = Form("any"), | |
| style: str = Form("casual"), | |
| num_outfits: int = Form(5), | |
| color_preference: Optional[str] = Form(None), | |
| fit_preference: Optional[str] = Form(None), | |
| material_preference: Optional[str] = Form(None), | |
| season: Optional[str] = Form(None), | |
| time_of_day: Optional[str] = Form(None), | |
| budget: Optional[str] = Form(None), | |
| personal_style: Optional[str] = Form(None), | |
| x_api_key: Optional[str] = Header(None) | |
| ) -> dict: | |
| """ | |
| Generate outfit recommendations from uploaded image files. | |
| This endpoint accepts multipart/form-data with: | |
| - files: Image files (JPG, PNG, WEBP, GIF, BMP, TIFF, etc.) | |
| - All tag parameters as form fields | |
| Returns personalized outfit recommendations. | |
| """ | |
| require_api_key(x_api_key) | |
| if not files or len(files) < 2: | |
| raise HTTPException(status_code=400, detail="At least 2 images required for outfit recommendations") | |
| # Load images from uploaded files | |
| items = [] | |
| errors = [] | |
| for i, file in enumerate(files): | |
| try: | |
| # Read file content | |
| contents = await file.read() | |
| # Load image from bytes | |
| img = load_image_from_bytes(contents, convert_to_rgb=True, raise_on_error=False) | |
| if img is None: | |
| errors.append(f"Failed to load image from file {file.filename}") | |
| continue | |
| items.append({ | |
| "id": f"item_{i}", | |
| "image": img, | |
| "category": None # Will be auto-detected | |
| }) | |
| except Exception as e: | |
| errors.append(f"Error processing file {file.filename}: {str(e)}") | |
| if len(items) < 2: | |
| error_msg = f"Not enough valid images. Need at least 2, got {len(items)}." | |
| if errors: | |
| error_msg += f" Errors: {', '.join(errors[:3])}" | |
| raise HTTPException(status_code=400, detail=error_msg) | |
| # Build context | |
| context = { | |
| "occasion": occasion, | |
| "weather": weather, | |
| "style": style, | |
| "outfit_style": style, | |
| "num_outfits": num_outfits | |
| } | |
| # Add optional tags | |
| if color_preference and color_preference != "None": | |
| context["color_preference"] = color_preference | |
| if fit_preference and fit_preference != "None": | |
| context["fit_preference"] = fit_preference | |
| if material_preference and material_preference != "None": | |
| context["material_preference"] = material_preference | |
| if season and season != "None": | |
| context["season"] = season | |
| if time_of_day and time_of_day != "None": | |
| context["time_of_day"] = time_of_day | |
| if budget and budget != "None": | |
| context["budget"] = budget | |
| if personal_style and personal_style != "None": | |
| context["personal_style"] = personal_style | |
| # Validate tags | |
| is_valid, tag_errors = validate_tags(context) | |
| if not is_valid: | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": "Invalid tags provided", | |
| "errors": tag_errors, | |
| "valid_tag_options": get_all_tag_options() | |
| } | |
| ) | |
| # Generate recommendations | |
| try: | |
| outfits = service.compose_outfits(items, context=context) | |
| # Check for errors | |
| if outfits and isinstance(outfits, list) and len(outfits) > 0: | |
| if isinstance(outfits[0], dict) and "error" in outfits[0]: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Recommendation generation failed", | |
| "details": outfits[0].get("details", []), | |
| "message": outfits[0].get("message", "Unknown error") | |
| } | |
| ) | |
| return { | |
| "outfits": outfits, | |
| "version": service.vit_version, | |
| "tags_processed": True, | |
| "context_used": context, | |
| "items_processed": len(items), | |
| "warnings": errors if errors else None | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Internal server error", | |
| "message": str(e), | |
| "model_status": service.get_model_status() | |
| } | |
| ) | |
| def validate_request_tags(tags: Dict[str, Any], x_api_key: Optional[str] = Header(None)) -> dict: | |
| """ | |
| Validate tag values before making a recommendation request. | |
| Useful for API clients to check tag validity. | |
| """ | |
| require_api_key(x_api_key) | |
| is_valid, errors = validate_tags(tags) | |
| return { | |
| "valid": is_valid, | |
| "errors": errors if not is_valid else [], | |
| "validated_tags": tags if is_valid else None | |
| } | |
| def model_status() -> dict: | |
| """Get detailed model loading status.""" | |
| return service.get_model_status() | |
| def reload_models() -> dict: | |
| """Force reload models - useful for debugging.""" | |
| try: | |
| service.force_reload_models() | |
| return {"status": "success", "message": "Models reloaded successfully"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| def test_recommend() -> dict: | |
| """Test recommendation with dummy data to debug the issue.""" | |
| try: | |
| # Create dummy items for testing | |
| dummy_items = [ | |
| {"id": "test_1", "image": None, "category": "shirt"}, | |
| {"id": "test_2", "image": None, "category": "pants"}, | |
| {"id": "test_3", "image": None, "category": "shoes"} | |
| ] | |
| # Try to get recommendations | |
| result = service.compose_outfits(dummy_items, {"num_outfits": 1}) | |
| return { | |
| "status": "success", | |
| "model_status": service.get_model_status(), | |
| "result": result, | |
| "result_length": len(result) if result else 0 | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e), "model_status": service.get_model_status()} | |
| def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict: | |
| """ | |
| Generate embeddings for images with comprehensive format support. | |
| Supports JPG, PNG, WEBP, GIF, BMP, TIFF, and other major formats. | |
| """ | |
| require_api_key(x_api_key) | |
| images: List[Image.Image] = [] | |
| errors = [] | |
| # Load from URLs | |
| if req.image_urls: | |
| for url in req.image_urls: | |
| img = load_image_from_url(url, timeout=20, convert_to_rgb=True, raise_on_error=False) | |
| if img is not None: | |
| images.append(img) | |
| else: | |
| errors.append(f"Failed to load image from URL: {url}") | |
| # Load from base64 | |
| if req.images_base64: | |
| for b64 in req.images_base64: | |
| try: | |
| image_bytes = base64.b64decode(b64) | |
| img = load_image_from_bytes(image_bytes, convert_to_rgb=True, raise_on_error=False) | |
| if img is not None: | |
| images.append(img) | |
| else: | |
| errors.append("Failed to load image from base64") | |
| except Exception as e: | |
| errors.append(f"Error decoding base64 image: {str(e)}") | |
| if not images: | |
| error_msg = "No images provided or all images failed to load" | |
| if errors: | |
| error_msg += f". Errors: {', '.join(errors[:3])}" | |
| raise HTTPException(status_code=400, detail=error_msg) | |
| # Ensure all images are RGB | |
| images = [ensure_rgb_image(img) for img in images] | |
| embs = service.embed_images(images) | |
| return { | |
| "embeddings": [e.tolist() for e in embs], | |
| "model_version": service.resnet_version, | |
| "images_loaded": len(images), | |
| "errors": errors if errors else None | |
| } | |
| def compose(req: ComposeRequest, x_api_key: Optional[str] = Header(None)) -> dict: | |
| """ | |
| Generate personalized outfit recommendations with expanded tag support. | |
| Supports both legacy context dict format and new tag-based format. | |
| Tags are processed and prioritized automatically. | |
| Items can provide: | |
| - image_url: URL to image (will be downloaded) | |
| - image_base64: Base64 encoded image | |
| - embedding: Pre-computed embedding (skips ResNet) | |
| - category: Item category (optional, auto-detected if not provided) | |
| """ | |
| require_api_key(x_api_key) | |
| # Build items with image loading | |
| items = [] | |
| errors = [] | |
| for it in req.items: | |
| item_dict = { | |
| "id": it.id, | |
| "embedding": np.array(it.embedding, dtype=np.float32) if it.embedding is not None else None, | |
| "category": it.category, | |
| } | |
| # Load image from URL if provided | |
| if it.image_url: | |
| img = load_image_from_url(it.image_url, timeout=20, convert_to_rgb=True, raise_on_error=False) | |
| if img is not None: | |
| item_dict["image"] = img | |
| else: | |
| errors.append(f"Failed to load image from URL for item {it.id}: {it.image_url}") | |
| # Load image from base64 if provided | |
| elif it.image_base64: | |
| try: | |
| image_bytes = base64.b64decode(it.image_base64) | |
| img = load_image_from_bytes(image_bytes, convert_to_rgb=True, raise_on_error=False) | |
| if img is not None: | |
| item_dict["image"] = img | |
| else: | |
| errors.append(f"Failed to load image from base64 for item {it.id}") | |
| except Exception as e: | |
| errors.append(f"Error decoding base64 image for item {it.id}: {str(e)}") | |
| # If no image and no embedding, skip this item | |
| if item_dict.get("image") is None and item_dict.get("embedding") is None: | |
| errors.append(f"Item {it.id} has no image or embedding - skipping") | |
| continue | |
| items.append(item_dict) | |
| if not items: | |
| error_msg = "No valid items provided. All items failed to load or have no images/embeddings." | |
| if errors: | |
| error_msg += f" Errors: {', '.join(errors[:5])}" | |
| raise HTTPException(status_code=400, detail=error_msg) | |
| # Build context from request | |
| context = req.context or {} | |
| # Add explicit tag fields if provided (takes precedence over context dict) | |
| if req.occasion: | |
| context["occasion"] = req.occasion | |
| if req.weather: | |
| context["weather"] = req.weather | |
| if req.style: | |
| context["style"] = req.style | |
| if req.outfit_style: | |
| context["outfit_style"] = req.outfit_style | |
| context["style"] = req.outfit_style # Also set style for consistency | |
| if req.num_outfits: | |
| context["num_outfits"] = req.num_outfits | |
| # Add optional tags | |
| optional_tags = { | |
| "color_preference": req.color_preference, | |
| "fit_preference": req.fit_preference, | |
| "material_preference": req.material_preference, | |
| "season": req.season, | |
| "time_of_day": req.time_of_day, | |
| "budget": req.budget, | |
| "personal_style": req.personal_style, | |
| "age_group": req.age_group, | |
| "gender": req.gender, | |
| } | |
| for tag_name, tag_value in optional_tags.items(): | |
| if tag_value: | |
| context[tag_name] = tag_value | |
| # Validate tags | |
| is_valid, tag_errors = validate_tags(context) | |
| if not is_valid: | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": "Invalid tags provided", | |
| "errors": tag_errors, | |
| "valid_tag_options": get_all_tag_options() | |
| } | |
| ) | |
| # Generate recommendations | |
| try: | |
| outfits = service.compose_outfits(items, context=context) | |
| # Check if compose_outfits returned an error | |
| if outfits and isinstance(outfits, list) and len(outfits) > 0: | |
| if isinstance(outfits[0], dict) and "error" in outfits[0]: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Recommendation generation failed", | |
| "details": outfits[0].get("details", []), | |
| "message": outfits[0].get("message", "Unknown error") | |
| } | |
| ) | |
| return { | |
| "outfits": outfits, | |
| "version": service.vit_version, | |
| "tags_processed": True, | |
| "context_used": context, | |
| "items_processed": len(items), | |
| "warnings": errors if errors else None | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Internal server error during recommendation generation", | |
| "message": str(e), | |
| "model_status": service.get_model_status() | |
| } | |
| ) | |
| def artifacts() -> dict: | |
| # list exported model artifacts for download | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| files = [] | |
| if os.path.isdir(export_dir): | |
| for fn in os.listdir(export_dir): | |
| if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")): | |
| files.append({ | |
| "name": fn, | |
| "path": f"{export_dir}/{fn}", | |
| "url": f"/files/{fn}", | |
| }) | |
| return {"artifacts": files} | |
| # --------- Gradio UI --------- | |
| def _load_images_from_files(files: List[str]) -> List[Image.Image]: | |
| """ | |
| Load images from file paths with comprehensive format support. | |
| Supports JPG, PNG, WEBP, GIF, BMP, TIFF, and other major formats. | |
| """ | |
| return load_images_from_files(files, convert_to_rgb=True, skip_errors=True) | |
| def gradio_embed(files: List[str]): | |
| if not files: | |
| return "[]" | |
| images = _load_images_from_files(files) | |
| if not images: | |
| return "[]" | |
| embs = service.embed_images(images) | |
| return str([e.tolist() for e in embs]) | |
| def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(245, 245, 245)) -> Image.Image: | |
| if not imgs: | |
| return Image.new("RGB", (1, height), color=bg) | |
| resized = [] | |
| for im in imgs: | |
| if im.mode != "RGB": | |
| im = im.convert("RGB") | |
| w, h = im.size | |
| scale = height / float(h) | |
| nw = max(1, int(w * scale)) | |
| resized.append(im.resize((nw, height))) | |
| total_w = sum(im.size[0] for im in resized) + pad * (len(resized) + 1) | |
| out = Image.new("RGB", (total_w, height + 2 * pad), color=bg) | |
| x = pad | |
| for im in resized: | |
| out.paste(im, (x, pad)) | |
| x += im.size[0] + pad | |
| return out | |
| def gradio_recommend( | |
| files: List[str], | |
| occasion: str, | |
| weather: str, | |
| num_outfits: int, | |
| outfit_style: str = "casual", | |
| color_preference: str = None, | |
| fit_preference: str = None, | |
| material_preference: str = None, | |
| season: str = None, | |
| time_of_day: str = None, | |
| budget: str = None, | |
| personal_style: str = None | |
| ): | |
| # Check model status first | |
| model_status = service.get_model_status() | |
| if not model_status["can_recommend"]: | |
| error_msg = "β Models not ready for recommendations!\n\n" | |
| error_msg += "**Model Status:**\n" | |
| error_msg += f"- ResNet: {'β Loaded' if model_status['resnet_loaded'] else 'β Not loaded'}\n" | |
| error_msg += f"- ViT: {'β Loaded' if model_status['vit_loaded'] else 'β Not loaded'}\n\n" | |
| error_msg += "**Errors:**\n" | |
| for error in model_status["errors"]: | |
| error_msg += f"- {error}\n\n" | |
| error_msg += "**Solution:**\n" | |
| error_msg += "Please train the models first using the 'Simple Training' or 'Advanced Training' tabs, or ensure trained checkpoints are available." | |
| return [], {"error": error_msg, "model_status": model_status} | |
| # Return stitched outfit images and a JSON with details | |
| if not files: | |
| return [], {"error": "No files uploaded"} | |
| # Enhanced debug: Log detailed file information for API troubleshooting | |
| print(f"π DEBUG: gradio_recommend called with {len(files)} files") | |
| file_info = [] | |
| for i, f in enumerate(files): | |
| if isinstance(f, str): | |
| from pathlib import Path | |
| path = Path(f) | |
| file_size = path.stat().st_size if path.exists() else 0 | |
| file_info.append(f"File {i+1}: {path.name} ({file_size} bytes)") | |
| print(f"π DEBUG: File {i+1}: path={f}, exists={path.exists()}, size={file_size}, name={path.name}") | |
| else: | |
| file_info.append(f"Type: {type(f).__name__}, Value: {str(f)[:100]}") | |
| print(f"π DEBUG: File {i+1}: type={type(f).__name__}, value={str(f)[:100]}") | |
| try: | |
| print(f"π DEBUG: Attempting to load {len(files)} images...") | |
| images = _load_images_from_files(files) | |
| print(f"π DEBUG: Successfully loaded {len(images)} images from {len(files)} files") | |
| if not images: | |
| error_msg = "Could not load images from uploaded files.\n\n" | |
| error_msg += f"Files received: {len(files)}\n" | |
| error_msg += f"File details: {', '.join(file_info[:3])}\n\n" | |
| error_msg += f"Supported formats: {', '.join(get_supported_extensions())}\n" | |
| error_msg += "Please ensure files are valid image files (JPG, PNG, WEBP, GIF, BMP, TIFF, etc.)" | |
| print(f"π DEBUG: ERROR - No images loaded. Files: {len(files)}, Images: {len(images)}") | |
| return [], {"error": error_msg, "files_received": len(files), "file_info": file_info[:5]} | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error processing files: {str(e)}\n\n" | |
| error_msg += f"Files received: {len(files)}\n" | |
| error_msg += f"File details: {', '.join(file_info[:3])}\n\n" | |
| error_msg += "Please check that files are valid image files." | |
| print(f"π DEBUG: EXCEPTION during image loading: {e}") | |
| traceback.print_exc() | |
| return [], {"error": error_msg, "exception": str(e), "files_received": len(files)} | |
| # Tag normalization: Map frontend values to backend values | |
| # This handles variations and synonyms from different frontend implementations | |
| def normalize_tag_value(tag_name: str, value: str) -> str: | |
| """Normalize tag values to match backend expectations""" | |
| if not value or value == "None" or value is None: | |
| return None | |
| value_lower = value.lower().strip() | |
| # Color preference mappings - normalize variations to standard values | |
| if tag_name == "color_preference": | |
| color_mappings = { | |
| "monochrome": "monochromatic", # Frontend uses "monochrome", backend uses "monochromatic" | |
| "mono": "monochromatic", | |
| "single_color": "monochromatic", | |
| "one_color": "monochromatic", | |
| } | |
| normalized = color_mappings.get(value_lower, value) | |
| # Validate against allowed values | |
| allowed_colors = ["neutral", "monochromatic", "complementary", "bold", "subtle", | |
| "bright", "muted", "pastel", "dark", "light", "earth_tones", | |
| "jewel_tones", "black_white", "navy_white", "colorful", "minimal_color"] | |
| if normalized in allowed_colors: | |
| return normalized | |
| return value # Return original if not in allowed list | |
| # Fit preference mappings | |
| if tag_name == "fit_preference": | |
| fit_mappings = { | |
| "slim": "fitted", | |
| "tight_fit": "fitted", | |
| "baggy": "loose", | |
| "wide": "loose", | |
| } | |
| normalized = fit_mappings.get(value_lower, value) | |
| # Validate against allowed values | |
| allowed_fits = ["fitted", "loose", "oversized", "relaxed", "comfortable", | |
| "structured", "flowy", "tailored", "athletic_fit", "regular_fit"] | |
| if normalized in allowed_fits: | |
| return normalized | |
| return value | |
| # Season mappings - both "fall" and "autumn" are valid, keep as-is | |
| if tag_name == "season": | |
| # Both "fall" and "autumn" are in the Literal type, so no normalization needed | |
| return value | |
| # Return original value if no mapping found | |
| return value | |
| # Normalize all tag values | |
| occasion = normalize_tag_value("occasion", occasion) or occasion | |
| weather = normalize_tag_value("weather", weather) or weather | |
| outfit_style = normalize_tag_value("outfit_style", outfit_style) or outfit_style | |
| color_preference = normalize_tag_value("color_preference", color_preference) if color_preference else None | |
| fit_preference = normalize_tag_value("fit_preference", fit_preference) if fit_preference else None | |
| material_preference = normalize_tag_value("material_preference", material_preference) if material_preference else None | |
| season = normalize_tag_value("season", season) if season else None | |
| time_of_day = normalize_tag_value("time_of_day", time_of_day) if time_of_day else None | |
| budget = normalize_tag_value("budget", budget) if budget else None | |
| personal_style = normalize_tag_value("personal_style", personal_style) if personal_style else None | |
| # Build comprehensive context with all tags | |
| context = { | |
| "occasion": occasion, | |
| "weather": weather, | |
| "style": outfit_style, | |
| "outfit_style": outfit_style, # Backward compatibility | |
| "num_outfits": int(num_outfits) | |
| } | |
| # Add optional tags if provided | |
| if color_preference and color_preference != "None": | |
| context["color_preference"] = color_preference | |
| if fit_preference and fit_preference != "None": | |
| context["fit_preference"] = fit_preference | |
| if material_preference and material_preference != "None": | |
| context["material_preference"] = material_preference | |
| if season and season != "None": | |
| context["season"] = season | |
| if time_of_day and time_of_day != "None": | |
| context["time_of_day"] = time_of_day | |
| if budget and budget != "None": | |
| context["budget"] = budget | |
| if personal_style and personal_style != "None": | |
| context["personal_style"] = personal_style | |
| # Build items that allow on-the-fly embedding in service | |
| items = [ | |
| {"id": f"item_{i}", "image": images[i], "category": None} | |
| for i in range(len(images)) | |
| ] | |
| print(f"π DEBUG: Calling compose_outfits with {len(items)} items, context={context}") | |
| res = service.compose_outfits(items, context=context) | |
| print(f"π DEBUG: compose_outfits returned {len(res) if res else 0} results") | |
| if res: | |
| print(f"π DEBUG: First result type: {type(res[0])}, keys: {res[0].keys() if isinstance(res[0], dict) else 'N/A'}") | |
| # Check if compose_outfits returned an error | |
| if res and isinstance(res[0], dict) and "error" in res[0]: | |
| print(f"π DEBUG: Error in compose_outfits result: {res[0]}") | |
| return [], res[0] | |
| # Prepare stitched previews - save to temp files for Gradio API compatibility | |
| strips: List[str] = [] # Changed to List[str] for file paths | |
| print(f"π DEBUG: Preparing stitched previews for {len(res)} outfits...") | |
| for i, r in enumerate(res): | |
| idxs = [] | |
| item_ids = r.get("item_ids", []) | |
| print(f"π DEBUG: Outfit {i+1}: item_ids={item_ids}") | |
| for iid in item_ids: | |
| try: | |
| idx = int(str(iid).split("_")[-1]) | |
| idxs.append(idx) | |
| print(f"π DEBUG: Mapped {iid} -> index {idx}") | |
| except Exception as e: | |
| print(f"π DEBUG: Failed to parse {iid}: {e}") | |
| continue | |
| imgs = [images[i] for i in idxs if 0 <= i < len(images)] | |
| print(f"π DEBUG: Extracted {len(imgs)} images from indices {idxs}") | |
| if imgs: | |
| strip = _stitch_strip(imgs) | |
| print(f"π DEBUG: Created stitched image: {strip.size}") | |
| # Save to temporary file (Gradio will convert to URL) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='/tmp') | |
| strip.save(temp_file.name, 'PNG') | |
| temp_file.close() | |
| strips.append(temp_file.name) # Return file path instead of PIL Image | |
| print(f"π DEBUG: Saved to temp file: {temp_file.name}") | |
| else: | |
| print(f"β οΈ DEBUG: No images extracted for outfit {i+1}") | |
| print(f"π DEBUG: Returning {len(strips)} stitched image file paths and {len(res)} outfit results") | |
| return strips, {"outfits": res} | |
| def start_training_advanced( | |
| # Dataset size | |
| dataset_size: str, | |
| # ResNet parameters | |
| resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str, | |
| resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int, | |
| resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float, | |
| # ViT parameters | |
| vit_epochs: int, vit_batch_size: int, vit_max_samples: int, vit_lr: float, vit_optimizer: str, | |
| vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int, | |
| vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float, | |
| # Advanced parameters | |
| use_mixed_precision: bool, channels_last: bool, gradient_clip: float, | |
| warmup_epochs: int, scheduler_type: str, early_stopping_patience: int, | |
| mining_strategy: str, augmentation_level: str, seed: int | |
| ): | |
| """Start advanced training with custom parameters.""" | |
| # Use global dataset size if not specified | |
| if not dataset_size or dataset_size == "full": | |
| dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000") | |
| if not DATASET_ROOT: | |
| return "β Dataset not ready. Please wait for bootstrap to complete." | |
| log_message = "π Advanced training started with custom parameters! Check the log below for progress." | |
| def _runner(): | |
| nonlocal log_message | |
| try: | |
| import subprocess | |
| import json | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| # Create custom config files | |
| resnet_config = { | |
| "model": { | |
| "backbone": resnet_backbone, | |
| "embedding_dim": resnet_embedding_dim, | |
| "pretrained": resnet_use_pretrained, | |
| "dropout": resnet_dropout | |
| }, | |
| "training": { | |
| "batch_size": resnet_batch_size, | |
| "epochs": resnet_epochs, | |
| "lr": resnet_lr, | |
| "weight_decay": resnet_weight_decay, | |
| "triplet_margin": resnet_triplet_margin, | |
| "optimizer": resnet_optimizer, | |
| "scheduler": scheduler_type, | |
| "warmup_epochs": warmup_epochs, | |
| "early_stopping_patience": early_stopping_patience, | |
| "use_amp": use_mixed_precision, | |
| "channels_last": channels_last, | |
| "gradient_clip": gradient_clip | |
| }, | |
| "data": { | |
| "image_size": 224, | |
| "augmentation_level": augmentation_level | |
| }, | |
| "advanced": { | |
| "mining_strategy": mining_strategy, | |
| "seed": seed | |
| } | |
| } | |
| vit_config = { | |
| "model": { | |
| "embedding_dim": vit_embedding_dim, | |
| "num_layers": vit_num_layers, | |
| "num_heads": vit_num_heads, | |
| "ff_multiplier": vit_ff_multiplier, | |
| "dropout": vit_dropout | |
| }, | |
| "training": { | |
| "batch_size": vit_batch_size, | |
| "epochs": vit_epochs, | |
| "lr": vit_lr, | |
| "weight_decay": vit_weight_decay, | |
| "triplet_margin": vit_triplet_margin, | |
| "optimizer": vit_optimizer, | |
| "scheduler": scheduler_type, | |
| "warmup_epochs": warmup_epochs, | |
| "early_stopping_patience": early_stopping_patience, | |
| "use_amp": use_mixed_precision | |
| }, | |
| "advanced": { | |
| "mining_strategy": mining_strategy, | |
| "seed": seed | |
| } | |
| } | |
| # Save configs | |
| with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f: | |
| json.dump(resnet_config, f, indent=2) | |
| with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f: | |
| json.dump(vit_config, f, indent=2) | |
| # Train ResNet with custom parameters | |
| log_message = f"π Starting ResNet training with custom parameters...\n" | |
| log_message += f"Dataset Size: {dataset_size} samples\n" | |
| log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n" | |
| log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n" | |
| log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n" | |
| # Add dataset size limit if not full | |
| dataset_args = [] | |
| if dataset_size != "full": | |
| dataset_args = ["--max_samples", dataset_size] | |
| resnet_cmd = [ | |
| "python", "training/train_resnet.py", | |
| "--data_root", DATASET_ROOT, | |
| "--epochs", str(resnet_epochs), | |
| "--batch_size", str(resnet_batch_size), | |
| "--lr", str(resnet_lr), | |
| "--weight_decay", str(resnet_weight_decay), | |
| "--triplet_margin", str(resnet_triplet_margin), | |
| "--embedding_dim", str(resnet_embedding_dim), | |
| "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth") | |
| ] + dataset_args | |
| if resnet_backbone != "resnet50": | |
| resnet_cmd.extend(["--backbone", resnet_backbone]) | |
| result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False) | |
| if result.returncode == 0: | |
| log_message += "β ResNet training completed successfully!\n" | |
| log_message += f"π ResNet Output:\n{result.stdout}\n\n" | |
| else: | |
| log_message += f"β ResNet training failed: {result.stderr}\n\n" | |
| return log_message | |
| # Wait a moment for file system sync and ensure ResNet is fully saved | |
| import time | |
| time.sleep(3) | |
| log_message += "β³ Waiting for ResNet checkpoint to be fully saved...\n" | |
| # Verify ResNet checkpoint exists before proceeding | |
| resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder_custom.pth") | |
| if not os.path.exists(resnet_checkpoint): | |
| log_message += f"β ResNet checkpoint not found at {resnet_checkpoint}\n" | |
| log_message += "Cannot proceed with ViT training without ResNet embeddings.\n" | |
| return log_message | |
| log_message += f"β ResNet checkpoint verified: {resnet_checkpoint}\n" | |
| # Train ViT with custom parameters | |
| log_message += f"π Starting ViT training with custom parameters...\n" | |
| log_message += f"Dataset Size: {dataset_size} samples\n" | |
| log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n" | |
| log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n" | |
| log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n" | |
| vit_cmd = [ | |
| "python", "training/train_vit.py", | |
| "--data_root", DATASET_ROOT, | |
| "--epochs", str(vit_epochs), | |
| "--batch_size", str(vit_batch_size), | |
| "--max_samples", str(vit_max_samples), | |
| "--lr", str(vit_lr), | |
| "--weight_decay", str(vit_weight_decay), | |
| "--triplet_margin", str(vit_triplet_margin), | |
| "--embedding_dim", str(vit_embedding_dim), | |
| "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth") | |
| ] + dataset_args | |
| result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False) | |
| if result.returncode == 0: | |
| log_message += "β ViT training completed successfully!\n" | |
| log_message += f"π ViT Output:\n{result.stdout}\n\n" | |
| log_message += "π All training completed! Models saved to models/exports/\n" | |
| log_message += "π Reloading models for inference...\n" | |
| service.reload_models() | |
| # Check if models loaded successfully | |
| model_status = service.get_model_status() | |
| if model_status["can_recommend"]: | |
| log_message += "β Models reloaded and ready for inference!\n" | |
| log_message += "π You can now generate outfit recommendations!\n" | |
| else: | |
| log_message += "β οΈ Models reloaded but validation failed!\n" | |
| log_message += "**Model Status:**\n" | |
| log_message += f"- ResNet: {'β Loaded' if model_status['resnet_loaded'] else 'β Failed'}\n" | |
| log_message += f"- ViT: {'β Loaded' if model_status['vit_loaded'] else 'β Failed'}\n" | |
| for error in model_status["errors"]: | |
| log_message += f"- {error}\n" | |
| # Auto-upload to HF Hub if token is available | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| log_message += "π€ Auto-uploading artifacts to Hugging Face Hub...\n" | |
| try: | |
| from utils.hf_utils import HFModelManager | |
| hf = HFModelManager(token=hf_token, username="Stylique") | |
| result = hf.upload_model("everything", "dressify-complete") | |
| if result.get("success"): | |
| log_message += "β Successfully uploaded to HF Hub!\n" | |
| log_message += "π Models: https://huggingface.co/Stylique/dressify-models\n" | |
| log_message += "π Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n" | |
| else: | |
| log_message += f"β οΈ Upload failed: {result.get('error', 'Unknown error')}\n" | |
| except Exception as e: | |
| log_message += f"β οΈ Auto-upload failed: {str(e)}\n" | |
| else: | |
| log_message += "π‘ Set HF_TOKEN env var for automatic uploads\n" | |
| else: | |
| log_message += f"β ViT training failed: {result.stderr}\n" | |
| except Exception as e: | |
| log_message += f"\nβ Training error: {str(e)}" | |
| threading.Thread(target=_runner, daemon=True).start() | |
| return log_message | |
| def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int): | |
| """Start simple training with basic parameters.""" | |
| # Use global dataset size if not specified | |
| if not dataset_size or dataset_size == "full": | |
| dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000") | |
| log_message = f"Starting training on {dataset_size} samples..." | |
| def _runner(): | |
| nonlocal log_message | |
| try: | |
| import subprocess | |
| if not DATASET_ROOT: | |
| log_message = "Dataset not ready." | |
| return | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| log_message = f"Training ResNet on {dataset_size} samples...\n" | |
| # Add dataset size limit if not full | |
| dataset_args = [] | |
| if dataset_size != "full": | |
| dataset_args = ["--max_samples", dataset_size] | |
| # Train ResNet first and wait for completion | |
| log_message += f"\nπ Starting ResNet training on {dataset_size} samples...\n" | |
| resnet_result = subprocess.run([ | |
| "python", "training/train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs), | |
| "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3", | |
| "--out", os.path.join(export_dir, "resnet_item_embedder.pth") | |
| ] + dataset_args, capture_output=True, text=True, check=False) | |
| if resnet_result.returncode == 0: | |
| log_message += "β ResNet training completed successfully!\n" | |
| log_message += f"π ResNet Output:\n{resnet_result.stdout}\n" | |
| else: | |
| log_message += f"β ResNet training failed: {resnet_result.stderr}\n" | |
| return log_message | |
| # Wait a moment for file system sync | |
| import time | |
| time.sleep(2) | |
| # Verify ResNet checkpoint exists before proceeding | |
| resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder.pth") | |
| if not os.path.exists(resnet_checkpoint): | |
| log_message += f"β ResNet checkpoint not found at {resnet_checkpoint}\n" | |
| log_message += "Cannot proceed with ViT training without ResNet embeddings.\n" | |
| return log_message | |
| log_message += f"β ResNet checkpoint verified: {resnet_checkpoint}\n" | |
| log_message += f"\nπ Starting ViT training on {dataset_size} samples...\n" | |
| vit_result = subprocess.run([ | |
| "python", "training/train_vit.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs), | |
| "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5", | |
| "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0", | |
| "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth") | |
| ] + dataset_args, capture_output=True, text=True, check=False) | |
| if vit_result.returncode == 0: | |
| log_message += "β ViT training completed successfully!\n" | |
| log_message += f"π ViT Output:\n{vit_result.stdout}\n" | |
| else: | |
| log_message += f"β ViT training failed: {vit_result.stderr}\n" | |
| return log_message | |
| service.reload_models() | |
| # Check if models loaded successfully | |
| model_status = service.get_model_status() | |
| if model_status["can_recommend"]: | |
| log_message += "\nβ Training completed! Models reloaded and ready for inference.\n" | |
| log_message += "π You can now generate outfit recommendations!\n" | |
| else: | |
| log_message += "\nβ οΈ Training completed but models failed to load properly!\n" | |
| log_message += "**Model Status:**\n" | |
| log_message += f"- ResNet: {'β Loaded' if model_status['resnet_loaded'] else 'β Failed'}\n" | |
| log_message += f"- ViT: {'β Loaded' if model_status['vit_loaded'] else 'β Failed'}\n" | |
| for error in model_status["errors"]: | |
| log_message += f"- {error}\n" | |
| log_message += "\nArtifacts saved to models/exports/" | |
| # Auto-upload to HF Hub if token is available | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| log_message += "\nπ€ Auto-uploading artifacts to Hugging Face Hub...\n" | |
| try: | |
| from utils.hf_utils import HFModelManager | |
| hf = HFModelManager(token=hf_token, username="Stylique") | |
| result = hf.upload_model("everything", "dressify-complete") | |
| if result.get("success"): | |
| log_message += "β Successfully uploaded to HF Hub!\n" | |
| log_message += "π Models: https://huggingface.co/Stylique/dressify-models\n" | |
| log_message += "π Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n" | |
| else: | |
| log_message += f"β οΈ Upload failed: {result.get('error', 'Unknown error')}\n" | |
| except Exception as e: | |
| log_message += f"β οΈ Auto-upload failed: {str(e)}\n" | |
| else: | |
| log_message += "\nπ‘ Set HF_TOKEN env var for automatic uploads\n" | |
| except Exception as e: | |
| log_message += f"\nError: {e}" | |
| threading.Thread(target=_runner, daemon=True).start() | |
| return log_message | |
| with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo: | |
| gr.Markdown("## π Dressify β Advanced Outfit Recommendation System\n*Research-grade, self-contained outfit recommendation with comprehensive training controls*") | |
| gr.Markdown("π‘ **Pro Tip**: Start with 2000 samples for quick testing, then increase to 50000+ for production training!") | |
| with gr.Tab("π¨ Recommend"): | |
| gr.Markdown("### π― Personalized Outfit Recommendations\n*Upload your wardrobe and customize recommendations with advanced tag preferences*") | |
| gr.Markdown(f"**Supported Formats:** {', '.join(get_supported_extensions())} (JPG, PNG, WEBP, GIF, BMP, TIFF, and more)") | |
| inp2 = gr.Files( | |
| label="Upload wardrobe images", | |
| file_count="multiple" | |
| # Note: file_types removed to allow API client flexibility | |
| # Validation is handled by our image_utils.load_images_from_files() | |
| ) | |
| with gr.Accordion("π― Primary Tags (Required)", open=True): | |
| with gr.Row(): | |
| occasion = gr.Dropdown( | |
| choices=["casual", "business", "formal", "semi_formal", "business_casual", "cocktail", | |
| "wedding", "party", "date", "sport", "workout", "travel", "beach", "outdoor", | |
| "night_out", "brunch", "dinner", "meeting", "interview", "cultural", "traditional"], | |
| value="casual", | |
| label="Occasion", | |
| info="Select the occasion or event type" | |
| ) | |
| weather = gr.Dropdown( | |
| choices=["any", "hot", "warm", "mild", "cool", "cold", "freezing", "rain", "snow", | |
| "windy", "humid", "sunny", "cloudy"], | |
| value="any", | |
| label="Weather", | |
| info="Current or expected weather conditions" | |
| ) | |
| outfit_style = gr.Dropdown( | |
| choices=["casual", "smart_casual", "formal", "sporty", "athletic", "streetwear", | |
| "minimalist", "classic", "modern", "elegant", "sophisticated", "traditional", "ethnic"], | |
| value="casual", | |
| label="Outfit Style", | |
| info="Preferred fashion aesthetic" | |
| ) | |
| with gr.Accordion("π¨ Style & Preference Tags (Optional)", open=False): | |
| with gr.Row(): | |
| color_preference = gr.Dropdown( | |
| choices=["None", "neutral", "monochromatic", "monochrome", "complementary", "bold", "subtle", | |
| "bright", "muted", "pastel", "dark", "light", "earth_tones", "jewel_tones", | |
| "black_white", "navy_white", "colorful", "minimal_color"], | |
| value="None", | |
| label="Color Preference", | |
| info="Preferred color scheme" | |
| ) | |
| fit_preference = gr.Dropdown( | |
| choices=["None", "fitted", "loose", "oversized", "relaxed", "comfortable", | |
| "structured", "flowy", "tailored", "athletic_fit", "regular_fit"], | |
| value="None", | |
| label="Fit Preference", | |
| info="Preferred fit and silhouette" | |
| ) | |
| material_preference = gr.Dropdown( | |
| choices=["None", "cotton", "linen", "silk", "wool", "cashmere", "denim", | |
| "leather", "breathable", "waterproof", "moisture_wicking", "sustainable"], | |
| value="None", | |
| label="Material Preference", | |
| info="Preferred fabric or material type" | |
| ) | |
| with gr.Accordion("π Context Tags (Optional)", open=False): | |
| with gr.Row(): | |
| season = gr.Dropdown( | |
| choices=["None", "spring", "summer", "fall", "autumn", "winter", "year_round", "transitional"], | |
| value="None", | |
| label="Season", | |
| info="Current season" | |
| ) | |
| time_of_day = gr.Dropdown( | |
| choices=["None", "morning", "afternoon", "evening", "night", "all_day"], | |
| value="None", | |
| label="Time of Day", | |
| info="When will you wear this outfit?" | |
| ) | |
| budget = gr.Dropdown( | |
| choices=["None", "luxury", "premium", "mid_range", "affordable", "budget", "value"], | |
| value="None", | |
| label="Budget Preference", | |
| info="Price range preference (informational)" | |
| ) | |
| personal_style = gr.Dropdown( | |
| choices=["None", "conservative", "moderate", "bold", "experimental", "traditional", | |
| "trendy", "timeless", "fashion_forward", "classic", "eclectic"], | |
| value="None", | |
| label="Personal Style", | |
| info="Your personal style preference" | |
| ) | |
| with gr.Row(): | |
| num_outfits = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Outfits", info="How many outfit recommendations to generate") | |
| out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=400, show_label=True) | |
| out_json = gr.JSON(label="Outfit Details & Tag Analysis", show_label=True) | |
| btn2 = gr.Button("β¨ Generate Personalized Outfits", variant="primary", size="lg") | |
| btn2.click( | |
| fn=gradio_recommend, | |
| inputs=[inp2, occasion, weather, num_outfits, outfit_style, | |
| color_preference, fit_preference, material_preference, | |
| season, time_of_day, budget, personal_style], | |
| outputs=[out_gallery, out_json] | |
| ) | |
| with gr.Tab("π¬ Advanced Training"): | |
| gr.Markdown("### π― Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.") | |
| # Dataset Preparation Section | |
| with gr.Accordion("π¦ Dataset Preparation (Optional)", open=False): | |
| gr.Markdown("**Note**: Dataset preparation is now manual only. Click the button below to download and prepare the dataset when needed.") | |
| with gr.Row(): | |
| prepare_dataset_btn = gr.Button("π₯ Download & Prepare Dataset", variant="secondary") | |
| prepare_status = gr.Textbox(label="Dataset Preparation Status", value="Dataset will be prepared if missing", interactive=False) | |
| def prepare_dataset_manual(): | |
| """Manually trigger dataset preparation.""" | |
| global DATASET_ROOT, BOOT_STATUS | |
| try: | |
| BOOT_STATUS = "preparing-dataset" | |
| # Check if dataset already exists | |
| root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore")) | |
| images_dir = os.path.join(root, "images") | |
| has_images = os.path.isdir(images_dir) and any(os.listdir(images_dir)) | |
| if has_images: | |
| print("β Images already exist, skipping download/extraction") | |
| ds_root = root | |
| else: | |
| print("π₯ Downloading and extracting dataset...") | |
| ds_root = ensure_dataset_ready() | |
| DATASET_ROOT = ds_root | |
| if not ds_root: | |
| BOOT_STATUS = "dataset-not-prepared" | |
| return "β Failed to prepare dataset" | |
| # Prepare splits if missing | |
| splits_dir = os.path.join(ds_root, "splits") | |
| has_splits = ( | |
| os.path.isfile(os.path.join(splits_dir, "train.json")) or | |
| os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json")) | |
| ) | |
| if not has_splits: | |
| os.makedirs(splits_dir, exist_ok=True) | |
| from scripts.prepare_polyvore import main as prepare_main | |
| os.environ.setdefault("PYTHONWARNINGS", "ignore") | |
| import sys | |
| argv_bak = sys.argv | |
| try: | |
| sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "500"] | |
| prepare_main() | |
| BOOT_STATUS = "ready" | |
| return "β Dataset and splits prepared successfully!" | |
| finally: | |
| sys.argv = argv_bak | |
| else: | |
| BOOT_STATUS = "ready" | |
| return "β Dataset already prepared (images and splits exist)" | |
| except Exception as e: | |
| BOOT_STATUS = "error" | |
| import traceback | |
| return f"β Error: {str(e)}\n{traceback.format_exc()}" | |
| prepare_dataset_btn.click(fn=prepare_dataset_manual, inputs=[], outputs=prepare_status) | |
| # Global Dataset Size Control | |
| with gr.Row(): | |
| gr.Markdown("#### π― **Global Dataset Size Control**") | |
| gr.Markdown("**Note**: Use 'Apply' button to regenerate splits with different size limits.") | |
| with gr.Row(): | |
| gr.Markdown("#### π **Current Behavior**") | |
| gr.Markdown("β’ **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **500 samples by default**\nβ’ **Training**: Uses 500 samples (ultra-fast training!)\nβ’ **Apply Button**: Regenerates splits with your selected size limit") | |
| with gr.Row(): | |
| global_dataset_size = gr.Dropdown( | |
| choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], | |
| value="500", | |
| label="Global Dataset Size (Affects Prep + Training)" | |
| ) | |
| gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)") | |
| with gr.Row(): | |
| # Apply dataset size button | |
| apply_size_btn = gr.Button("π Apply Dataset Size & Regenerate Splits", variant="primary") | |
| size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 500 samples (click Apply to regenerate splits)", interactive=False) | |
| # Current dataset info | |
| gr.Markdown("#### π **Current Dataset Status**") | |
| gr.Markdown("β’ **Full dataset downloaded**: 53,306 outfits (required for system)\nβ’ **Splits generated**: **500 samples by default** (ultra-fast training!)\nβ’ **Training will use**: 500 samples (ultra-fast training!)\nβ’ **Scale up**: Use Apply button to increase to larger sizes") | |
| def apply_dataset_size(size: str): | |
| """Apply global dataset size and regenerate splits.""" | |
| try: | |
| if size == "full": | |
| return f"β Using full dataset ({size}) - no size limit applied" | |
| # Call the dataset preparation with size limit | |
| import subprocess | |
| import os | |
| # Set environment variable for dataset size | |
| os.environ["DATASET_SIZE_LIMIT"] = size | |
| # Check if script exists | |
| script_path = "scripts/prepare_polyvore.py" | |
| if not os.path.exists(script_path): | |
| return f"β Script not found: {script_path}" | |
| # Regenerate splits with size limit using subprocess | |
| cmd = [ | |
| "python", script_path, | |
| "--root", "/home/user/app/data/Polyvore", | |
| "--out", "/home/user/app/data/Polyvore/splits", | |
| "--max_samples", size | |
| ] | |
| print(f"Running command: {' '.join(cmd)}") | |
| print(f"Current working directory: {os.getcwd()}") | |
| # Run from the correct directory | |
| result = subprocess.run(cmd, capture_output=True, text=True, check=False, cwd="/home/user/app") | |
| if result.returncode == 0: | |
| return f"β Successfully regenerated splits with {size} samples limit" | |
| else: | |
| error_msg = f"β Failed to regenerate splits:\n" | |
| error_msg += f"Return code: {result.returncode}\n" | |
| error_msg += f"STDOUT: {result.stdout}\n" | |
| error_msg += f"STDERR: {result.stderr}" | |
| return error_msg | |
| except Exception as e: | |
| return f"β Failed to apply dataset size: {str(e)}" | |
| apply_size_btn.click(fn=apply_dataset_size, inputs=[global_dataset_size], outputs=[size_status]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Dataset Size Control") | |
| gr.Markdown("Start small for testing, increase for production training") | |
| dataset_size = gr.Dropdown( | |
| choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], | |
| value="500", | |
| label="Training Dataset Size" | |
| ) | |
| gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training") | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### πΌοΈ ResNet Item Embedder") | |
| # Model architecture | |
| resnet_backbone = gr.Dropdown( | |
| choices=["resnet50", "resnet101"], | |
| value="resnet50", | |
| label="Backbone Architecture" | |
| ) | |
| resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") | |
| resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained") | |
| resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") | |
| # Training parameters | |
| resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs") | |
| resnet_batch_size = gr.Slider(4, 128, value=4, step=4, label="Batch Size") | |
| resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate") | |
| resnet_optimizer = gr.Dropdown( | |
| choices=["adamw", "adam", "sgd", "rmsprop"], | |
| value="adamw", | |
| label="Optimizer" | |
| ) | |
| resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay") | |
| resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin") | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π§ ViT Outfit Encoder") | |
| # Model architecture | |
| vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") | |
| vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers") | |
| vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads") | |
| vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier") | |
| vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") | |
| # Training parameters | |
| vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs") | |
| vit_batch_size = gr.Slider(2, 64, value=4, step=2, label="Batch Size") | |
| vit_max_samples = gr.Slider(100, 5000, value=500, step=100, label="Max Training Samples") | |
| vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate") | |
| vit_optimizer = gr.Dropdown( | |
| choices=["adamw", "adam", "sgd", "rmsprop"], | |
| value="adamw", | |
| label="Optimizer" | |
| ) | |
| vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay") | |
| vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### βοΈ Advanced Training Settings") | |
| # Hardware optimization | |
| use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)") | |
| channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format") | |
| gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping") | |
| # Learning rate scheduling | |
| warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs") | |
| scheduler_type = gr.Dropdown( | |
| choices=["cosine", "step", "plateau", "linear"], | |
| value="cosine", | |
| label="Learning Rate Scheduler" | |
| ) | |
| early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience") | |
| # Training strategy | |
| mining_strategy = gr.Dropdown( | |
| choices=["semi_hard", "hardest", "random"], | |
| value="semi_hard", | |
| label="Triplet Mining Strategy" | |
| ) | |
| augmentation_level = gr.Dropdown( | |
| choices=["minimal", "standard", "aggressive"], | |
| value="standard", | |
| label="Data Augmentation Level" | |
| ) | |
| seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Training Control") | |
| # Quick training | |
| gr.Markdown("**Quick Training (Basic Parameters)**") | |
| epochs_res = gr.Slider(1, 50, value=3, step=1, label="ResNet epochs") | |
| epochs_vit = gr.Slider(1, 100, value=3, step=1, label="ViT epochs") | |
| start_btn = gr.Button("π Start Quick Training", variant="secondary") | |
| # Advanced training | |
| gr.Markdown("**Advanced Training (Custom Parameters)**") | |
| start_advanced_btn = gr.Button("π― Start Advanced Training", variant="primary") | |
| # Training log | |
| train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20) | |
| # Status | |
| gr.Markdown("**Training Status**") | |
| training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False) | |
| # Event handlers | |
| start_btn.click( | |
| fn=start_training_simple, | |
| inputs=[dataset_size, epochs_res, epochs_vit], | |
| outputs=train_log | |
| ) | |
| start_advanced_btn.click( | |
| fn=start_training_advanced, | |
| inputs=[ | |
| # Dataset size | |
| dataset_size, | |
| # ResNet parameters | |
| resnet_epochs, resnet_batch_size, resnet_lr, resnet_optimizer, | |
| resnet_weight_decay, resnet_triplet_margin, resnet_embedding_dim, | |
| resnet_backbone, resnet_use_pretrained, resnet_dropout, | |
| # ViT parameters | |
| vit_epochs, vit_batch_size, vit_max_samples, vit_lr, vit_optimizer, | |
| vit_weight_decay, vit_triplet_margin, vit_embedding_dim, | |
| vit_num_layers, vit_num_heads, vit_ff_multiplier, vit_dropout, | |
| # Advanced parameters | |
| use_mixed_precision, channels_last, gradient_clip, | |
| warmup_epochs, scheduler_type, early_stopping_patience, | |
| mining_strategy, augmentation_level, seed | |
| ], | |
| outputs=train_log | |
| ) | |
| with gr.Tab("π¦ Artifact Management"): | |
| gr.Markdown("### π― Comprehensive Artifact Management\nManage, package, and upload all system artifacts to Hugging Face Hub.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Artifact Overview") | |
| artifact_overview = gr.JSON(label="System Artifacts", value=get_artifact_overview) | |
| refresh_overview = gr.Button("π Refresh Overview") | |
| refresh_overview.click(fn=get_artifact_overview, inputs=[], outputs=artifact_overview) | |
| gr.Markdown("#### π¦ Create Packages") | |
| package_type = gr.Dropdown( | |
| choices=["complete", "splits_only", "models_only"], | |
| value="complete", | |
| label="Package Type" | |
| ) | |
| create_package_btn = gr.Button("π¦ Create Package") | |
| package_result = gr.Textbox(label="Package Result", interactive=False) | |
| available_packages = gr.JSON(label="Available Packages", value=get_available_packages) | |
| create_package_btn.click( | |
| fn=create_download_package, | |
| inputs=[package_type], | |
| outputs=[package_result, available_packages] | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Hugging Face Hub Integration") | |
| gr.Markdown("π‘ **Pro Tip**: Set `HF_TOKEN` environment variable for automatic uploads after training!") | |
| hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_...") | |
| hf_username = gr.Textbox(label="Username", placeholder="your-username") | |
| with gr.Row(): | |
| push_splits_btn = gr.Button("π€ Push Splits", variant="secondary") | |
| push_models_btn = gr.Button("π€ Push Models", variant="secondary") | |
| push_everything_btn = gr.Button("π€ Push Everything", variant="primary") | |
| hf_result = gr.Textbox(label="Upload Result", interactive=False, lines=3) | |
| push_splits_btn.click(fn=push_splits_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) | |
| push_models_btn.click(fn=push_models_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) | |
| push_everything_btn.click(fn=push_everything_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) | |
| gr.Markdown("#### π₯ Download Management") | |
| individual_files = gr.JSON(label="Individual Files", value=get_individual_files) | |
| download_all_btn = gr.Button("π₯ Download All as ZIP") | |
| download_result = gr.Textbox(label="Download Result", interactive=False) | |
| download_all_btn.click(fn=download_all_files, inputs=[], outputs=download_result) | |
| with gr.Tab("π Status"): | |
| gr.Markdown("### π¦ System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.") | |
| status = gr.Textbox(label="Bootstrap Status", value=lambda: BOOT_STATUS) | |
| refresh_status = gr.Button("π Refresh Status") | |
| refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status) | |
| # Model Status | |
| gr.Markdown("#### π€ Model Status") | |
| model_status = gr.JSON(label="Model Loading Status", value=lambda: service.get_model_status()) | |
| refresh_models = gr.Button("π Refresh Model Status") | |
| refresh_models.click(fn=lambda: service.get_model_status(), inputs=[], outputs=model_status) | |
| # System info | |
| gr.Markdown("#### π» System Information") | |
| device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}") | |
| resnet_version = gr.Textbox(label="ResNet Version", value=lambda: f"ResNet: {service.resnet_version}") | |
| vit_version = gr.Textbox(label="ViT Version", value=lambda: f"ViT: {service.vit_version}") | |
| # Health check | |
| gr.Markdown("#### π₯ Health Check") | |
| health_btn = gr.Button("π Check Health") | |
| health_status = gr.Textbox(label="Health Status", value="Click to check") | |
| def check_health(): | |
| try: | |
| health = app.get("/health") | |
| return f"β System Healthy - {health}" | |
| except Exception as e: | |
| return f"β Health Check Failed: {str(e)}" | |
| health_btn.click(fn=check_health, inputs=[], outputs=health_status) | |
| try: | |
| # Mount Gradio onto FastAPI root path (disable SSR to avoid stray port fetches) | |
| demo.queue() | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| except Exception: | |
| # In case mounting fails in certain runners, we still want FastAPI to be available | |
| pass | |
| # Mount static files for direct artifact download | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| try: | |
| app.mount("/files", StaticFiles(directory=export_dir), name="files") | |
| except Exception: | |
| pass | |
| if __name__ == "__main__": | |
| # Local/Space run | |
| demo.queue().launch(ssr_mode=False) | |