Spaces:
Runtime error
Runtime error
| """ | |
| app.py β OOTDiffusion Hugging Face Space | |
| Place this file in the ROOT of your Space repo alongside the | |
| OOTDiffusion source folders: ootd/, run/, preprocess/, checkpoints/ | |
| README.md front-matter required: | |
| --- | |
| title: OOTDiffusion Virtual Try-On | |
| emoji: π | |
| colorFrom: purple | |
| colorTo: pink | |
| sdk: gradio | |
| sdk_version: 4.16.0 | |
| app_file: app.py | |
| pinned: false | |
| license: cc-by-nc-sa-4.0 | |
| --- | |
| """ | |
| import sys | |
| import os | |
| # ββ Path setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| RUN_DIR = os.path.join(ROOT_DIR, "run") | |
| sys.path.insert(0, ROOT_DIR) | |
| sys.path.insert(0, RUN_DIR) | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| # ββ Device ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[OOTDiffusion] Device: {DEVICE}") | |
| # ββ Lazy-load models (loaded once on first request) βββββββββββββββββββββββββββ | |
| _pipe_hd = None # VITON-HD β half-body | |
| _pipe_dc = None # Dress Code β full-body | |
| def load_pipeline(model_type: str): | |
| """Import and cache the requested OOTDiffusion pipeline.""" | |
| global _pipe_hd, _pipe_dc | |
| if model_type == "hd": | |
| if _pipe_hd is None: | |
| from ootd.inference_ootd_hd import OOTDiffusionHD | |
| print("[OOTDiffusion] Loading HD pipeline β¦") | |
| _pipe_hd = OOTDiffusionHD(ROOT_DIR) | |
| return _pipe_hd | |
| else: # dc | |
| if _pipe_dc is None: | |
| from ootd.inference_ootd_dc import OOTDiffusionDC | |
| print("[OOTDiffusion] Loading DC pipeline β¦") | |
| _pipe_dc = OOTDiffusionDC(ROOT_DIR) | |
| return _pipe_dc | |
| # ββ Category mapping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CATEGORY_MAP = { | |
| "Upper-body": 0, | |
| "Lower-body": 1, | |
| "Dress": 2, | |
| } | |
| # ββ Main inference function βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_tryon( | |
| model_image, | |
| cloth_image, | |
| model_type, | |
| category_label, | |
| n_samples, | |
| n_steps, | |
| guidance_scale, | |
| seed, | |
| ): | |
| if model_image is None: | |
| raise gr.Error("Please upload a model (person) image.") | |
| if cloth_image is None: | |
| raise gr.Error("Please upload a garment image.") | |
| # Convert to PIL just in case Gradio passes numpy arrays | |
| if isinstance(model_image, np.ndarray): | |
| model_image = Image.fromarray(model_image) | |
| if isinstance(cloth_image, np.ndarray): | |
| cloth_image = Image.fromarray(cloth_image) | |
| model_image = model_image.convert("RGB") | |
| cloth_image = cloth_image.convert("RGB") | |
| category_idx = CATEGORY_MAP[category_label] | |
| try: | |
| pipe = load_pipeline(model_type) | |
| except Exception as e: | |
| raise gr.Error( | |
| f"Failed to load model: {e}\n" | |
| "Make sure checkpoints/ and ootd/ folders are present." | |
| ) | |
| try: | |
| if model_type == "hd": | |
| result = pipe( | |
| model_type="hd", | |
| category=category_idx, | |
| image_garm=cloth_image, | |
| image_vton=model_image, | |
| mask=None, | |
| image_ori=model_image, | |
| num_samples=int(n_samples), | |
| num_steps=int(n_steps), | |
| guidance_scale=guidance_scale, | |
| seed=int(seed), | |
| ) | |
| else: | |
| result = pipe( | |
| model_type="dc", | |
| category=category_idx, | |
| image_garm=cloth_image, | |
| image_vton=model_image, | |
| mask=None, | |
| image_ori=model_image, | |
| num_samples=int(n_samples), | |
| num_steps=int(n_steps), | |
| guidance_scale=guidance_scale, | |
| seed=int(seed), | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Inference failed: {e}") | |
| # result is expected to be a list of PIL Images | |
| if isinstance(result, (list, tuple)): | |
| return result | |
| return [result] | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="OOTDiffusion Virtual Try-On", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π OOTDiffusion β Virtual Try-On | |
| **[AAAI 2025]** Upload a *model photo* and a *garment image*, choose settings, and click **Run Try-On**. | |
| > β οΈ Non-commercial use only (CC-BY-NC-SA-4.0) | |
| """ | |
| ) | |
| with gr.Row(): | |
| # ββ Left column: inputs βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| model_img = gr.Image( | |
| label="Model Image (person)", | |
| type="pil", | |
| height=400, | |
| ) | |
| cloth_img = gr.Image( | |
| label="Garment Image (clothing)", | |
| type="pil", | |
| height=400, | |
| ) | |
| # ββ Middle column: settings βββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| model_type = gr.Radio( | |
| choices=["hd", "dc"], | |
| value="hd", | |
| label="Model Type", | |
| info="hd = half-body (VITON-HD) | dc = full-body (Dress Code)", | |
| ) | |
| category = gr.Dropdown( | |
| choices=list(CATEGORY_MAP.keys()), | |
| value="Upper-body", | |
| label="Garment Category", | |
| info="Only used when Model Type is 'dc'", | |
| ) | |
| n_samples = gr.Slider( | |
| minimum=1, maximum=4, step=1, value=1, | |
| label="Number of Samples", | |
| ) | |
| n_steps = gr.Slider( | |
| minimum=10, maximum=40, step=5, value=20, | |
| label="Denoising Steps", | |
| info="More steps = better quality but slower", | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, maximum=5.0, step=0.5, value=2.0, | |
| label="Guidance Scale", | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| label="Seed (-1 = random)", | |
| precision=0, | |
| ) | |
| run_btn = gr.Button("π Run Try-On", variant="primary") | |
| # ββ Right column: outputs βββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="Try-On Results", | |
| columns=2, | |
| height=500, | |
| object_fit="contain", | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Tips | |
| - **HD model**: best for upper-body garments on half-body photos | |
| - **DC model**: supports upper-body / lower-body / dress on full-body photos | |
| - Increasing **steps** to 30β40 noticeably improves quality | |
| - Set **seed = -1** for random results each run | |
| """ | |
| ) | |
| # ββ Wire up the button ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| run_btn.click( | |
| fn=run_tryon, | |
| inputs=[ | |
| model_img, | |
| cloth_img, | |
| model_type, | |
| category, | |
| n_samples, | |
| n_steps, | |
| guidance_scale, | |
| seed, | |
| ], | |
| outputs=output_gallery, | |
| ) | |
| # ββ Launch ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| demo.launch() | |