import os import gc import threading from typing import Optional import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image # ─────────────────────────── CONFIG ──────────────────────────── # SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights") EXAMPLES_DIR = os.path.join(SCRIPT_DIR, "examples", "data") CATEGORIES = ["tops", "bottoms", "one-pieces"] GARMENT_PHOTO_TYPES = ["model", "flat-lay"] # ──────────────────────── WEIGHT DOWNLOAD ────────────────────── # def download_weights(): """Download model weights from HuggingFace Hub (skips if already present).""" os.makedirs(WEIGHTS_DIR, exist_ok=True) dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose") os.makedirs(dwpose_dir, exist_ok=True) tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors") if not os.path.exists(tryon_path): print("Downloading TryOnModel weights...") hf_hub_download( repo_id="fashn-ai/fashn-vton-1.5", filename="model.safetensors", local_dir=WEIGHTS_DIR, ) for filename in ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]: filepath = os.path.join(dwpose_dir, filename) if not os.path.exists(filepath): print(f"Downloading DWPose/{filename}...") hf_hub_download( repo_id="fashn-ai/DWPose", filename=filename, local_dir=dwpose_dir, ) print("All weights ready!") # Download weights at startup download_weights() # ──────────────────────── PIPELINE LOADER ────────────────────── # _pipeline_lock = threading.Lock() _pipeline: Optional[object] = None def get_pipeline(): """Thread-safe lazy pipeline loader (CPU mode).""" global _pipeline with _pipeline_lock: if _pipeline is None: from fashn_vton import TryOnPipeline print("Loading pipeline on CPU...") _pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cpu") print("Pipeline ready!") return _pipeline # ─────────────────────────── INFERENCE ───────────────────────── # def try_on( person_image, garment_image, category: str, garment_photo_type: str, num_timesteps: int, guidance_scale: float, seed: int, segmentation_free: bool, ): """Run virtual try-on inference.""" if person_image is None: raise gr.Error("Please upload a person image.") if garment_image is None: raise gr.Error("Please upload a garment image.") # Normalise seed if seed is None or seed < 0: seed = 42 seed = int(seed) # Ensure PIL RGB def to_pil(x): if isinstance(x, np.ndarray): x = Image.fromarray(x) if isinstance(x, Image.Image): return x.convert("RGB") return Image.open(x).convert("RGB") person_img = to_pil(person_image) garment_img = to_pil(garment_image) pipeline = get_pipeline() try: result = pipeline( person_image=person_img, garment_image=garment_img, category=category, garment_photo_type=garment_photo_type, num_samples=1, num_timesteps=num_timesteps, guidance_scale=guidance_scale, seed=seed, segmentation_free=segmentation_free, ) return result.images[0], "✅ Done!" except Exception as e: return None, f"❌ Error: {e}" # ─────────────────────────── GRADIO UI ───────────────────────── # CUSTOM_CSS = """ body { font-family: 'Inter', sans-serif; } .contain img { object-fit: contain !important; max-height: 520px !important; } #run-btn { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; border: none !important; color: white !important; font-size: 1.1rem !important; font-weight: 600 !important; padding: 0.75rem !important; border-radius: 12px !important; transition: opacity 0.2s; } #run-btn:hover { opacity: 0.85; } .status-box textarea { font-size: 0.9rem !important; color: #a3e635 !important; background: #1e1e2e !important; border-radius: 8px !important; } .gr-accordion { border-radius: 10px !important; } """ BANNER_MD = """ # 👗 FASHN VTON — Virtual Try-On Upload a **person image** and a **garment image**, choose the garment category and hit **Try On**! > ⚠️ Running on **CPU** — inference may take a few minutes. Reduce *Sampling Steps* for faster results. """ TIPS_HTML = """