recomendation / app.py
Ali Mohsin
folder reorganise
72af8c3
import os
import base64
import io
import tempfile
from typing import List, Optional, Any, Dict
import gradio as gr
import numpy as np
import requests
import torch
from fastapi import FastAPI, Header, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List as TypingList
from PIL import Image
from starlette.staticfiles import StaticFiles
import threading
import json
from inference import InferenceService
from utils.data_fetch import ensure_dataset_ready
from utils.tag_system import get_all_tag_options, validate_tags, TagProcessor
from utils.image_utils import (
load_images_from_files,
load_image_from_bytes,
load_image_from_url,
is_image_file,
get_supported_formats,
get_supported_extensions,
ensure_rgb_image
)
# Global state
BOOT_STATUS = "starting"
DATASET_ROOT: Optional[str] = None
def get_artifact_overview():
"""Get comprehensive artifact overview."""
try:
from utils.artifact_manager import create_artifact_manager
manager = create_artifact_manager()
return manager.get_artifact_summary()
except Exception as e:
return {"error": str(e)}
def export_artifact_summary():
"""Export artifact summary as JSON file."""
try:
from utils.artifact_manager import create_artifact_manager
manager = create_artifact_manager()
summary = manager.get_artifact_summary()
# Save to exports directory
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
summary_path = os.path.join(export_dir, "system_summary.json")
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
return summary_path
except Exception as e:
return None
def create_download_package(package_type: str):
"""Create a downloadable package."""
try:
from utils.artifact_manager import create_artifact_manager
manager = create_artifact_manager()
# Extract package type from the dropdown choice
if "complete" in package_type:
pkg_type = "complete"
elif "splits_only" in package_type:
pkg_type = "splits_only"
elif "models_only" in package_type:
pkg_type = "models_only"
else:
return f"❌ Invalid package type: {package_type}", get_available_packages()
package_path = manager.create_download_package(pkg_type)
package_name = os.path.basename(package_path)
return f"βœ… Package created: {package_name}", get_available_packages()
except Exception as e:
return f"❌ Failed to create package: {e}", get_available_packages()
def get_available_packages():
"""Get list of available packages."""
try:
export_dir = os.getenv("EXPORT_DIR", "models/exports")
packages = []
if os.path.exists(export_dir):
for file in os.listdir(export_dir):
if file.endswith((".tar.gz", ".zip")):
file_path = os.path.join(export_dir, file)
packages.append({
"name": file,
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
"path": file_path,
"url": f"/files/{file}"
})
return {"packages": packages}
except Exception as e:
return {"error": str(e)}
def get_individual_files():
"""Get list of individual downloadable files."""
try:
from utils.artifact_manager import create_artifact_manager
manager = create_artifact_manager()
files = manager.get_downloadable_files()
# Group by category
categorized = {}
for file in files:
category = file["category"]
if category not in categorized:
categorized[category] = []
categorized[category].append(file)
return categorized
except Exception as e:
return {"error": str(e)}
def download_all_files():
"""Download all files as a ZIP archive."""
try:
from utils.artifact_manager import create_artifact_manager
manager = create_artifact_manager()
files = manager.get_downloadable_files()
# Create ZIP with all files
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
zip_path = os.path.join(export_dir, "all_artifacts.zip")
import zipfile
with zipfile.ZipFile(zip_path, 'w') as zipf:
for file in files:
if os.path.exists(file["path"]):
zipf.write(file["path"], file["name"])
return zip_path
except Exception as e:
return None
def get_training_status():
"""Get current training status from the monitor."""
try:
from ui.monitor import create_monitor
monitor = create_monitor()
status = monitor.get_status()
return status if status else {"status": "no-training"}
except Exception as e:
return {"status": "error", "error": str(e)}
def push_splits_to_hf(token, username):
"""Push splits to HF Hub."""
if not token or not username:
return "❌ Please provide HF token and username"
try:
from utils.hf_utils import HFModelManager
hf = HFModelManager(token=token, username=username)
result = hf.upload_model("splits", "Dressify-Helper")
if result.get("success"):
return f"βœ… Successfully uploaded splits to {username}/Dressify-Helper"
else:
return f"❌ Failed to upload splits: {result.get('error', 'Unknown error')}"
except Exception as e:
return f"❌ Upload failed: {e}"
def push_models_to_hf(token, username):
"""Push models to HF Hub."""
if not token or not username:
return "❌ Please provide HF token and username"
try:
from utils.hf_utils import HFModelManager
hf = HFModelManager(token=token, username=username)
result = hf.upload_model("models", "dressify-models")
if result.get("success"):
return f"βœ… Successfully uploaded models to {username}/dressify-models"
else:
return f"❌ Failed to upload models: {result.get('error', 'Unknown error')}"
except Exception as e:
return f"❌ Upload failed: {e}"
def push_everything_to_hf(token, username):
"""Push everything to HF Hub."""
if not token or not username:
return "❌ Please provide HF token and username"
try:
from utils.hf_utils import HFModelManager
hf = HFModelManager(token=token, username=username)
result = hf.upload_model("everything", "dressify-complete")
if result.get("success"):
return f"βœ… Successfully uploaded everything to HF Hub"
else:
return f"❌ Failed to upload everything: {result.get('error', 'Unknown error')}"
except Exception as e:
return f"❌ Upload failed: {e}"
AI_API_KEY = os.getenv("AI_API_KEY")
def require_api_key(x_api_key: Optional[str]):
if AI_API_KEY and x_api_key != AI_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
class EmbedRequest(BaseModel):
image_urls: Optional[List[str]] = None
images_base64: Optional[List[str]] = None
class Item(BaseModel):
id: str
embedding: Optional[List[float]] = None
category: Optional[str] = None
image_url: Optional[str] = None
image_base64: Optional[str] = None # Base64 encoded image
class ComposeRequest(BaseModel):
items: List[Item]
context: Optional[Dict[str, Any]] = None
# Expanded tag fields for better API usability
occasion: Optional[str] = "casual"
weather: Optional[str] = "any"
style: Optional[str] = "casual"
outfit_style: Optional[str] = None # Alias for style
num_outfits: Optional[int] = 5
# Optional tags
color_preference: Optional[str] = None
fit_preference: Optional[str] = None
material_preference: Optional[str] = None
season: Optional[str] = None
time_of_day: Optional[str] = None
budget: Optional[str] = None
personal_style: Optional[str] = None
age_group: Optional[str] = None
gender: Optional[str] = None
class Config:
extra = "allow" # Allow additional fields for flexibility
app = FastAPI(title="Dressify Recommendation Service")
service = InferenceService()
# Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
BOOT_STATUS = "idle"
DATASET_ROOT: Optional[str] = None
def _background_bootstrap():
global BOOT_STATUS
global DATASET_ROOT
try:
# Only check if dataset exists - DO NOT prepare it automatically
root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore"))
images_dir = os.path.join(root, "images")
splits_dir = os.path.join(root, "splits")
# Check if dataset already exists
has_images = os.path.isdir(images_dir) and any(os.listdir(images_dir))
has_splits = (
os.path.isfile(os.path.join(splits_dir, "train.json")) or
os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json"))
)
if has_images and has_splits:
print("βœ… Dataset and splits already prepared")
DATASET_ROOT = root
BOOT_STATUS = "ready"
elif has_images:
print("βœ… Dataset images exist, but splits may be missing (use Advanced Training to prepare)")
DATASET_ROOT = root
BOOT_STATUS = "ready"
else:
print("ℹ️ Dataset not prepared. Use 'Download & Prepare Dataset' button in Advanced Training tab if needed.")
DATASET_ROOT = None
BOOT_STATUS = "ready" # System is ready, just dataset not prepared
# NO automatic training - models should be pre-trained or trained manually via UI
BOOT_STATUS = "ready"
except Exception as e:
BOOT_STATUS = f"error: {e}"
threading.Thread(target=_background_bootstrap, daemon=True).start()
@app.get("/health")
def health() -> dict:
return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
@app.get("/tags")
def get_tags() -> dict:
"""
Get all available tag options for API integration.
Returns comprehensive list of all tag categories and their valid values.
"""
return {
"tag_categories": get_all_tag_options(),
"description": "Available tags for personalized outfit recommendations",
"usage": {
"primary_tags": ["occasion", "weather", "style"],
"optional_tags": ["color_preference", "fit_preference", "material_preference",
"season", "time_of_day", "budget", "personal_style",
"age_group", "gender"]
}
}
@app.get("/image-formats")
def get_image_formats() -> dict:
"""
Get all supported image formats for API integration.
"""
return {
"supported_formats": get_supported_formats(),
"supported_extensions": get_supported_extensions(),
"description": "All major image formats are supported including JPG, PNG, WEBP, GIF, BMP, TIFF, and more",
"note": "Images are automatically converted to RGB mode for model processing"
}
@app.post("/compose/upload")
async def compose_with_upload(
files: TypingList[UploadFile] = File(...),
occasion: str = Form("casual"),
weather: str = Form("any"),
style: str = Form("casual"),
num_outfits: int = Form(5),
color_preference: Optional[str] = Form(None),
fit_preference: Optional[str] = Form(None),
material_preference: Optional[str] = Form(None),
season: Optional[str] = Form(None),
time_of_day: Optional[str] = Form(None),
budget: Optional[str] = Form(None),
personal_style: Optional[str] = Form(None),
x_api_key: Optional[str] = Header(None)
) -> dict:
"""
Generate outfit recommendations from uploaded image files.
This endpoint accepts multipart/form-data with:
- files: Image files (JPG, PNG, WEBP, GIF, BMP, TIFF, etc.)
- All tag parameters as form fields
Returns personalized outfit recommendations.
"""
require_api_key(x_api_key)
if not files or len(files) < 2:
raise HTTPException(status_code=400, detail="At least 2 images required for outfit recommendations")
# Load images from uploaded files
items = []
errors = []
for i, file in enumerate(files):
try:
# Read file content
contents = await file.read()
# Load image from bytes
img = load_image_from_bytes(contents, convert_to_rgb=True, raise_on_error=False)
if img is None:
errors.append(f"Failed to load image from file {file.filename}")
continue
items.append({
"id": f"item_{i}",
"image": img,
"category": None # Will be auto-detected
})
except Exception as e:
errors.append(f"Error processing file {file.filename}: {str(e)}")
if len(items) < 2:
error_msg = f"Not enough valid images. Need at least 2, got {len(items)}."
if errors:
error_msg += f" Errors: {', '.join(errors[:3])}"
raise HTTPException(status_code=400, detail=error_msg)
# Build context
context = {
"occasion": occasion,
"weather": weather,
"style": style,
"outfit_style": style,
"num_outfits": num_outfits
}
# Add optional tags
if color_preference and color_preference != "None":
context["color_preference"] = color_preference
if fit_preference and fit_preference != "None":
context["fit_preference"] = fit_preference
if material_preference and material_preference != "None":
context["material_preference"] = material_preference
if season and season != "None":
context["season"] = season
if time_of_day and time_of_day != "None":
context["time_of_day"] = time_of_day
if budget and budget != "None":
context["budget"] = budget
if personal_style and personal_style != "None":
context["personal_style"] = personal_style
# Validate tags
is_valid, tag_errors = validate_tags(context)
if not is_valid:
return JSONResponse(
status_code=400,
content={
"error": "Invalid tags provided",
"errors": tag_errors,
"valid_tag_options": get_all_tag_options()
}
)
# Generate recommendations
try:
outfits = service.compose_outfits(items, context=context)
# Check for errors
if outfits and isinstance(outfits, list) and len(outfits) > 0:
if isinstance(outfits[0], dict) and "error" in outfits[0]:
return JSONResponse(
status_code=500,
content={
"error": "Recommendation generation failed",
"details": outfits[0].get("details", []),
"message": outfits[0].get("message", "Unknown error")
}
)
return {
"outfits": outfits,
"version": service.vit_version,
"tags_processed": True,
"context_used": context,
"items_processed": len(items),
"warnings": errors if errors else None
}
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"message": str(e),
"model_status": service.get_model_status()
}
)
@app.post("/tags/validate")
def validate_request_tags(tags: Dict[str, Any], x_api_key: Optional[str] = Header(None)) -> dict:
"""
Validate tag values before making a recommendation request.
Useful for API clients to check tag validity.
"""
require_api_key(x_api_key)
is_valid, errors = validate_tags(tags)
return {
"valid": is_valid,
"errors": errors if not is_valid else [],
"validated_tags": tags if is_valid else None
}
@app.get("/model-status")
def model_status() -> dict:
"""Get detailed model loading status."""
return service.get_model_status()
@app.post("/reload-models")
def reload_models() -> dict:
"""Force reload models - useful for debugging."""
try:
service.force_reload_models()
return {"status": "success", "message": "Models reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.post("/test-recommend")
def test_recommend() -> dict:
"""Test recommendation with dummy data to debug the issue."""
try:
# Create dummy items for testing
dummy_items = [
{"id": "test_1", "image": None, "category": "shirt"},
{"id": "test_2", "image": None, "category": "pants"},
{"id": "test_3", "image": None, "category": "shoes"}
]
# Try to get recommendations
result = service.compose_outfits(dummy_items, {"num_outfits": 1})
return {
"status": "success",
"model_status": service.get_model_status(),
"result": result,
"result_length": len(result) if result else 0
}
except Exception as e:
return {"status": "error", "message": str(e), "model_status": service.get_model_status()}
@app.post("/embed")
def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
"""
Generate embeddings for images with comprehensive format support.
Supports JPG, PNG, WEBP, GIF, BMP, TIFF, and other major formats.
"""
require_api_key(x_api_key)
images: List[Image.Image] = []
errors = []
# Load from URLs
if req.image_urls:
for url in req.image_urls:
img = load_image_from_url(url, timeout=20, convert_to_rgb=True, raise_on_error=False)
if img is not None:
images.append(img)
else:
errors.append(f"Failed to load image from URL: {url}")
# Load from base64
if req.images_base64:
for b64 in req.images_base64:
try:
image_bytes = base64.b64decode(b64)
img = load_image_from_bytes(image_bytes, convert_to_rgb=True, raise_on_error=False)
if img is not None:
images.append(img)
else:
errors.append("Failed to load image from base64")
except Exception as e:
errors.append(f"Error decoding base64 image: {str(e)}")
if not images:
error_msg = "No images provided or all images failed to load"
if errors:
error_msg += f". Errors: {', '.join(errors[:3])}"
raise HTTPException(status_code=400, detail=error_msg)
# Ensure all images are RGB
images = [ensure_rgb_image(img) for img in images]
embs = service.embed_images(images)
return {
"embeddings": [e.tolist() for e in embs],
"model_version": service.resnet_version,
"images_loaded": len(images),
"errors": errors if errors else None
}
@app.post("/compose")
def compose(req: ComposeRequest, x_api_key: Optional[str] = Header(None)) -> dict:
"""
Generate personalized outfit recommendations with expanded tag support.
Supports both legacy context dict format and new tag-based format.
Tags are processed and prioritized automatically.
Items can provide:
- image_url: URL to image (will be downloaded)
- image_base64: Base64 encoded image
- embedding: Pre-computed embedding (skips ResNet)
- category: Item category (optional, auto-detected if not provided)
"""
require_api_key(x_api_key)
# Build items with image loading
items = []
errors = []
for it in req.items:
item_dict = {
"id": it.id,
"embedding": np.array(it.embedding, dtype=np.float32) if it.embedding is not None else None,
"category": it.category,
}
# Load image from URL if provided
if it.image_url:
img = load_image_from_url(it.image_url, timeout=20, convert_to_rgb=True, raise_on_error=False)
if img is not None:
item_dict["image"] = img
else:
errors.append(f"Failed to load image from URL for item {it.id}: {it.image_url}")
# Load image from base64 if provided
elif it.image_base64:
try:
image_bytes = base64.b64decode(it.image_base64)
img = load_image_from_bytes(image_bytes, convert_to_rgb=True, raise_on_error=False)
if img is not None:
item_dict["image"] = img
else:
errors.append(f"Failed to load image from base64 for item {it.id}")
except Exception as e:
errors.append(f"Error decoding base64 image for item {it.id}: {str(e)}")
# If no image and no embedding, skip this item
if item_dict.get("image") is None and item_dict.get("embedding") is None:
errors.append(f"Item {it.id} has no image or embedding - skipping")
continue
items.append(item_dict)
if not items:
error_msg = "No valid items provided. All items failed to load or have no images/embeddings."
if errors:
error_msg += f" Errors: {', '.join(errors[:5])}"
raise HTTPException(status_code=400, detail=error_msg)
# Build context from request
context = req.context or {}
# Add explicit tag fields if provided (takes precedence over context dict)
if req.occasion:
context["occasion"] = req.occasion
if req.weather:
context["weather"] = req.weather
if req.style:
context["style"] = req.style
if req.outfit_style:
context["outfit_style"] = req.outfit_style
context["style"] = req.outfit_style # Also set style for consistency
if req.num_outfits:
context["num_outfits"] = req.num_outfits
# Add optional tags
optional_tags = {
"color_preference": req.color_preference,
"fit_preference": req.fit_preference,
"material_preference": req.material_preference,
"season": req.season,
"time_of_day": req.time_of_day,
"budget": req.budget,
"personal_style": req.personal_style,
"age_group": req.age_group,
"gender": req.gender,
}
for tag_name, tag_value in optional_tags.items():
if tag_value:
context[tag_name] = tag_value
# Validate tags
is_valid, tag_errors = validate_tags(context)
if not is_valid:
return JSONResponse(
status_code=400,
content={
"error": "Invalid tags provided",
"errors": tag_errors,
"valid_tag_options": get_all_tag_options()
}
)
# Generate recommendations
try:
outfits = service.compose_outfits(items, context=context)
# Check if compose_outfits returned an error
if outfits and isinstance(outfits, list) and len(outfits) > 0:
if isinstance(outfits[0], dict) and "error" in outfits[0]:
return JSONResponse(
status_code=500,
content={
"error": "Recommendation generation failed",
"details": outfits[0].get("details", []),
"message": outfits[0].get("message", "Unknown error")
}
)
return {
"outfits": outfits,
"version": service.vit_version,
"tags_processed": True,
"context_used": context,
"items_processed": len(items),
"warnings": errors if errors else None
}
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(
status_code=500,
content={
"error": "Internal server error during recommendation generation",
"message": str(e),
"model_status": service.get_model_status()
}
)
@app.get("/artifacts")
def artifacts() -> dict:
# list exported model artifacts for download
export_dir = os.getenv("EXPORT_DIR", "models/exports")
files = []
if os.path.isdir(export_dir):
for fn in os.listdir(export_dir):
if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
files.append({
"name": fn,
"path": f"{export_dir}/{fn}",
"url": f"/files/{fn}",
})
return {"artifacts": files}
# --------- Gradio UI ---------
def _load_images_from_files(files: List[str]) -> List[Image.Image]:
"""
Load images from file paths with comprehensive format support.
Supports JPG, PNG, WEBP, GIF, BMP, TIFF, and other major formats.
"""
return load_images_from_files(files, convert_to_rgb=True, skip_errors=True)
def gradio_embed(files: List[str]):
if not files:
return "[]"
images = _load_images_from_files(files)
if not images:
return "[]"
embs = service.embed_images(images)
return str([e.tolist() for e in embs])
def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(245, 245, 245)) -> Image.Image:
if not imgs:
return Image.new("RGB", (1, height), color=bg)
resized = []
for im in imgs:
if im.mode != "RGB":
im = im.convert("RGB")
w, h = im.size
scale = height / float(h)
nw = max(1, int(w * scale))
resized.append(im.resize((nw, height)))
total_w = sum(im.size[0] for im in resized) + pad * (len(resized) + 1)
out = Image.new("RGB", (total_w, height + 2 * pad), color=bg)
x = pad
for im in resized:
out.paste(im, (x, pad))
x += im.size[0] + pad
return out
def gradio_recommend(
files: List[str],
occasion: str,
weather: str,
num_outfits: int,
outfit_style: str = "casual",
color_preference: str = None,
fit_preference: str = None,
material_preference: str = None,
season: str = None,
time_of_day: str = None,
budget: str = None,
personal_style: str = None
):
# Check model status first
model_status = service.get_model_status()
if not model_status["can_recommend"]:
error_msg = "❌ Models not ready for recommendations!\n\n"
error_msg += "**Model Status:**\n"
error_msg += f"- ResNet: {'βœ… Loaded' if model_status['resnet_loaded'] else '❌ Not loaded'}\n"
error_msg += f"- ViT: {'βœ… Loaded' if model_status['vit_loaded'] else '❌ Not loaded'}\n\n"
error_msg += "**Errors:**\n"
for error in model_status["errors"]:
error_msg += f"- {error}\n\n"
error_msg += "**Solution:**\n"
error_msg += "Please train the models first using the 'Simple Training' or 'Advanced Training' tabs, or ensure trained checkpoints are available."
return [], {"error": error_msg, "model_status": model_status}
# Return stitched outfit images and a JSON with details
if not files:
return [], {"error": "No files uploaded"}
# Enhanced debug: Log detailed file information for API troubleshooting
print(f"πŸ” DEBUG: gradio_recommend called with {len(files)} files")
file_info = []
for i, f in enumerate(files):
if isinstance(f, str):
from pathlib import Path
path = Path(f)
file_size = path.stat().st_size if path.exists() else 0
file_info.append(f"File {i+1}: {path.name} ({file_size} bytes)")
print(f"πŸ” DEBUG: File {i+1}: path={f}, exists={path.exists()}, size={file_size}, name={path.name}")
else:
file_info.append(f"Type: {type(f).__name__}, Value: {str(f)[:100]}")
print(f"πŸ” DEBUG: File {i+1}: type={type(f).__name__}, value={str(f)[:100]}")
try:
print(f"πŸ” DEBUG: Attempting to load {len(files)} images...")
images = _load_images_from_files(files)
print(f"πŸ” DEBUG: Successfully loaded {len(images)} images from {len(files)} files")
if not images:
error_msg = "Could not load images from uploaded files.\n\n"
error_msg += f"Files received: {len(files)}\n"
error_msg += f"File details: {', '.join(file_info[:3])}\n\n"
error_msg += f"Supported formats: {', '.join(get_supported_extensions())}\n"
error_msg += "Please ensure files are valid image files (JPG, PNG, WEBP, GIF, BMP, TIFF, etc.)"
print(f"πŸ” DEBUG: ERROR - No images loaded. Files: {len(files)}, Images: {len(images)}")
return [], {"error": error_msg, "files_received": len(files), "file_info": file_info[:5]}
except Exception as e:
import traceback
error_msg = f"Error processing files: {str(e)}\n\n"
error_msg += f"Files received: {len(files)}\n"
error_msg += f"File details: {', '.join(file_info[:3])}\n\n"
error_msg += "Please check that files are valid image files."
print(f"πŸ” DEBUG: EXCEPTION during image loading: {e}")
traceback.print_exc()
return [], {"error": error_msg, "exception": str(e), "files_received": len(files)}
# Tag normalization: Map frontend values to backend values
# This handles variations and synonyms from different frontend implementations
def normalize_tag_value(tag_name: str, value: str) -> str:
"""Normalize tag values to match backend expectations"""
if not value or value == "None" or value is None:
return None
value_lower = value.lower().strip()
# Color preference mappings - normalize variations to standard values
if tag_name == "color_preference":
color_mappings = {
"monochrome": "monochromatic", # Frontend uses "monochrome", backend uses "monochromatic"
"mono": "monochromatic",
"single_color": "monochromatic",
"one_color": "monochromatic",
}
normalized = color_mappings.get(value_lower, value)
# Validate against allowed values
allowed_colors = ["neutral", "monochromatic", "complementary", "bold", "subtle",
"bright", "muted", "pastel", "dark", "light", "earth_tones",
"jewel_tones", "black_white", "navy_white", "colorful", "minimal_color"]
if normalized in allowed_colors:
return normalized
return value # Return original if not in allowed list
# Fit preference mappings
if tag_name == "fit_preference":
fit_mappings = {
"slim": "fitted",
"tight_fit": "fitted",
"baggy": "loose",
"wide": "loose",
}
normalized = fit_mappings.get(value_lower, value)
# Validate against allowed values
allowed_fits = ["fitted", "loose", "oversized", "relaxed", "comfortable",
"structured", "flowy", "tailored", "athletic_fit", "regular_fit"]
if normalized in allowed_fits:
return normalized
return value
# Season mappings - both "fall" and "autumn" are valid, keep as-is
if tag_name == "season":
# Both "fall" and "autumn" are in the Literal type, so no normalization needed
return value
# Return original value if no mapping found
return value
# Normalize all tag values
occasion = normalize_tag_value("occasion", occasion) or occasion
weather = normalize_tag_value("weather", weather) or weather
outfit_style = normalize_tag_value("outfit_style", outfit_style) or outfit_style
color_preference = normalize_tag_value("color_preference", color_preference) if color_preference else None
fit_preference = normalize_tag_value("fit_preference", fit_preference) if fit_preference else None
material_preference = normalize_tag_value("material_preference", material_preference) if material_preference else None
season = normalize_tag_value("season", season) if season else None
time_of_day = normalize_tag_value("time_of_day", time_of_day) if time_of_day else None
budget = normalize_tag_value("budget", budget) if budget else None
personal_style = normalize_tag_value("personal_style", personal_style) if personal_style else None
# Build comprehensive context with all tags
context = {
"occasion": occasion,
"weather": weather,
"style": outfit_style,
"outfit_style": outfit_style, # Backward compatibility
"num_outfits": int(num_outfits)
}
# Add optional tags if provided
if color_preference and color_preference != "None":
context["color_preference"] = color_preference
if fit_preference and fit_preference != "None":
context["fit_preference"] = fit_preference
if material_preference and material_preference != "None":
context["material_preference"] = material_preference
if season and season != "None":
context["season"] = season
if time_of_day and time_of_day != "None":
context["time_of_day"] = time_of_day
if budget and budget != "None":
context["budget"] = budget
if personal_style and personal_style != "None":
context["personal_style"] = personal_style
# Build items that allow on-the-fly embedding in service
items = [
{"id": f"item_{i}", "image": images[i], "category": None}
for i in range(len(images))
]
print(f"πŸ” DEBUG: Calling compose_outfits with {len(items)} items, context={context}")
res = service.compose_outfits(items, context=context)
print(f"πŸ” DEBUG: compose_outfits returned {len(res) if res else 0} results")
if res:
print(f"πŸ” DEBUG: First result type: {type(res[0])}, keys: {res[0].keys() if isinstance(res[0], dict) else 'N/A'}")
# Check if compose_outfits returned an error
if res and isinstance(res[0], dict) and "error" in res[0]:
print(f"πŸ” DEBUG: Error in compose_outfits result: {res[0]}")
return [], res[0]
# Prepare stitched previews - save to temp files for Gradio API compatibility
strips: List[str] = [] # Changed to List[str] for file paths
print(f"πŸ” DEBUG: Preparing stitched previews for {len(res)} outfits...")
for i, r in enumerate(res):
idxs = []
item_ids = r.get("item_ids", [])
print(f"πŸ” DEBUG: Outfit {i+1}: item_ids={item_ids}")
for iid in item_ids:
try:
idx = int(str(iid).split("_")[-1])
idxs.append(idx)
print(f"πŸ” DEBUG: Mapped {iid} -> index {idx}")
except Exception as e:
print(f"πŸ” DEBUG: Failed to parse {iid}: {e}")
continue
imgs = [images[i] for i in idxs if 0 <= i < len(images)]
print(f"πŸ” DEBUG: Extracted {len(imgs)} images from indices {idxs}")
if imgs:
strip = _stitch_strip(imgs)
print(f"πŸ” DEBUG: Created stitched image: {strip.size}")
# Save to temporary file (Gradio will convert to URL)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='/tmp')
strip.save(temp_file.name, 'PNG')
temp_file.close()
strips.append(temp_file.name) # Return file path instead of PIL Image
print(f"πŸ” DEBUG: Saved to temp file: {temp_file.name}")
else:
print(f"⚠️ DEBUG: No images extracted for outfit {i+1}")
print(f"πŸ” DEBUG: Returning {len(strips)} stitched image file paths and {len(res)} outfit results")
return strips, {"outfits": res}
def start_training_advanced(
# Dataset size
dataset_size: str,
# ResNet parameters
resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str,
resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int,
resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float,
# ViT parameters
vit_epochs: int, vit_batch_size: int, vit_max_samples: int, vit_lr: float, vit_optimizer: str,
vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int,
vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float,
# Advanced parameters
use_mixed_precision: bool, channels_last: bool, gradient_clip: float,
warmup_epochs: int, scheduler_type: str, early_stopping_patience: int,
mining_strategy: str, augmentation_level: str, seed: int
):
"""Start advanced training with custom parameters."""
# Use global dataset size if not specified
if not dataset_size or dataset_size == "full":
dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000")
if not DATASET_ROOT:
return "❌ Dataset not ready. Please wait for bootstrap to complete."
log_message = "πŸš€ Advanced training started with custom parameters! Check the log below for progress."
def _runner():
nonlocal log_message
try:
import subprocess
import json
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
# Create custom config files
resnet_config = {
"model": {
"backbone": resnet_backbone,
"embedding_dim": resnet_embedding_dim,
"pretrained": resnet_use_pretrained,
"dropout": resnet_dropout
},
"training": {
"batch_size": resnet_batch_size,
"epochs": resnet_epochs,
"lr": resnet_lr,
"weight_decay": resnet_weight_decay,
"triplet_margin": resnet_triplet_margin,
"optimizer": resnet_optimizer,
"scheduler": scheduler_type,
"warmup_epochs": warmup_epochs,
"early_stopping_patience": early_stopping_patience,
"use_amp": use_mixed_precision,
"channels_last": channels_last,
"gradient_clip": gradient_clip
},
"data": {
"image_size": 224,
"augmentation_level": augmentation_level
},
"advanced": {
"mining_strategy": mining_strategy,
"seed": seed
}
}
vit_config = {
"model": {
"embedding_dim": vit_embedding_dim,
"num_layers": vit_num_layers,
"num_heads": vit_num_heads,
"ff_multiplier": vit_ff_multiplier,
"dropout": vit_dropout
},
"training": {
"batch_size": vit_batch_size,
"epochs": vit_epochs,
"lr": vit_lr,
"weight_decay": vit_weight_decay,
"triplet_margin": vit_triplet_margin,
"optimizer": vit_optimizer,
"scheduler": scheduler_type,
"warmup_epochs": warmup_epochs,
"early_stopping_patience": early_stopping_patience,
"use_amp": use_mixed_precision
},
"advanced": {
"mining_strategy": mining_strategy,
"seed": seed
}
}
# Save configs
with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f:
json.dump(resnet_config, f, indent=2)
with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f:
json.dump(vit_config, f, indent=2)
# Train ResNet with custom parameters
log_message = f"πŸš€ Starting ResNet training with custom parameters...\n"
log_message += f"Dataset Size: {dataset_size} samples\n"
log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
# Add dataset size limit if not full
dataset_args = []
if dataset_size != "full":
dataset_args = ["--max_samples", dataset_size]
resnet_cmd = [
"python", "training/train_resnet.py",
"--data_root", DATASET_ROOT,
"--epochs", str(resnet_epochs),
"--batch_size", str(resnet_batch_size),
"--lr", str(resnet_lr),
"--weight_decay", str(resnet_weight_decay),
"--triplet_margin", str(resnet_triplet_margin),
"--embedding_dim", str(resnet_embedding_dim),
"--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth")
] + dataset_args
if resnet_backbone != "resnet50":
resnet_cmd.extend(["--backbone", resnet_backbone])
result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
log_message += "βœ… ResNet training completed successfully!\n"
log_message += f"πŸ“Š ResNet Output:\n{result.stdout}\n\n"
else:
log_message += f"❌ ResNet training failed: {result.stderr}\n\n"
return log_message
# Wait a moment for file system sync and ensure ResNet is fully saved
import time
time.sleep(3)
log_message += "⏳ Waiting for ResNet checkpoint to be fully saved...\n"
# Verify ResNet checkpoint exists before proceeding
resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder_custom.pth")
if not os.path.exists(resnet_checkpoint):
log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n"
log_message += "Cannot proceed with ViT training without ResNet embeddings.\n"
return log_message
log_message += f"βœ… ResNet checkpoint verified: {resnet_checkpoint}\n"
# Train ViT with custom parameters
log_message += f"πŸš€ Starting ViT training with custom parameters...\n"
log_message += f"Dataset Size: {dataset_size} samples\n"
log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
vit_cmd = [
"python", "training/train_vit.py",
"--data_root", DATASET_ROOT,
"--epochs", str(vit_epochs),
"--batch_size", str(vit_batch_size),
"--max_samples", str(vit_max_samples),
"--lr", str(vit_lr),
"--weight_decay", str(vit_weight_decay),
"--triplet_margin", str(vit_triplet_margin),
"--embedding_dim", str(vit_embedding_dim),
"--export", os.path.join(export_dir, "vit_outfit_model_custom.pth")
] + dataset_args
result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
log_message += "βœ… ViT training completed successfully!\n"
log_message += f"πŸ“Š ViT Output:\n{result.stdout}\n\n"
log_message += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
log_message += "πŸ”„ Reloading models for inference...\n"
service.reload_models()
# Check if models loaded successfully
model_status = service.get_model_status()
if model_status["can_recommend"]:
log_message += "βœ… Models reloaded and ready for inference!\n"
log_message += "πŸŽ‰ You can now generate outfit recommendations!\n"
else:
log_message += "⚠️ Models reloaded but validation failed!\n"
log_message += "**Model Status:**\n"
log_message += f"- ResNet: {'βœ… Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n"
log_message += f"- ViT: {'βœ… Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n"
for error in model_status["errors"]:
log_message += f"- {error}\n"
# Auto-upload to HF Hub if token is available
hf_token = os.getenv("HF_TOKEN")
if hf_token:
log_message += "πŸ“€ Auto-uploading artifacts to Hugging Face Hub...\n"
try:
from utils.hf_utils import HFModelManager
hf = HFModelManager(token=hf_token, username="Stylique")
result = hf.upload_model("everything", "dressify-complete")
if result.get("success"):
log_message += "βœ… Successfully uploaded to HF Hub!\n"
log_message += "πŸ”— Models: https://huggingface.co/Stylique/dressify-models\n"
log_message += "πŸ”— Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n"
else:
log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n"
except Exception as e:
log_message += f"⚠️ Auto-upload failed: {str(e)}\n"
else:
log_message += "πŸ’‘ Set HF_TOKEN env var for automatic uploads\n"
else:
log_message += f"❌ ViT training failed: {result.stderr}\n"
except Exception as e:
log_message += f"\n❌ Training error: {str(e)}"
threading.Thread(target=_runner, daemon=True).start()
return log_message
def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
"""Start simple training with basic parameters."""
# Use global dataset size if not specified
if not dataset_size or dataset_size == "full":
dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000")
log_message = f"Starting training on {dataset_size} samples..."
def _runner():
nonlocal log_message
try:
import subprocess
if not DATASET_ROOT:
log_message = "Dataset not ready."
return
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
log_message = f"Training ResNet on {dataset_size} samples...\n"
# Add dataset size limit if not full
dataset_args = []
if dataset_size != "full":
dataset_args = ["--max_samples", dataset_size]
# Train ResNet first and wait for completion
log_message += f"\nπŸš€ Starting ResNet training on {dataset_size} samples...\n"
resnet_result = subprocess.run([
"python", "training/train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
"--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3",
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
] + dataset_args, capture_output=True, text=True, check=False)
if resnet_result.returncode == 0:
log_message += "βœ… ResNet training completed successfully!\n"
log_message += f"πŸ“Š ResNet Output:\n{resnet_result.stdout}\n"
else:
log_message += f"❌ ResNet training failed: {resnet_result.stderr}\n"
return log_message
# Wait a moment for file system sync
import time
time.sleep(2)
# Verify ResNet checkpoint exists before proceeding
resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder.pth")
if not os.path.exists(resnet_checkpoint):
log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n"
log_message += "Cannot proceed with ViT training without ResNet embeddings.\n"
return log_message
log_message += f"βœ… ResNet checkpoint verified: {resnet_checkpoint}\n"
log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
vit_result = subprocess.run([
"python", "training/train_vit.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
"--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
"--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
"--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
] + dataset_args, capture_output=True, text=True, check=False)
if vit_result.returncode == 0:
log_message += "βœ… ViT training completed successfully!\n"
log_message += f"πŸ“Š ViT Output:\n{vit_result.stdout}\n"
else:
log_message += f"❌ ViT training failed: {vit_result.stderr}\n"
return log_message
service.reload_models()
# Check if models loaded successfully
model_status = service.get_model_status()
if model_status["can_recommend"]:
log_message += "\nβœ… Training completed! Models reloaded and ready for inference.\n"
log_message += "πŸŽ‰ You can now generate outfit recommendations!\n"
else:
log_message += "\n⚠️ Training completed but models failed to load properly!\n"
log_message += "**Model Status:**\n"
log_message += f"- ResNet: {'βœ… Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n"
log_message += f"- ViT: {'βœ… Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n"
for error in model_status["errors"]:
log_message += f"- {error}\n"
log_message += "\nArtifacts saved to models/exports/"
# Auto-upload to HF Hub if token is available
hf_token = os.getenv("HF_TOKEN")
if hf_token:
log_message += "\nπŸ“€ Auto-uploading artifacts to Hugging Face Hub...\n"
try:
from utils.hf_utils import HFModelManager
hf = HFModelManager(token=hf_token, username="Stylique")
result = hf.upload_model("everything", "dressify-complete")
if result.get("success"):
log_message += "βœ… Successfully uploaded to HF Hub!\n"
log_message += "πŸ”— Models: https://huggingface.co/Stylique/dressify-models\n"
log_message += "πŸ”— Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n"
else:
log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n"
except Exception as e:
log_message += f"⚠️ Auto-upload failed: {str(e)}\n"
else:
log_message += "\nπŸ’‘ Set HF_TOKEN env var for automatic uploads\n"
except Exception as e:
log_message += f"\nError: {e}"
threading.Thread(target=_runner, daemon=True).start()
return log_message
with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
gr.Markdown("## πŸ† Dressify – Advanced Outfit Recommendation System\n*Research-grade, self-contained outfit recommendation with comprehensive training controls*")
gr.Markdown("πŸ’‘ **Pro Tip**: Start with 2000 samples for quick testing, then increase to 50000+ for production training!")
with gr.Tab("🎨 Recommend"):
gr.Markdown("### 🎯 Personalized Outfit Recommendations\n*Upload your wardrobe and customize recommendations with advanced tag preferences*")
gr.Markdown(f"**Supported Formats:** {', '.join(get_supported_extensions())} (JPG, PNG, WEBP, GIF, BMP, TIFF, and more)")
inp2 = gr.Files(
label="Upload wardrobe images",
file_count="multiple"
# Note: file_types removed to allow API client flexibility
# Validation is handled by our image_utils.load_images_from_files()
)
with gr.Accordion("🎯 Primary Tags (Required)", open=True):
with gr.Row():
occasion = gr.Dropdown(
choices=["casual", "business", "formal", "semi_formal", "business_casual", "cocktail",
"wedding", "party", "date", "sport", "workout", "travel", "beach", "outdoor",
"night_out", "brunch", "dinner", "meeting", "interview", "cultural", "traditional"],
value="casual",
label="Occasion",
info="Select the occasion or event type"
)
weather = gr.Dropdown(
choices=["any", "hot", "warm", "mild", "cool", "cold", "freezing", "rain", "snow",
"windy", "humid", "sunny", "cloudy"],
value="any",
label="Weather",
info="Current or expected weather conditions"
)
outfit_style = gr.Dropdown(
choices=["casual", "smart_casual", "formal", "sporty", "athletic", "streetwear",
"minimalist", "classic", "modern", "elegant", "sophisticated", "traditional", "ethnic"],
value="casual",
label="Outfit Style",
info="Preferred fashion aesthetic"
)
with gr.Accordion("🎨 Style & Preference Tags (Optional)", open=False):
with gr.Row():
color_preference = gr.Dropdown(
choices=["None", "neutral", "monochromatic", "monochrome", "complementary", "bold", "subtle",
"bright", "muted", "pastel", "dark", "light", "earth_tones", "jewel_tones",
"black_white", "navy_white", "colorful", "minimal_color"],
value="None",
label="Color Preference",
info="Preferred color scheme"
)
fit_preference = gr.Dropdown(
choices=["None", "fitted", "loose", "oversized", "relaxed", "comfortable",
"structured", "flowy", "tailored", "athletic_fit", "regular_fit"],
value="None",
label="Fit Preference",
info="Preferred fit and silhouette"
)
material_preference = gr.Dropdown(
choices=["None", "cotton", "linen", "silk", "wool", "cashmere", "denim",
"leather", "breathable", "waterproof", "moisture_wicking", "sustainable"],
value="None",
label="Material Preference",
info="Preferred fabric or material type"
)
with gr.Accordion("πŸ“… Context Tags (Optional)", open=False):
with gr.Row():
season = gr.Dropdown(
choices=["None", "spring", "summer", "fall", "autumn", "winter", "year_round", "transitional"],
value="None",
label="Season",
info="Current season"
)
time_of_day = gr.Dropdown(
choices=["None", "morning", "afternoon", "evening", "night", "all_day"],
value="None",
label="Time of Day",
info="When will you wear this outfit?"
)
budget = gr.Dropdown(
choices=["None", "luxury", "premium", "mid_range", "affordable", "budget", "value"],
value="None",
label="Budget Preference",
info="Price range preference (informational)"
)
personal_style = gr.Dropdown(
choices=["None", "conservative", "moderate", "bold", "experimental", "traditional",
"trendy", "timeless", "fashion_forward", "classic", "eclectic"],
value="None",
label="Personal Style",
info="Your personal style preference"
)
with gr.Row():
num_outfits = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Outfits", info="How many outfit recommendations to generate")
out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=400, show_label=True)
out_json = gr.JSON(label="Outfit Details & Tag Analysis", show_label=True)
btn2 = gr.Button("✨ Generate Personalized Outfits", variant="primary", size="lg")
btn2.click(
fn=gradio_recommend,
inputs=[inp2, occasion, weather, num_outfits, outfit_style,
color_preference, fit_preference, material_preference,
season, time_of_day, budget, personal_style],
outputs=[out_gallery, out_json]
)
with gr.Tab("πŸ”¬ Advanced Training"):
gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
# Dataset Preparation Section
with gr.Accordion("πŸ“¦ Dataset Preparation (Optional)", open=False):
gr.Markdown("**Note**: Dataset preparation is now manual only. Click the button below to download and prepare the dataset when needed.")
with gr.Row():
prepare_dataset_btn = gr.Button("πŸ“₯ Download & Prepare Dataset", variant="secondary")
prepare_status = gr.Textbox(label="Dataset Preparation Status", value="Dataset will be prepared if missing", interactive=False)
def prepare_dataset_manual():
"""Manually trigger dataset preparation."""
global DATASET_ROOT, BOOT_STATUS
try:
BOOT_STATUS = "preparing-dataset"
# Check if dataset already exists
root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore"))
images_dir = os.path.join(root, "images")
has_images = os.path.isdir(images_dir) and any(os.listdir(images_dir))
if has_images:
print("βœ… Images already exist, skipping download/extraction")
ds_root = root
else:
print("πŸ“₯ Downloading and extracting dataset...")
ds_root = ensure_dataset_ready()
DATASET_ROOT = ds_root
if not ds_root:
BOOT_STATUS = "dataset-not-prepared"
return "❌ Failed to prepare dataset"
# Prepare splits if missing
splits_dir = os.path.join(ds_root, "splits")
has_splits = (
os.path.isfile(os.path.join(splits_dir, "train.json")) or
os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json"))
)
if not has_splits:
os.makedirs(splits_dir, exist_ok=True)
from scripts.prepare_polyvore import main as prepare_main
os.environ.setdefault("PYTHONWARNINGS", "ignore")
import sys
argv_bak = sys.argv
try:
sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "500"]
prepare_main()
BOOT_STATUS = "ready"
return "βœ… Dataset and splits prepared successfully!"
finally:
sys.argv = argv_bak
else:
BOOT_STATUS = "ready"
return "βœ… Dataset already prepared (images and splits exist)"
except Exception as e:
BOOT_STATUS = "error"
import traceback
return f"❌ Error: {str(e)}\n{traceback.format_exc()}"
prepare_dataset_btn.click(fn=prepare_dataset_manual, inputs=[], outputs=prepare_status)
# Global Dataset Size Control
with gr.Row():
gr.Markdown("#### 🎯 **Global Dataset Size Control**")
gr.Markdown("**Note**: Use 'Apply' button to regenerate splits with different size limits.")
with gr.Row():
gr.Markdown("#### πŸ“Š **Current Behavior**")
gr.Markdown("β€’ **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **500 samples by default**\nβ€’ **Training**: Uses 500 samples (ultra-fast training!)\nβ€’ **Apply Button**: Regenerates splits with your selected size limit")
with gr.Row():
global_dataset_size = gr.Dropdown(
choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"],
value="500",
label="Global Dataset Size (Affects Prep + Training)"
)
gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)")
with gr.Row():
# Apply dataset size button
apply_size_btn = gr.Button("πŸ”„ Apply Dataset Size & Regenerate Splits", variant="primary")
size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 500 samples (click Apply to regenerate splits)", interactive=False)
# Current dataset info
gr.Markdown("#### πŸ“Š **Current Dataset Status**")
gr.Markdown("β€’ **Full dataset downloaded**: 53,306 outfits (required for system)\nβ€’ **Splits generated**: **500 samples by default** (ultra-fast training!)\nβ€’ **Training will use**: 500 samples (ultra-fast training!)\nβ€’ **Scale up**: Use Apply button to increase to larger sizes")
def apply_dataset_size(size: str):
"""Apply global dataset size and regenerate splits."""
try:
if size == "full":
return f"βœ… Using full dataset ({size}) - no size limit applied"
# Call the dataset preparation with size limit
import subprocess
import os
# Set environment variable for dataset size
os.environ["DATASET_SIZE_LIMIT"] = size
# Check if script exists
script_path = "scripts/prepare_polyvore.py"
if not os.path.exists(script_path):
return f"❌ Script not found: {script_path}"
# Regenerate splits with size limit using subprocess
cmd = [
"python", script_path,
"--root", "/home/user/app/data/Polyvore",
"--out", "/home/user/app/data/Polyvore/splits",
"--max_samples", size
]
print(f"Running command: {' '.join(cmd)}")
print(f"Current working directory: {os.getcwd()}")
# Run from the correct directory
result = subprocess.run(cmd, capture_output=True, text=True, check=False, cwd="/home/user/app")
if result.returncode == 0:
return f"βœ… Successfully regenerated splits with {size} samples limit"
else:
error_msg = f"❌ Failed to regenerate splits:\n"
error_msg += f"Return code: {result.returncode}\n"
error_msg += f"STDOUT: {result.stdout}\n"
error_msg += f"STDERR: {result.stderr}"
return error_msg
except Exception as e:
return f"❌ Failed to apply dataset size: {str(e)}"
apply_size_btn.click(fn=apply_dataset_size, inputs=[global_dataset_size], outputs=[size_status])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### πŸ“Š Dataset Size Control")
gr.Markdown("Start small for testing, increase for production training")
dataset_size = gr.Dropdown(
choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"],
value="500",
label="Training Dataset Size"
)
gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training")
with gr.Column(scale=1):
gr.Markdown("#### πŸ–ΌοΈ ResNet Item Embedder")
# Model architecture
resnet_backbone = gr.Dropdown(
choices=["resnet50", "resnet101"],
value="resnet50",
label="Backbone Architecture"
)
resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension")
resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained")
resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate")
# Training parameters
resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
resnet_batch_size = gr.Slider(4, 128, value=4, step=4, label="Batch Size")
resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
resnet_optimizer = gr.Dropdown(
choices=["adamw", "adam", "sgd", "rmsprop"],
value="adamw",
label="Optimizer"
)
resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay")
resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin")
with gr.Column(scale=1):
gr.Markdown("#### 🧠 ViT Outfit Encoder")
# Model architecture
vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension")
vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers")
vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads")
vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier")
vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate")
# Training parameters
vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
vit_batch_size = gr.Slider(2, 64, value=4, step=2, label="Batch Size")
vit_max_samples = gr.Slider(100, 5000, value=500, step=100, label="Max Training Samples")
vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
vit_optimizer = gr.Dropdown(
choices=["adamw", "adam", "sgd", "rmsprop"],
value="adamw",
label="Optimizer"
)
vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay")
vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### βš™οΈ Advanced Training Settings")
# Hardware optimization
use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)")
channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format")
gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping")
# Learning rate scheduling
warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs")
scheduler_type = gr.Dropdown(
choices=["cosine", "step", "plateau", "linear"],
value="cosine",
label="Learning Rate Scheduler"
)
early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience")
# Training strategy
mining_strategy = gr.Dropdown(
choices=["semi_hard", "hardest", "random"],
value="semi_hard",
label="Triplet Mining Strategy"
)
augmentation_level = gr.Dropdown(
choices=["minimal", "standard", "aggressive"],
value="standard",
label="Data Augmentation Level"
)
seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed")
with gr.Column(scale=1):
gr.Markdown("#### πŸš€ Training Control")
# Quick training
gr.Markdown("**Quick Training (Basic Parameters)**")
epochs_res = gr.Slider(1, 50, value=3, step=1, label="ResNet epochs")
epochs_vit = gr.Slider(1, 100, value=3, step=1, label="ViT epochs")
start_btn = gr.Button("πŸš€ Start Quick Training", variant="secondary")
# Advanced training
gr.Markdown("**Advanced Training (Custom Parameters)**")
start_advanced_btn = gr.Button("🎯 Start Advanced Training", variant="primary")
# Training log
train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20)
# Status
gr.Markdown("**Training Status**")
training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False)
# Event handlers
start_btn.click(
fn=start_training_simple,
inputs=[dataset_size, epochs_res, epochs_vit],
outputs=train_log
)
start_advanced_btn.click(
fn=start_training_advanced,
inputs=[
# Dataset size
dataset_size,
# ResNet parameters
resnet_epochs, resnet_batch_size, resnet_lr, resnet_optimizer,
resnet_weight_decay, resnet_triplet_margin, resnet_embedding_dim,
resnet_backbone, resnet_use_pretrained, resnet_dropout,
# ViT parameters
vit_epochs, vit_batch_size, vit_max_samples, vit_lr, vit_optimizer,
vit_weight_decay, vit_triplet_margin, vit_embedding_dim,
vit_num_layers, vit_num_heads, vit_ff_multiplier, vit_dropout,
# Advanced parameters
use_mixed_precision, channels_last, gradient_clip,
warmup_epochs, scheduler_type, early_stopping_patience,
mining_strategy, augmentation_level, seed
],
outputs=train_log
)
with gr.Tab("πŸ“¦ Artifact Management"):
gr.Markdown("### 🎯 Comprehensive Artifact Management\nManage, package, and upload all system artifacts to Hugging Face Hub.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### πŸ“Š Artifact Overview")
artifact_overview = gr.JSON(label="System Artifacts", value=get_artifact_overview)
refresh_overview = gr.Button("πŸ”„ Refresh Overview")
refresh_overview.click(fn=get_artifact_overview, inputs=[], outputs=artifact_overview)
gr.Markdown("#### πŸ“¦ Create Packages")
package_type = gr.Dropdown(
choices=["complete", "splits_only", "models_only"],
value="complete",
label="Package Type"
)
create_package_btn = gr.Button("πŸ“¦ Create Package")
package_result = gr.Textbox(label="Package Result", interactive=False)
available_packages = gr.JSON(label="Available Packages", value=get_available_packages)
create_package_btn.click(
fn=create_download_package,
inputs=[package_type],
outputs=[package_result, available_packages]
)
with gr.Column(scale=1):
gr.Markdown("#### πŸš€ Hugging Face Hub Integration")
gr.Markdown("πŸ’‘ **Pro Tip**: Set `HF_TOKEN` environment variable for automatic uploads after training!")
hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_...")
hf_username = gr.Textbox(label="Username", placeholder="your-username")
with gr.Row():
push_splits_btn = gr.Button("πŸ“€ Push Splits", variant="secondary")
push_models_btn = gr.Button("πŸ“€ Push Models", variant="secondary")
push_everything_btn = gr.Button("πŸ“€ Push Everything", variant="primary")
hf_result = gr.Textbox(label="Upload Result", interactive=False, lines=3)
push_splits_btn.click(fn=push_splits_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
push_models_btn.click(fn=push_models_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
push_everything_btn.click(fn=push_everything_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
gr.Markdown("#### πŸ“₯ Download Management")
individual_files = gr.JSON(label="Individual Files", value=get_individual_files)
download_all_btn = gr.Button("πŸ“₯ Download All as ZIP")
download_result = gr.Textbox(label="Download Result", interactive=False)
download_all_btn.click(fn=download_all_files, inputs=[], outputs=download_result)
with gr.Tab("πŸ“ˆ Status"):
gr.Markdown("### 🚦 System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.")
status = gr.Textbox(label="Bootstrap Status", value=lambda: BOOT_STATUS)
refresh_status = gr.Button("πŸ”„ Refresh Status")
refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
# Model Status
gr.Markdown("#### πŸ€– Model Status")
model_status = gr.JSON(label="Model Loading Status", value=lambda: service.get_model_status())
refresh_models = gr.Button("πŸ”„ Refresh Model Status")
refresh_models.click(fn=lambda: service.get_model_status(), inputs=[], outputs=model_status)
# System info
gr.Markdown("#### πŸ’» System Information")
device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}")
resnet_version = gr.Textbox(label="ResNet Version", value=lambda: f"ResNet: {service.resnet_version}")
vit_version = gr.Textbox(label="ViT Version", value=lambda: f"ViT: {service.vit_version}")
# Health check
gr.Markdown("#### πŸ₯ Health Check")
health_btn = gr.Button("πŸ” Check Health")
health_status = gr.Textbox(label="Health Status", value="Click to check")
def check_health():
try:
health = app.get("/health")
return f"βœ… System Healthy - {health}"
except Exception as e:
return f"❌ Health Check Failed: {str(e)}"
health_btn.click(fn=check_health, inputs=[], outputs=health_status)
try:
# Mount Gradio onto FastAPI root path (disable SSR to avoid stray port fetches)
demo.queue()
app = gr.mount_gradio_app(app, demo, path="/")
except Exception:
# In case mounting fails in certain runners, we still want FastAPI to be available
pass
# Mount static files for direct artifact download
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
try:
app.mount("/files", StaticFiles(directory=export_dir), name="files")
except Exception:
pass
if __name__ == "__main__":
# Local/Space run
demo.queue().launch(ssr_mode=False)