OOODdiffusion / app.py
fanboyd13's picture
Upload app.py
2b08552 verified
"""
app.py β€” OOTDiffusion Hugging Face Space
Place this file in the ROOT of your Space repo alongside the
OOTDiffusion source folders: ootd/, run/, preprocess/, checkpoints/
README.md front-matter required:
---
title: OOTDiffusion Virtual Try-On
emoji: πŸ‘—
colorFrom: purple
colorTo: pink
sdk: gradio
sdk_version: 4.16.0
app_file: app.py
pinned: false
license: cc-by-nc-sa-4.0
---
"""
import sys
import os
# ── Path setup ────────────────────────────────────────────────────────────────
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
RUN_DIR = os.path.join(ROOT_DIR, "run")
sys.path.insert(0, ROOT_DIR)
sys.path.insert(0, RUN_DIR)
import torch
import numpy as np
import gradio as gr
from PIL import Image
# ── Device ────────────────────────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[OOTDiffusion] Device: {DEVICE}")
# ── Lazy-load models (loaded once on first request) ───────────────────────────
_pipe_hd = None # VITON-HD β€” half-body
_pipe_dc = None # Dress Code β€” full-body
def load_pipeline(model_type: str):
"""Import and cache the requested OOTDiffusion pipeline."""
global _pipe_hd, _pipe_dc
if model_type == "hd":
if _pipe_hd is None:
from ootd.inference_ootd_hd import OOTDiffusionHD
print("[OOTDiffusion] Loading HD pipeline …")
_pipe_hd = OOTDiffusionHD(ROOT_DIR)
return _pipe_hd
else: # dc
if _pipe_dc is None:
from ootd.inference_ootd_dc import OOTDiffusionDC
print("[OOTDiffusion] Loading DC pipeline …")
_pipe_dc = OOTDiffusionDC(ROOT_DIR)
return _pipe_dc
# ── Category mapping ──────────────────────────────────────────────────────────
CATEGORY_MAP = {
"Upper-body": 0,
"Lower-body": 1,
"Dress": 2,
}
# ── Main inference function ───────────────────────────────────────────────────
def run_tryon(
model_image,
cloth_image,
model_type,
category_label,
n_samples,
n_steps,
guidance_scale,
seed,
):
if model_image is None:
raise gr.Error("Please upload a model (person) image.")
if cloth_image is None:
raise gr.Error("Please upload a garment image.")
# Convert to PIL just in case Gradio passes numpy arrays
if isinstance(model_image, np.ndarray):
model_image = Image.fromarray(model_image)
if isinstance(cloth_image, np.ndarray):
cloth_image = Image.fromarray(cloth_image)
model_image = model_image.convert("RGB")
cloth_image = cloth_image.convert("RGB")
category_idx = CATEGORY_MAP[category_label]
try:
pipe = load_pipeline(model_type)
except Exception as e:
raise gr.Error(
f"Failed to load model: {e}\n"
"Make sure checkpoints/ and ootd/ folders are present."
)
try:
if model_type == "hd":
result = pipe(
model_type="hd",
category=category_idx,
image_garm=cloth_image,
image_vton=model_image,
mask=None,
image_ori=model_image,
num_samples=int(n_samples),
num_steps=int(n_steps),
guidance_scale=guidance_scale,
seed=int(seed),
)
else:
result = pipe(
model_type="dc",
category=category_idx,
image_garm=cloth_image,
image_vton=model_image,
mask=None,
image_ori=model_image,
num_samples=int(n_samples),
num_steps=int(n_steps),
guidance_scale=guidance_scale,
seed=int(seed),
)
except Exception as e:
raise gr.Error(f"Inference failed: {e}")
# result is expected to be a list of PIL Images
if isinstance(result, (list, tuple)):
return result
return [result]
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="OOTDiffusion Virtual Try-On", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ‘— OOTDiffusion β€” Virtual Try-On
**[AAAI 2025]** Upload a *model photo* and a *garment image*, choose settings, and click **Run Try-On**.
> ⚠️ Non-commercial use only (CC-BY-NC-SA-4.0)
"""
)
with gr.Row():
# ── Left column: inputs ───────────────────────────────────────────────
with gr.Column(scale=1):
model_img = gr.Image(
label="Model Image (person)",
type="pil",
height=400,
)
cloth_img = gr.Image(
label="Garment Image (clothing)",
type="pil",
height=400,
)
# ── Middle column: settings ───────────────────────────────────────────
with gr.Column(scale=1):
model_type = gr.Radio(
choices=["hd", "dc"],
value="hd",
label="Model Type",
info="hd = half-body (VITON-HD) | dc = full-body (Dress Code)",
)
category = gr.Dropdown(
choices=list(CATEGORY_MAP.keys()),
value="Upper-body",
label="Garment Category",
info="Only used when Model Type is 'dc'",
)
n_samples = gr.Slider(
minimum=1, maximum=4, step=1, value=1,
label="Number of Samples",
)
n_steps = gr.Slider(
minimum=10, maximum=40, step=5, value=20,
label="Denoising Steps",
info="More steps = better quality but slower",
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=5.0, step=0.5, value=2.0,
label="Guidance Scale",
)
seed = gr.Number(
value=42,
label="Seed (-1 = random)",
precision=0,
)
run_btn = gr.Button("πŸš€ Run Try-On", variant="primary")
# ── Right column: outputs ─────────────────────────────────────────────
with gr.Column(scale=1):
output_gallery = gr.Gallery(
label="Try-On Results",
columns=2,
height=500,
object_fit="contain",
)
gr.Markdown(
"""
### Tips
- **HD model**: best for upper-body garments on half-body photos
- **DC model**: supports upper-body / lower-body / dress on full-body photos
- Increasing **steps** to 30–40 noticeably improves quality
- Set **seed = -1** for random results each run
"""
)
# ── Wire up the button ────────────────────────────────────────────────────
run_btn.click(
fn=run_tryon,
inputs=[
model_img,
cloth_img,
model_type,
category,
n_samples,
n_steps,
guidance_scale,
seed,
],
outputs=output_gallery,
)
# ── Launch ────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch()