| """FASHN VTON v1.5 HuggingFace Space Demo."""
|
|
|
| import os
|
| import platform
|
|
|
| import gradio as gr
|
| import spaces
|
| import torch
|
| from huggingface_hub import hf_hub_download
|
| from PIL import Image
|
|
|
|
|
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| ASSETS_DIR = os.path.join(SCRIPT_DIR, "assets")
|
| WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights")
|
|
|
| CATEGORIES = ["tops", "bottoms", "one-pieces"]
|
| GARMENT_PHOTO_TYPES = ["model", "flat-lay"]
|
|
|
|
|
| _pipeline = None
|
|
|
|
|
|
|
|
|
|
|
| def download_weights():
|
| """Download model weights from HuggingFace Hub."""
|
| os.makedirs(WEIGHTS_DIR, exist_ok=True)
|
| dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose")
|
| os.makedirs(dwpose_dir, exist_ok=True)
|
|
|
|
|
| tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors")
|
| if not os.path.exists(tryon_path):
|
| print("Downloading TryOnModel weights...")
|
| hf_hub_download(
|
| repo_id="fashn-ai/fashn-vton-1.5",
|
| filename="model.safetensors",
|
| local_dir=WEIGHTS_DIR,
|
| )
|
|
|
|
|
| dwpose_files = ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
|
| for filename in dwpose_files:
|
| filepath = os.path.join(dwpose_dir, filename)
|
| if not os.path.exists(filepath):
|
| print(f"Downloading DWPose/{filename}...")
|
| hf_hub_download(
|
| repo_id="fashn-ai/DWPose",
|
| filename=filename,
|
| local_dir=dwpose_dir,
|
| )
|
|
|
| print("Weights downloaded successfully!")
|
|
|
|
|
|
|
|
|
|
|
| def get_pipeline():
|
| """Lazy-load the pipeline on first use (ensures GPU is available on ZeroGPU)."""
|
| global _pipeline
|
| if _pipeline is None:
|
|
|
| if not torch.cuda.is_available():
|
| raise gr.Error(
|
| "CUDA is not available. This demo requires a GPU to run. "
|
| "If you're on HuggingFace Spaces, please try again in a moment."
|
| )
|
|
|
|
|
| print(f"Python : {platform.python_version()}")
|
| print(f"PyTorch : {torch.__version__}")
|
| print(f" • built for CUDA : {torch.version.cuda}")
|
| if torch.backends.cudnn.is_available():
|
| print(f" • built for cuDNN: {torch.backends.cudnn.version()}")
|
| print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
|
| if torch.cuda.is_available():
|
| dev = torch.cuda.current_device()
|
| cc = torch.cuda.get_device_capability(dev)
|
| print(f"GPU {dev}: {torch.cuda.get_device_name(dev)} (compute {cc[0]}.{cc[1]})")
|
|
|
|
|
| if torch.cuda.get_device_properties(0).major >= 8:
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
|
|
| print("Downloading weights (if needed)...")
|
| download_weights()
|
|
|
| print("Loading pipeline...")
|
| from fashn_vton import TryOnPipeline
|
|
|
| _pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cuda")
|
| print("Pipeline loaded on CUDA!")
|
|
|
| return _pipeline
|
|
|
|
|
|
|
|
|
|
|
| @spaces.GPU
|
| def try_on(
|
| person_image: Image.Image,
|
| garment_image: Image.Image,
|
| category: str,
|
| garment_photo_type: str,
|
| num_timesteps: int,
|
| guidance_scale: float,
|
| seed: int,
|
| segmentation_free: bool,
|
| ) -> Image.Image:
|
| """Run virtual try-on inference."""
|
| if person_image is None:
|
| raise gr.Error("Please upload a person image")
|
| if garment_image is None:
|
| raise gr.Error("Please upload a garment image")
|
|
|
|
|
| if seed is None or seed < 0:
|
| seed = 42
|
|
|
|
|
| if person_image.mode != "RGB":
|
| person_image = person_image.convert("RGB")
|
| if garment_image.mode != "RGB":
|
| garment_image = garment_image.convert("RGB")
|
|
|
|
|
| pipeline = get_pipeline()
|
|
|
|
|
| result = pipeline(
|
| person_image=person_image,
|
| garment_image=garment_image,
|
| category=category,
|
| garment_photo_type=garment_photo_type,
|
| num_samples=1,
|
| num_timesteps=num_timesteps,
|
| guidance_scale=guidance_scale,
|
| seed=int(seed),
|
| segmentation_free=segmentation_free,
|
| )
|
|
|
| return result.images[0]
|
|
|
|
|
|
|
|
|
|
|
| CUSTOM_CSS = """
|
| .contain img {
|
| object-fit: contain !important;
|
| max-height: 856px !important;
|
| max-width: 576px !important;
|
| }
|
| """
|
|
|
|
|
| with open(os.path.join(SCRIPT_DIR, "banner.html"), "r") as f:
|
| banner_html = f.read()
|
| with open(os.path.join(SCRIPT_DIR, "tips.html"), "r") as f:
|
| tips_html = f.read()
|
|
|
|
|
| examples_dir = os.path.join(ASSETS_DIR, "examples")
|
|
|
|
|
| paired_examples = [
|
| [os.path.join(examples_dir, "person1.png"), os.path.join(examples_dir, "garment1.jpeg"), "one-pieces", "model"],
|
| [os.path.join(examples_dir, "person2.png"), os.path.join(examples_dir, "garment2.webp"), "tops", "model"],
|
| [os.path.join(examples_dir, "person3.png"), os.path.join(examples_dir, "garment3.jpeg"), "tops", "flat-lay"],
|
| [os.path.join(examples_dir, "person4.png"), os.path.join(examples_dir, "garment4.webp"), "tops", "model"],
|
| [os.path.join(examples_dir, "person5.png"), os.path.join(examples_dir, "garment5.jpeg"), "bottoms", "flat-lay"],
|
| [os.path.join(examples_dir, "person6.png"), os.path.join(examples_dir, "garment6.webp"), "one-pieces", "model"],
|
| ]
|
|
|
|
|
| person_only_examples = [os.path.join(examples_dir, "person0.png")]
|
|
|
|
|
|
|
| garment_examples_data = [
|
| (os.path.join(examples_dir, "garment0.png"), "tops", "model"),
|
| (os.path.join(examples_dir, "garment7.jpg"), "tops", "flat-lay"),
|
| ]
|
| garment_gallery_images = [item[0] for item in garment_examples_data]
|
|
|
|
|
| def on_garment_gallery_select(evt: gr.SelectData):
|
| """Handle garment gallery selection - load image and update dropdowns."""
|
| idx = evt.index
|
| if idx < len(garment_examples_data):
|
| image_path, cat, photo_type = garment_examples_data[idx]
|
| return Image.open(image_path), cat, photo_type
|
| return None, "tops", "model"
|
|
|
|
|
|
|
| with gr.Blocks(css=CUSTOM_CSS) as demo:
|
|
|
| gr.HTML(banner_html)
|
| gr.HTML(tips_html)
|
|
|
| with gr.Row(equal_height=False):
|
|
|
| with gr.Column(scale=1):
|
| person_image = gr.Image(
|
| label="Person Image",
|
| type="pil",
|
| sources=["upload", "clipboard"],
|
| elem_classes=["contain"],
|
| )
|
|
|
|
|
| gr.Examples(
|
| examples=person_only_examples,
|
| inputs=person_image,
|
| label="Person Examples",
|
| )
|
|
|
|
|
| with gr.Column(scale=1):
|
| garment_image = gr.Image(
|
| label="Garment Image",
|
| type="pil",
|
| sources=["upload", "clipboard"],
|
| elem_classes=["contain"],
|
| )
|
|
|
| with gr.Row():
|
| category = gr.Dropdown(
|
| choices=CATEGORIES,
|
| value="tops",
|
| label="Category",
|
| )
|
| garment_photo_type = gr.Dropdown(
|
| choices=GARMENT_PHOTO_TYPES,
|
| value="model",
|
| label="Photo Type",
|
| )
|
|
|
|
|
| gr.Markdown("**Garment Examples** (click to load with settings)")
|
| garment_gallery = gr.Gallery(
|
| value=garment_gallery_images,
|
| columns=2,
|
| rows=1,
|
| height="auto",
|
| object_fit="contain",
|
| show_label=False,
|
| allow_preview=False,
|
| )
|
|
|
|
|
| with gr.Column(scale=1):
|
| result_image = gr.Image(
|
| label="Try-On Result",
|
| type="pil",
|
| interactive=False,
|
| elem_classes=["contain"],
|
| )
|
|
|
| run_button = gr.Button("Try On", variant="primary", size="lg")
|
|
|
|
|
| with gr.Accordion("Advanced Settings", open=False):
|
| num_timesteps = gr.Slider(
|
| minimum=10,
|
| maximum=50,
|
| value=50,
|
| step=5,
|
| label="Sampling Steps",
|
| info="Higher = better quality, slower.",
|
| )
|
| guidance_scale = gr.Slider(
|
| minimum=1.0,
|
| maximum=3.0,
|
| value=1.5,
|
| step=0.1,
|
| label="Guidance Scale",
|
| info="How closely to follow the garment. 1.5 recommended.",
|
| )
|
| seed = gr.Number(
|
| value=42,
|
| label="Seed",
|
| info="Random seed for reproducibility.",
|
| precision=0,
|
| )
|
| segmentation_free = gr.Checkbox(
|
| value=True,
|
| label="Segmentation Free",
|
| info="Preserves body features and allows unconstrained garment volume. Disable for tighter garment fitting.",
|
| )
|
|
|
|
|
| gr.Examples(
|
| examples=paired_examples,
|
| inputs=[person_image, garment_image, category, garment_photo_type],
|
| label="Complete Examples (click to load person + garment + settings)",
|
| )
|
|
|
|
|
| run_button.click(
|
| fn=try_on,
|
| inputs=[
|
| person_image,
|
| garment_image,
|
| category,
|
| garment_photo_type,
|
| num_timesteps,
|
| guidance_scale,
|
| seed,
|
| segmentation_free,
|
| ],
|
| outputs=[result_image],
|
| )
|
|
|
|
|
| garment_gallery.select(
|
| fn=on_garment_gallery_select,
|
| inputs=None,
|
| outputs=[garment_image, category, garment_photo_type],
|
| )
|
|
|
|
|
| demo.queue(default_concurrency_limit=1, max_size=30)
|
|
|
| if __name__ == "__main__":
|
| demo.launch(share=False)
|
|
|