"""FASHN VTON v1.5 HuggingFace Space Demo.""" import os import platform import gradio as gr import spaces import torch from huggingface_hub import hf_hub_download from PIL import Image # ----------------- CONFIG ----------------- # SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) ASSETS_DIR = os.path.join(SCRIPT_DIR, "assets") WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights") CATEGORIES = ["tops", "bottoms", "one-pieces"] GARMENT_PHOTO_TYPES = ["model", "flat-lay"] # Global pipeline instance (lazy loaded) _pipeline = None # ----------------- HELPERS ----------------- # def download_weights(): """Download model weights from HuggingFace Hub.""" os.makedirs(WEIGHTS_DIR, exist_ok=True) dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose") os.makedirs(dwpose_dir, exist_ok=True) # Download TryOnModel weights tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors") if not os.path.exists(tryon_path): print("Downloading TryOnModel weights...") hf_hub_download( repo_id="fashn-ai/fashn-vton-1.5", filename="model.safetensors", local_dir=WEIGHTS_DIR, ) # Download DWPose models dwpose_files = ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"] for filename in dwpose_files: filepath = os.path.join(dwpose_dir, filename) if not os.path.exists(filepath): print(f"Downloading DWPose/{filename}...") hf_hub_download( repo_id="fashn-ai/DWPose", filename=filename, local_dir=dwpose_dir, ) print("Weights downloaded successfully!") # ----------------- MODEL LOADING ----------------- # def get_pipeline(): """Lazy-load the pipeline on first use (ensures GPU is available on ZeroGPU).""" global _pipeline if _pipeline is None: # Check CUDA availability (will be true inside @spaces.GPU context) if not torch.cuda.is_available(): raise gr.Error( "CUDA is not available. This demo requires a GPU to run. " "If you're on HuggingFace Spaces, please try again in a moment." ) # ---------------------------------- Diagnostics ---------------------------------- # print(f"Python : {platform.python_version()}") print(f"PyTorch : {torch.__version__}") print(f" • built for CUDA : {torch.version.cuda}") if torch.backends.cudnn.is_available(): print(f" • built for cuDNN: {torch.backends.cudnn.version()}") print(f"torch.cuda.is_available(): {torch.cuda.is_available()}") if torch.cuda.is_available(): dev = torch.cuda.current_device() cc = torch.cuda.get_device_capability(dev) print(f"GPU {dev}: {torch.cuda.get_device_name(dev)} (compute {cc[0]}.{cc[1]})") # Enable TF32 for faster computation on Ampere+ GPUs if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print("Downloading weights (if needed)...") download_weights() print("Loading pipeline...") from fashn_vton import TryOnPipeline _pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cuda") print("Pipeline loaded on CUDA!") return _pipeline # ----------------- INFERENCE ----------------- # @spaces.GPU def try_on( person_image: Image.Image, garment_image: Image.Image, category: str, garment_photo_type: str, num_timesteps: int, guidance_scale: float, seed: int, segmentation_free: bool, ) -> Image.Image: """Run virtual try-on inference.""" if person_image is None: raise gr.Error("Please upload a person image") if garment_image is None: raise gr.Error("Please upload a garment image") # Handle seed (guard against None or invalid values) if seed is None or seed < 0: seed = 42 # Convert to RGB if needed if person_image.mode != "RGB": person_image = person_image.convert("RGB") if garment_image.mode != "RGB": garment_image = garment_image.convert("RGB") # Get pipeline (lazy loads on first call) pipeline = get_pipeline() # Run inference result = pipeline( person_image=person_image, garment_image=garment_image, category=category, garment_photo_type=garment_photo_type, num_samples=1, num_timesteps=num_timesteps, guidance_scale=guidance_scale, seed=int(seed), segmentation_free=segmentation_free, ) return result.images[0] # ----------------- UI ----------------- # # Custom CSS CUSTOM_CSS = """ .contain img { object-fit: contain !important; max-height: 856px !important; max-width: 576px !important; } """ # Load HTML content with open(os.path.join(SCRIPT_DIR, "banner.html"), "r") as f: banner_html = f.read() with open(os.path.join(SCRIPT_DIR, "tips.html"), "r") as f: tips_html = f.read() # Build example paths examples_dir = os.path.join(ASSETS_DIR, "examples") # Paired examples: [person_path, garment_path, category, garment_photo_type] paired_examples = [ [os.path.join(examples_dir, "person1.png"), os.path.join(examples_dir, "garment1.jpeg"), "one-pieces", "model"], [os.path.join(examples_dir, "person2.png"), os.path.join(examples_dir, "garment2.webp"), "tops", "model"], [os.path.join(examples_dir, "person3.png"), os.path.join(examples_dir, "garment3.jpeg"), "tops", "flat-lay"], [os.path.join(examples_dir, "person4.png"), os.path.join(examples_dir, "garment4.webp"), "tops", "model"], [os.path.join(examples_dir, "person5.png"), os.path.join(examples_dir, "garment5.jpeg"), "bottoms", "flat-lay"], [os.path.join(examples_dir, "person6.png"), os.path.join(examples_dir, "garment6.webp"), "one-pieces", "model"], ] # Individual examples (classic from repo) person_only_examples = [os.path.join(examples_dir, "person0.png")] # Garment examples with their settings: (image_path, category, photo_type) # Order matters - index in Gallery corresponds to this list garment_examples_data = [ (os.path.join(examples_dir, "garment0.png"), "tops", "model"), (os.path.join(examples_dir, "garment7.jpg"), "tops", "flat-lay"), ] garment_gallery_images = [item[0] for item in garment_examples_data] def on_garment_gallery_select(evt: gr.SelectData): """Handle garment gallery selection - load image and update dropdowns.""" idx = evt.index if idx < len(garment_examples_data): image_path, cat, photo_type = garment_examples_data[idx] return Image.open(image_path), cat, photo_type return None, "tops", "model" # Build UI with gr.Blocks(css=CUSTOM_CSS) as demo: # Header gr.HTML(banner_html) gr.HTML(tips_html) with gr.Row(equal_height=False): # Column 1: Person with gr.Column(scale=1): person_image = gr.Image( label="Person Image", type="pil", sources=["upload", "clipboard"], elem_classes=["contain"], ) # Individual person examples gr.Examples( examples=person_only_examples, inputs=person_image, label="Person Examples", ) # Column 2: Garment with gr.Column(scale=1): garment_image = gr.Image( label="Garment Image", type="pil", sources=["upload", "clipboard"], elem_classes=["contain"], ) with gr.Row(): category = gr.Dropdown( choices=CATEGORIES, value="tops", label="Category", ) garment_photo_type = gr.Dropdown( choices=GARMENT_PHOTO_TYPES, value="model", label="Photo Type", ) # Garment examples as clickable gallery gr.Markdown("**Garment Examples** (click to load with settings)") garment_gallery = gr.Gallery( value=garment_gallery_images, columns=2, rows=1, height="auto", object_fit="contain", show_label=False, allow_preview=False, ) # Column 3: Result with gr.Column(scale=1): result_image = gr.Image( label="Try-On Result", type="pil", interactive=False, elem_classes=["contain"], ) run_button = gr.Button("Try On", variant="primary", size="lg") # Advanced settings with gr.Accordion("Advanced Settings", open=False): num_timesteps = gr.Slider( minimum=10, maximum=50, value=50, step=5, label="Sampling Steps", info="Higher = better quality, slower.", ) guidance_scale = gr.Slider( minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="Guidance Scale", info="How closely to follow the garment. 1.5 recommended.", ) seed = gr.Number( value=42, label="Seed", info="Random seed for reproducibility.", precision=0, ) segmentation_free = gr.Checkbox( value=True, label="Segmentation Free", info="Preserves body features and allows unconstrained garment volume. Disable for tighter garment fitting.", ) # Paired examples at the bottom gr.Examples( examples=paired_examples, inputs=[person_image, garment_image, category, garment_photo_type], label="Complete Examples (click to load person + garment + settings)", ) # Event handlers run_button.click( fn=try_on, inputs=[ person_image, garment_image, category, garment_photo_type, num_timesteps, guidance_scale, seed, segmentation_free, ], outputs=[result_image], ) # Garment gallery selection - loads image and updates dropdowns garment_gallery.select( fn=on_garment_gallery_select, inputs=None, outputs=[garment_image, category, garment_photo_type], ) # Configure queue with concurrency limit to prevent GPU OOM demo.queue(default_concurrency_limit=1, max_size=30) if __name__ == "__main__": demo.launch(share=False)