import os import gc import threading from typing import Optional import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image # ─────────────────────────── CONFIG ──────────────────────────── # SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights") EXAMPLES_DIR = os.path.join(SCRIPT_DIR, "examples", "data") CATEGORIES = ["tops", "bottoms", "one-pieces"] GARMENT_PHOTO_TYPES = ["model", "flat-lay"] # ──────────────────────── WEIGHT DOWNLOAD ────────────────────── # def download_weights(): """Download model weights from HuggingFace Hub (skips if already present).""" os.makedirs(WEIGHTS_DIR, exist_ok=True) dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose") os.makedirs(dwpose_dir, exist_ok=True) tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors") if not os.path.exists(tryon_path): print("Downloading TryOnModel weights...") hf_hub_download( repo_id="fashn-ai/fashn-vton-1.5", filename="model.safetensors", local_dir=WEIGHTS_DIR, ) for filename in ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]: filepath = os.path.join(dwpose_dir, filename) if not os.path.exists(filepath): print(f"Downloading DWPose/{filename}...") hf_hub_download( repo_id="fashn-ai/DWPose", filename=filename, local_dir=dwpose_dir, ) print("All weights ready!") # Download weights at startup download_weights() # ──────────────────────── PIPELINE LOADER ────────────────────── # _pipeline_lock = threading.Lock() _pipeline: Optional[object] = None def get_pipeline(): """Thread-safe lazy pipeline loader (CPU mode).""" global _pipeline with _pipeline_lock: if _pipeline is None: from fashn_vton import TryOnPipeline print("Loading pipeline on CPU...") _pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cpu") print("Pipeline ready!") return _pipeline # ─────────────────────────── INFERENCE ───────────────────────── # def try_on( person_image, garment_image, category: str, garment_photo_type: str, num_timesteps: int, guidance_scale: float, seed: int, segmentation_free: bool, ): """Run virtual try-on inference.""" if person_image is None: raise gr.Error("Please upload a person image.") if garment_image is None: raise gr.Error("Please upload a garment image.") # Normalise seed if seed is None or seed < 0: seed = 42 seed = int(seed) # Ensure PIL RGB def to_pil(x): if isinstance(x, np.ndarray): x = Image.fromarray(x) if isinstance(x, Image.Image): return x.convert("RGB") return Image.open(x).convert("RGB") person_img = to_pil(person_image) garment_img = to_pil(garment_image) pipeline = get_pipeline() try: result = pipeline( person_image=person_img, garment_image=garment_img, category=category, garment_photo_type=garment_photo_type, num_samples=1, num_timesteps=num_timesteps, guidance_scale=guidance_scale, seed=seed, segmentation_free=segmentation_free, ) return result.images[0], "✅ Done!" except Exception as e: return None, f"❌ Error: {e}" # ─────────────────────────── GRADIO UI ───────────────────────── # CUSTOM_CSS = """ body { font-family: 'Inter', sans-serif; } .contain img { object-fit: contain !important; max-height: 520px !important; } #run-btn { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; border: none !important; color: white !important; font-size: 1.1rem !important; font-weight: 600 !important; padding: 0.75rem !important; border-radius: 12px !important; transition: opacity 0.2s; } #run-btn:hover { opacity: 0.85; } .status-box textarea { font-size: 0.9rem !important; color: #a3e635 !important; background: #1e1e2e !important; border-radius: 8px !important; } .gr-accordion { border-radius: 10px !important; } """ BANNER_MD = """ # 👗 FASHN VTON — Virtual Try-On Upload a **person image** and a **garment image**, choose the garment category and hit **Try On**! > ⚠️ Running on **CPU** — inference may take a few minutes. Reduce *Sampling Steps* for faster results. """ TIPS_HTML = """
💡 Tips for best results:
👤 Single person, clearly visible
|
👕 Match category to garment type
|
📸 Use "flat-lay" for product shots
|
📐 2:3 aspect ratio optimal
""" person_example = os.path.join(EXAMPLES_DIR, "model.jpeg") garment_example = os.path.join(EXAMPLES_DIR, "garment.jpeg") with gr.Blocks(css=CUSTOM_CSS, title="FASHN VTON — Virtual Try-On") as demo: gr.Markdown(BANNER_MD) gr.HTML(TIPS_HTML) with gr.Row(equal_height=False): # ── Column 1 : Person ────────────────────────────────── with gr.Column(scale=1): person_in = gr.Image( label="Person Image", type="pil", sources=["upload", "clipboard"], elem_classes=["contain"], ) if os.path.exists(person_example): gr.Examples( examples=[[person_example]], inputs=[person_in], label="Person Example", ) # ── Column 2 : Garment ───────────────────────────────── with gr.Column(scale=1): garment_in = gr.Image( label="Garment Image", type="pil", sources=["upload", "clipboard"], elem_classes=["contain"], ) with gr.Row(): category = gr.Dropdown( choices=CATEGORIES, value="tops", label="Category", ) garment_photo_type = gr.Dropdown( choices=GARMENT_PHOTO_TYPES, value="model", label="Photo Type", ) if os.path.exists(garment_example): gr.Examples( examples=[[garment_example]], inputs=[garment_in], label="Garment Example", ) # ── Column 3 : Result ────────────────────────────────── with gr.Column(scale=1): result_img = gr.Image( label="Try-On Result", type="pil", interactive=False, elem_classes=["contain"], ) status = gr.Textbox( value="Ready", label="Status", interactive=False, elem_classes=["status-box"], ) run_btn = gr.Button("👗 Try On", variant="primary", elem_id="run-btn") with gr.Accordion("⚙️ Advanced Settings", open=False): num_timesteps = gr.Slider( minimum=10, maximum=50, value=30, step=5, label="Sampling Steps", info="Higher = better quality but slower. 30 is a good balance.", ) guidance_scale = gr.Slider( minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="Guidance Scale", info="How closely to follow the garment details. 1.5 recommended.", ) seed = gr.Number( value=42, label="Seed", precision=0, info="Change seed to get a different variation of the result.", ) segmentation_free = gr.Checkbox( value=True, label="Segmentation-Free (Recommended)", info="Preserves body features and allows unconstrained garment volume.", ) # ── Event ────────────────────────────────────────────────── run_btn.click( fn=try_on, inputs=[ person_in, garment_in, category, garment_photo_type, num_timesteps, guidance_scale, seed, segmentation_free, ], outputs=[result_img, status], ) demo.queue(default_concurrency_limit=1, max_size=10) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", share=False)