virtual-tryon / app.py
Hemil Ghori
change ui
29f4ad4
import os
import gc
import threading
from typing import Optional
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
# ─────────────────────────── CONFIG ──────────────────────────── #
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights")
EXAMPLES_DIR = os.path.join(SCRIPT_DIR, "examples", "data")
CATEGORIES = ["tops", "bottoms", "one-pieces"]
GARMENT_PHOTO_TYPES = ["model", "flat-lay"]
# ──────────────────────── WEIGHT DOWNLOAD ────────────────────── #
def download_weights():
"""Download model weights from HuggingFace Hub (skips if already present)."""
os.makedirs(WEIGHTS_DIR, exist_ok=True)
dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose")
os.makedirs(dwpose_dir, exist_ok=True)
tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors")
if not os.path.exists(tryon_path):
print("Downloading TryOnModel weights...")
hf_hub_download(
repo_id="fashn-ai/fashn-vton-1.5",
filename="model.safetensors",
local_dir=WEIGHTS_DIR,
)
for filename in ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]:
filepath = os.path.join(dwpose_dir, filename)
if not os.path.exists(filepath):
print(f"Downloading DWPose/{filename}...")
hf_hub_download(
repo_id="fashn-ai/DWPose",
filename=filename,
local_dir=dwpose_dir,
)
print("All weights ready!")
# Download weights at startup
download_weights()
# ──────────────────────── PIPELINE LOADER ────────────────────── #
_pipeline_lock = threading.Lock()
_pipeline: Optional[object] = None
def get_pipeline():
"""Thread-safe lazy pipeline loader (CPU mode)."""
global _pipeline
with _pipeline_lock:
if _pipeline is None:
from fashn_vton import TryOnPipeline
print("Loading pipeline on CPU...")
_pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cpu")
print("Pipeline ready!")
return _pipeline
# ─────────────────────────── INFERENCE ───────────────────────── #
def try_on(
person_image,
garment_image,
category: str,
garment_photo_type: str,
num_timesteps: int,
guidance_scale: float,
seed: int,
segmentation_free: bool,
):
"""Run virtual try-on inference."""
if person_image is None:
raise gr.Error("Please upload a person image.")
if garment_image is None:
raise gr.Error("Please upload a garment image.")
# Normalise seed
if seed is None or seed < 0:
seed = 42
seed = int(seed)
# Ensure PIL RGB
def to_pil(x):
if isinstance(x, np.ndarray):
x = Image.fromarray(x)
if isinstance(x, Image.Image):
return x.convert("RGB")
return Image.open(x).convert("RGB")
person_img = to_pil(person_image)
garment_img = to_pil(garment_image)
pipeline = get_pipeline()
try:
result = pipeline(
person_image=person_img,
garment_image=garment_img,
category=category,
garment_photo_type=garment_photo_type,
num_samples=1,
num_timesteps=num_timesteps,
guidance_scale=guidance_scale,
seed=seed,
segmentation_free=segmentation_free,
)
return result.images[0], "βœ… Done!"
except Exception as e:
return None, f"❌ Error: {e}"
# ─────────────────────────── GRADIO UI ───────────────────────── #
CUSTOM_CSS = """
body { font-family: 'Inter', sans-serif; }
.contain img {
object-fit: contain !important;
max-height: 520px !important;
}
#run-btn {
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important;
border: none !important;
color: white !important;
font-size: 1.1rem !important;
font-weight: 600 !important;
padding: 0.75rem !important;
border-radius: 12px !important;
transition: opacity 0.2s;
}
#run-btn:hover { opacity: 0.85; }
.status-box textarea {
font-size: 0.9rem !important;
color: #a3e635 !important;
background: #1e1e2e !important;
border-radius: 8px !important;
}
.gr-accordion { border-radius: 10px !important; }
"""
BANNER_MD = """
# πŸ‘— FASHN VTON β€” Virtual Try-On
Upload a **person image** and a **garment image**, choose the garment category and hit **Try On**!
> ⚠️ Running on **CPU** β€” inference may take a few minutes. Reduce *Sampling Steps* for faster results.
"""
TIPS_HTML = """
<div style="display: flex; justify-content: center; align-items: center; gap: 1rem; flex-wrap: wrap; margin-bottom: 20px; font-size: 0.95rem; color: #a1a1aa;">
<div style="font-weight: 600; color: #e4e4e7;">πŸ’‘ Tips for best results:</div>
<div>πŸ‘€ Single person, clearly visible</div>
<div style="color: #52525b;">|</div>
<div>πŸ‘• Match category to garment type</div>
<div style="color: #52525b;">|</div>
<div>πŸ“Έ Use "flat-lay" for product shots</div>
<div style="color: #52525b;">|</div>
<div>πŸ“ 2:3 aspect ratio optimal</div>
</div>
"""
person_example = os.path.join(EXAMPLES_DIR, "model.jpeg")
garment_example = os.path.join(EXAMPLES_DIR, "garment.jpeg")
with gr.Blocks(css=CUSTOM_CSS, title="FASHN VTON β€” Virtual Try-On") as demo:
gr.Markdown(BANNER_MD)
gr.HTML(TIPS_HTML)
with gr.Row(equal_height=False):
# ── Column 1 : Person ──────────────────────────────────
with gr.Column(scale=1):
person_in = gr.Image(
label="Person Image",
type="pil",
sources=["upload", "clipboard"],
elem_classes=["contain"],
)
if os.path.exists(person_example):
gr.Examples(
examples=[[person_example]],
inputs=[person_in],
label="Person Example",
)
# ── Column 2 : Garment ─────────────────────────────────
with gr.Column(scale=1):
garment_in = gr.Image(
label="Garment Image",
type="pil",
sources=["upload", "clipboard"],
elem_classes=["contain"],
)
with gr.Row():
category = gr.Dropdown(
choices=CATEGORIES,
value="tops",
label="Category",
)
garment_photo_type = gr.Dropdown(
choices=GARMENT_PHOTO_TYPES,
value="model",
label="Photo Type",
)
if os.path.exists(garment_example):
gr.Examples(
examples=[[garment_example]],
inputs=[garment_in],
label="Garment Example",
)
# ── Column 3 : Result ──────────────────────────────────
with gr.Column(scale=1):
result_img = gr.Image(
label="Try-On Result",
type="pil",
interactive=False,
elem_classes=["contain"],
)
status = gr.Textbox(
value="Ready",
label="Status",
interactive=False,
elem_classes=["status-box"],
)
run_btn = gr.Button("πŸ‘— Try On", variant="primary", elem_id="run-btn")
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
num_timesteps = gr.Slider(
minimum=10, maximum=50, value=30, step=5,
label="Sampling Steps",
info="Higher = better quality but slower. 30 is a good balance.",
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=3.0, value=1.5, step=0.1,
label="Guidance Scale",
info="How closely to follow the garment details. 1.5 recommended.",
)
seed = gr.Number(
value=42, label="Seed", precision=0,
info="Change seed to get a different variation of the result.",
)
segmentation_free = gr.Checkbox(
value=True,
label="Segmentation-Free (Recommended)",
info="Preserves body features and allows unconstrained garment volume.",
)
# ── Event ──────────────────────────────────────────────────
run_btn.click(
fn=try_on,
inputs=[
person_in, garment_in,
category, garment_photo_type,
num_timesteps, guidance_scale,
seed, segmentation_free,
],
outputs=[result_img, status],
)
demo.queue(default_concurrency_limit=1, max_size=10)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)