virtual-tryon / app.py
Nandha2017's picture
Fix: .then() chaining for instant status, remove gr.File, robust error return
1659459 verified
"""
Virtual Try-On — Paint-by-Example + Hugging Face ZeroGPU
No local GPU or model storage needed.
"""
import datetime
import os
import gradio as gr
import spaces
import torch
from PIL import Image, ImageDraw
# ---------------------------------------------------------------------------
# Persistent storage
# ---------------------------------------------------------------------------
DATA_DIR = "/data" if os.path.exists("/data") else "/tmp"
OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
# ---------------------------------------------------------------------------
# Image helpers
# ---------------------------------------------------------------------------
TARGET_SIZE = 512
def _fit_to_square(img: Image.Image, size: int = TARGET_SIZE) -> Image.Image:
img = img.convert("RGB")
img.thumbnail((size, size), Image.LANCZOS)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
return canvas
def _make_mask(size: int, cloth_type: str) -> Image.Image:
mask = Image.new("L", (size, size), 0)
d = ImageDraw.Draw(mask)
if cloth_type == "upper":
d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
elif cloth_type == "lower":
d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
else:
d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
return mask
# ---------------------------------------------------------------------------
# GPU inference — returns images + status string
# ---------------------------------------------------------------------------
_pipe = None
@spaces.GPU(duration=120)
def run_tryon(
person_image: Image.Image,
garment_image: Image.Image,
cloth_type: str,
num_steps: int,
guidance_scale: float,
seed: int,
):
if person_image is None or garment_image is None:
return None, "❌ Please upload both a person photo and a garment image."
global _pipe
if _pipe is None:
from diffusers import PaintByExamplePipeline
print("Loading Paint-by-Example (~5 GB, first run only)…")
_pipe = PaintByExamplePipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float16,
).to("cuda")
_pipe.set_progress_bar_config(disable=True)
print("Pipeline ready.")
person = _fit_to_square(person_image)
garment = _fit_to_square(garment_image)
mask = _make_mask(TARGET_SIZE, cloth_type)
rng = torch.Generator(device="cuda")
rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())
result = _pipe(
image=person,
mask_image=mask,
example_image=garment,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
generator=rng,
)
output_images = result.images
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
for i, img in enumerate(output_images):
img.save(os.path.join(OUTPUT_DIR, f"tryon_{timestamp}_{i}.png"), format="PNG")
return output_images, "✅ Done! Right-click an image in the gallery to save it."
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Virtual Try-On", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# 👗 Virtual Try-On\n"
"Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
"> Runs on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed.\n"
"> **First run:** ~2-3 min (model download ~5 GB). **Subsequent runs:** ~15-30s."
)
with gr.Row():
with gr.Column():
person_input = gr.Image(label="Person Photo", type="pil", height=350)
garment_input = gr.Image(label="Garment Image", type="pil", height=350)
cloth_type = gr.Radio(
["upper", "lower", "overall"],
value="upper",
label="Garment Type",
info="upper=top/shirt | lower=pants/skirt | overall=dress/full outfit",
)
with gr.Accordion("Advanced", open=False):
num_steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
guidance = gr.Slider(1.0, 10.0, value=7.5, step=0.5, label="Guidance Scale")
seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
try_btn = gr.Button("👗 Try On", variant="primary", size="lg")
with gr.Column():
status_box = gr.Textbox(
label="Status", value="Ready — upload images and click Try On",
interactive=False, max_lines=2,
)
output_gallery = gr.Gallery(label="Result", columns=1, height=420)
# Chain: first update status immediately, then run inference
try_btn.click(
fn=lambda: "⏳ Requesting GPU + loading model… (first run ~3 min, please wait)",
inputs=None,
outputs=[status_box],
).then(
fn=run_tryon,
inputs=[person_input, garment_input, cloth_type, num_steps, guidance, seed_input],
outputs=[output_gallery, status_box],
)
gr.Markdown(
"---\n"
"**Tips:** front-facing photo · garment on white/neutral background · upper body for shirts\n\n"
"Built with [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example) · "
"[Gradio](https://gradio.app) · [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu)"
)
if __name__ == "__main__":
demo.launch()