Spaces:
Running on Zero
Running on Zero
| """ | |
| Virtual Try-On — Paint-by-Example + Hugging Face ZeroGPU | |
| No local GPU or model storage needed. | |
| """ | |
| import datetime | |
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image, ImageDraw | |
| # --------------------------------------------------------------------------- | |
| # Persistent storage | |
| # --------------------------------------------------------------------------- | |
| DATA_DIR = "/data" if os.path.exists("/data") else "/tmp" | |
| OUTPUT_DIR = os.path.join(DATA_DIR, "outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache") | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub") | |
| # --------------------------------------------------------------------------- | |
| # Image helpers | |
| # --------------------------------------------------------------------------- | |
| TARGET_SIZE = 512 | |
| def _fit_to_square(img: Image.Image, size: int = TARGET_SIZE) -> Image.Image: | |
| img = img.convert("RGB") | |
| img.thumbnail((size, size), Image.LANCZOS) | |
| canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
| canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2)) | |
| return canvas | |
| def _make_mask(size: int, cloth_type: str) -> Image.Image: | |
| mask = Image.new("L", (size, size), 0) | |
| d = ImageDraw.Draw(mask) | |
| if cloth_type == "upper": | |
| d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255) | |
| elif cloth_type == "lower": | |
| d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255) | |
| else: | |
| d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255) | |
| return mask | |
| # --------------------------------------------------------------------------- | |
| # GPU inference — returns images + status string | |
| # --------------------------------------------------------------------------- | |
| _pipe = None | |
| def run_tryon( | |
| person_image: Image.Image, | |
| garment_image: Image.Image, | |
| cloth_type: str, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ): | |
| if person_image is None or garment_image is None: | |
| return None, "❌ Please upload both a person photo and a garment image." | |
| global _pipe | |
| if _pipe is None: | |
| from diffusers import PaintByExamplePipeline | |
| print("Loading Paint-by-Example (~5 GB, first run only)…") | |
| _pipe = PaintByExamplePipeline.from_pretrained( | |
| "Fantasy-Studio/Paint-by-Example", | |
| torch_dtype=torch.float16, | |
| ).to("cuda") | |
| _pipe.set_progress_bar_config(disable=True) | |
| print("Pipeline ready.") | |
| person = _fit_to_square(person_image) | |
| garment = _fit_to_square(garment_image) | |
| mask = _make_mask(TARGET_SIZE, cloth_type) | |
| rng = torch.Generator(device="cuda") | |
| rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item()) | |
| result = _pipe( | |
| image=person, | |
| mask_image=mask, | |
| example_image=garment, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| generator=rng, | |
| ) | |
| output_images = result.images | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| for i, img in enumerate(output_images): | |
| img.save(os.path.join(OUTPUT_DIR, f"tryon_{timestamp}_{i}.png"), format="PNG") | |
| return output_images, "✅ Done! Right-click an image in the gallery to save it." | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Virtual Try-On", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# 👗 Virtual Try-On\n" | |
| "Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n" | |
| "> Runs on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed.\n" | |
| "> **First run:** ~2-3 min (model download ~5 GB). **Subsequent runs:** ~15-30s." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| person_input = gr.Image(label="Person Photo", type="pil", height=350) | |
| garment_input = gr.Image(label="Garment Image", type="pil", height=350) | |
| cloth_type = gr.Radio( | |
| ["upper", "lower", "overall"], | |
| value="upper", | |
| label="Garment Type", | |
| info="upper=top/shirt | lower=pants/skirt | overall=dress/full outfit", | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| num_steps = gr.Slider(10, 50, value=30, step=1, label="Steps") | |
| guidance = gr.Slider(1.0, 10.0, value=7.5, step=0.5, label="Guidance Scale") | |
| seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0) | |
| try_btn = gr.Button("👗 Try On", variant="primary", size="lg") | |
| with gr.Column(): | |
| status_box = gr.Textbox( | |
| label="Status", value="Ready — upload images and click Try On", | |
| interactive=False, max_lines=2, | |
| ) | |
| output_gallery = gr.Gallery(label="Result", columns=1, height=420) | |
| # Chain: first update status immediately, then run inference | |
| try_btn.click( | |
| fn=lambda: "⏳ Requesting GPU + loading model… (first run ~3 min, please wait)", | |
| inputs=None, | |
| outputs=[status_box], | |
| ).then( | |
| fn=run_tryon, | |
| inputs=[person_input, garment_input, cloth_type, num_steps, guidance, seed_input], | |
| outputs=[output_gallery, status_box], | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "**Tips:** front-facing photo · garment on white/neutral background · upper body for shirts\n\n" | |
| "Built with [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example) · " | |
| "[Gradio](https://gradio.app) · [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |