VTON / app.py
Zy131's picture
Upload 7 files
a577a4e verified
"""FASHN VTON v1.5 HuggingFace Space Demo."""
import os
import platform
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
# ----------------- CONFIG ----------------- #
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ASSETS_DIR = os.path.join(SCRIPT_DIR, "assets")
WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights")
CATEGORIES = ["tops", "bottoms", "one-pieces"]
GARMENT_PHOTO_TYPES = ["model", "flat-lay"]
# Global pipeline instance (lazy loaded)
_pipeline = None
# ----------------- HELPERS ----------------- #
def download_weights():
"""Download model weights from HuggingFace Hub."""
os.makedirs(WEIGHTS_DIR, exist_ok=True)
dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose")
os.makedirs(dwpose_dir, exist_ok=True)
# Download TryOnModel weights
tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors")
if not os.path.exists(tryon_path):
print("Downloading TryOnModel weights...")
hf_hub_download(
repo_id="fashn-ai/fashn-vton-1.5",
filename="model.safetensors",
local_dir=WEIGHTS_DIR,
)
# Download DWPose models
dwpose_files = ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
for filename in dwpose_files:
filepath = os.path.join(dwpose_dir, filename)
if not os.path.exists(filepath):
print(f"Downloading DWPose/{filename}...")
hf_hub_download(
repo_id="fashn-ai/DWPose",
filename=filename,
local_dir=dwpose_dir,
)
print("Weights downloaded successfully!")
# ----------------- MODEL LOADING ----------------- #
def get_pipeline():
"""Lazy-load the pipeline on first use (ensures GPU is available on ZeroGPU)."""
global _pipeline
if _pipeline is None:
# Check CUDA availability (will be true inside @spaces.GPU context)
if not torch.cuda.is_available():
raise gr.Error(
"CUDA is not available. This demo requires a GPU to run. "
"If you're on HuggingFace Spaces, please try again in a moment."
)
# ---------------------------------- Diagnostics ---------------------------------- #
print(f"Python : {platform.python_version()}")
print(f"PyTorch : {torch.__version__}")
print(f" • built for CUDA : {torch.version.cuda}")
if torch.backends.cudnn.is_available():
print(f" • built for cuDNN: {torch.backends.cudnn.version()}")
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
if torch.cuda.is_available():
dev = torch.cuda.current_device()
cc = torch.cuda.get_device_capability(dev)
print(f"GPU {dev}: {torch.cuda.get_device_name(dev)} (compute {cc[0]}.{cc[1]})")
# Enable TF32 for faster computation on Ampere+ GPUs
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print("Downloading weights (if needed)...")
download_weights()
print("Loading pipeline...")
from fashn_vton import TryOnPipeline
_pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cuda")
print("Pipeline loaded on CUDA!")
return _pipeline
# ----------------- INFERENCE ----------------- #
@spaces.GPU
def try_on(
person_image: Image.Image,
garment_image: Image.Image,
category: str,
garment_photo_type: str,
num_timesteps: int,
guidance_scale: float,
seed: int,
segmentation_free: bool,
) -> Image.Image:
"""Run virtual try-on inference."""
if person_image is None:
raise gr.Error("Please upload a person image")
if garment_image is None:
raise gr.Error("Please upload a garment image")
# Handle seed (guard against None or invalid values)
if seed is None or seed < 0:
seed = 42
# Convert to RGB if needed
if person_image.mode != "RGB":
person_image = person_image.convert("RGB")
if garment_image.mode != "RGB":
garment_image = garment_image.convert("RGB")
# Get pipeline (lazy loads on first call)
pipeline = get_pipeline()
# Run inference
result = pipeline(
person_image=person_image,
garment_image=garment_image,
category=category,
garment_photo_type=garment_photo_type,
num_samples=1,
num_timesteps=num_timesteps,
guidance_scale=guidance_scale,
seed=int(seed),
segmentation_free=segmentation_free,
)
return result.images[0]
# ----------------- UI ----------------- #
# Custom CSS
CUSTOM_CSS = """
.contain img {
object-fit: contain !important;
max-height: 856px !important;
max-width: 576px !important;
}
"""
# Load HTML content
with open(os.path.join(SCRIPT_DIR, "banner.html"), "r") as f:
banner_html = f.read()
with open(os.path.join(SCRIPT_DIR, "tips.html"), "r") as f:
tips_html = f.read()
# Build example paths
examples_dir = os.path.join(ASSETS_DIR, "examples")
# Paired examples: [person_path, garment_path, category, garment_photo_type]
paired_examples = [
[os.path.join(examples_dir, "person1.png"), os.path.join(examples_dir, "garment1.jpeg"), "one-pieces", "model"],
[os.path.join(examples_dir, "person2.png"), os.path.join(examples_dir, "garment2.webp"), "tops", "model"],
[os.path.join(examples_dir, "person3.png"), os.path.join(examples_dir, "garment3.jpeg"), "tops", "flat-lay"],
[os.path.join(examples_dir, "person4.png"), os.path.join(examples_dir, "garment4.webp"), "tops", "model"],
[os.path.join(examples_dir, "person5.png"), os.path.join(examples_dir, "garment5.jpeg"), "bottoms", "flat-lay"],
[os.path.join(examples_dir, "person6.png"), os.path.join(examples_dir, "garment6.webp"), "one-pieces", "model"],
]
# Individual examples (classic from repo)
person_only_examples = [os.path.join(examples_dir, "person0.png")]
# Garment examples with their settings: (image_path, category, photo_type)
# Order matters - index in Gallery corresponds to this list
garment_examples_data = [
(os.path.join(examples_dir, "garment0.png"), "tops", "model"),
(os.path.join(examples_dir, "garment7.jpg"), "tops", "flat-lay"),
]
garment_gallery_images = [item[0] for item in garment_examples_data]
def on_garment_gallery_select(evt: gr.SelectData):
"""Handle garment gallery selection - load image and update dropdowns."""
idx = evt.index
if idx < len(garment_examples_data):
image_path, cat, photo_type = garment_examples_data[idx]
return Image.open(image_path), cat, photo_type
return None, "tops", "model"
# Build UI
with gr.Blocks(css=CUSTOM_CSS) as demo:
# Header
gr.HTML(banner_html)
gr.HTML(tips_html)
with gr.Row(equal_height=False):
# Column 1: Person
with gr.Column(scale=1):
person_image = gr.Image(
label="Person Image",
type="pil",
sources=["upload", "clipboard"],
elem_classes=["contain"],
)
# Individual person examples
gr.Examples(
examples=person_only_examples,
inputs=person_image,
label="Person Examples",
)
# Column 2: Garment
with gr.Column(scale=1):
garment_image = gr.Image(
label="Garment Image",
type="pil",
sources=["upload", "clipboard"],
elem_classes=["contain"],
)
with gr.Row():
category = gr.Dropdown(
choices=CATEGORIES,
value="tops",
label="Category",
)
garment_photo_type = gr.Dropdown(
choices=GARMENT_PHOTO_TYPES,
value="model",
label="Photo Type",
)
# Garment examples as clickable gallery
gr.Markdown("**Garment Examples** (click to load with settings)")
garment_gallery = gr.Gallery(
value=garment_gallery_images,
columns=2,
rows=1,
height="auto",
object_fit="contain",
show_label=False,
allow_preview=False,
)
# Column 3: Result
with gr.Column(scale=1):
result_image = gr.Image(
label="Try-On Result",
type="pil",
interactive=False,
elem_classes=["contain"],
)
run_button = gr.Button("Try On", variant="primary", size="lg")
# Advanced settings
with gr.Accordion("Advanced Settings", open=False):
num_timesteps = gr.Slider(
minimum=10,
maximum=50,
value=50,
step=5,
label="Sampling Steps",
info="Higher = better quality, slower.",
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=3.0,
value=1.5,
step=0.1,
label="Guidance Scale",
info="How closely to follow the garment. 1.5 recommended.",
)
seed = gr.Number(
value=42,
label="Seed",
info="Random seed for reproducibility.",
precision=0,
)
segmentation_free = gr.Checkbox(
value=True,
label="Segmentation Free",
info="Preserves body features and allows unconstrained garment volume. Disable for tighter garment fitting.",
)
# Paired examples at the bottom
gr.Examples(
examples=paired_examples,
inputs=[person_image, garment_image, category, garment_photo_type],
label="Complete Examples (click to load person + garment + settings)",
)
# Event handlers
run_button.click(
fn=try_on,
inputs=[
person_image,
garment_image,
category,
garment_photo_type,
num_timesteps,
guidance_scale,
seed,
segmentation_free,
],
outputs=[result_image],
)
# Garment gallery selection - loads image and updates dropdowns
garment_gallery.select(
fn=on_garment_gallery_select,
inputs=None,
outputs=[garment_image, category, garment_photo_type],
)
# Configure queue with concurrency limit to prevent GPU OOM
demo.queue(default_concurrency_limit=1, max_size=30)
if __name__ == "__main__":
demo.launch(share=False)