File size: 7,080 Bytes
6a75bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
app.py β€” OOTDiffusion Hugging Face Space
Place this file in the ROOT of your Space repo.

Your Space structure should look like:
  OOODdiffusion/
  β”œβ”€β”€ app.py                  ← this file (root level)
  β”œβ”€β”€ requirements.txt        ← root level
  β”œβ”€β”€ README.md               ← root level
  └── OOTDiffusion-main/      ← the uploaded zip contents
      β”œβ”€β”€ ootd/
      β”œβ”€β”€ run/
      β”œβ”€β”€ preprocess/
      β”œβ”€β”€ checkpoints/
      └── ...
"""

import sys
import os

# ── Path setup ────────────────────────────────────────────────────────────────
ROOT_DIR  = os.path.dirname(os.path.abspath(__file__))

# Support both flat layout and nested OOTDiffusion-main/ layout
OOTD_DIR = ROOT_DIR
for candidate in ["OOTDiffusion-main", "OOTDiffusion"]:
    candidate_path = os.path.join(ROOT_DIR, candidate)
    if os.path.isdir(candidate_path):
        OOTD_DIR = candidate_path
        break

RUN_DIR = os.path.join(OOTD_DIR, "run")

sys.path.insert(0, OOTD_DIR)
sys.path.insert(0, RUN_DIR)

print(f"[OOTDiffusion] ROOT_DIR : {ROOT_DIR}")
print(f"[OOTDiffusion] OOTD_DIR : {OOTD_DIR}")

import torch
import numpy as np
import gradio as gr
from PIL import Image

# ── Device ────────────────────────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[OOTDiffusion] Device: {DEVICE}")

# ── Lazy-load models ──────────────────────────────────────────────────────────
_pipe_hd = None
_pipe_dc = None


def load_pipeline(model_type: str):
    global _pipe_hd, _pipe_dc
    if model_type == "hd":
        if _pipe_hd is None:
            from ootd.inference_ootd_hd import OOTDiffusionHD
            print("[OOTDiffusion] Loading HD pipeline …")
            _pipe_hd = OOTDiffusionHD(OOTD_DIR)
        return _pipe_hd
    else:
        if _pipe_dc is None:
            from ootd.inference_ootd_dc import OOTDiffusionDC
            print("[OOTDiffusion] Loading DC pipeline …")
            _pipe_dc = OOTDiffusionDC(OOTD_DIR)
        return _pipe_dc


# ── Category mapping ──────────────────────────────────────────────────────────
CATEGORY_MAP = {
    "Upper-body": 0,
    "Lower-body":  1,
    "Dress":       2,
}


# ── Inference ─────────────────────────────────────────────────────────────────
def run_tryon(model_image, cloth_image, model_type, category_label,
              n_samples, n_steps, guidance_scale, seed):

    if model_image is None:
        raise gr.Error("Please upload a model (person) image.")
    if cloth_image is None:
        raise gr.Error("Please upload a garment image.")

    if isinstance(model_image, np.ndarray):
        model_image = Image.fromarray(model_image)
    if isinstance(cloth_image, np.ndarray):
        cloth_image = Image.fromarray(cloth_image)

    model_image = model_image.convert("RGB")
    cloth_image = cloth_image.convert("RGB")

    category_idx = CATEGORY_MAP[category_label]

    try:
        pipe = load_pipeline(model_type)
    except Exception as e:
        raise gr.Error(
            f"Failed to load model: {e}\n"
            "Make sure OOTDiffusion-main/ folder with ootd/ and checkpoints/ is present."
        )

    try:
        result = pipe(
            model_type=model_type,
            category=category_idx,
            image_garm=cloth_image,
            image_vton=model_image,
            mask=None,
            image_ori=model_image,
            num_samples=int(n_samples),
            num_steps=int(n_steps),
            guidance_scale=float(guidance_scale),
            seed=int(seed),
        )
    except Exception as e:
        raise gr.Error(f"Inference failed: {e}")

    if isinstance(result, (list, tuple)):
        return result
    return [result]


# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="OOTDiffusion Virtual Try-On", theme=gr.themes.Soft()) as demo:

    gr.Markdown("""
    # πŸ‘— OOTDiffusion β€” Virtual Try-On
    **[AAAI 2025]** Upload a *person photo* and a *garment image*, then click **Run Try-On**.
    > ⚠️ Non-commercial use only (CC-BY-NC-SA-4.0)
    """)

    with gr.Row():
        with gr.Column(scale=1):
            model_img = gr.Image(label="πŸ‘€ Model Image (person)", type="pil", height=380)
            cloth_img = gr.Image(label="πŸ‘• Garment Image", type="pil", height=380)

        with gr.Column(scale=1):
            model_type = gr.Radio(
                choices=["hd", "dc"], value="hd",
                label="Model Type",
                info="hd = half-body (VITON-HD)  |  dc = full-body (Dress Code)"
            )
            category = gr.Dropdown(
                choices=list(CATEGORY_MAP.keys()), value="Upper-body",
                label="Garment Category",
                info="Only matters when Model Type = dc"
            )
            n_samples = gr.Slider(1, 4, step=1, value=1, label="Number of Samples")
            n_steps   = gr.Slider(10, 40, step=5, value=20, label="Denoising Steps",
                                  info="More steps = better quality, slower")
            guidance  = gr.Slider(1.0, 5.0, step=0.5, value=2.0, label="Guidance Scale")
            seed      = gr.Number(value=42, label="Seed  (-1 = random)", precision=0)
            run_btn   = gr.Button("πŸš€ Run Try-On", variant="primary", size="lg")

        with gr.Column(scale=1):
            output_gallery = gr.Gallery(
                label="✨ Try-On Results",
                columns=2, height=500, object_fit="contain"
            )

    gr.Markdown("""
    ### πŸ’‘ Tips
    - **HD model** β€” best for upper-body garments on half-body photos
    - **DC model** β€” supports upper / lower / dress on full-body photos
    - Steps **30–40** give noticeably better quality
    - **Seed = -1** gives a different result each run
    """)

    run_btn.click(
        fn=run_tryon,
        inputs=[model_img, cloth_img, model_type, category,
                n_samples, n_steps, guidance, seed],
        outputs=output_gallery,
    )

# ── Launch ────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    demo.launch()