File size: 6,115 Bytes
bb42d68
ceeb6a1
 
bb42d68
 
 
 
 
 
 
 
d04b5a9
bb42d68
 
99ba0dd
bb42d68
 
 
 
 
d04b5a9
 
bb42d68
 
d04b5a9
bb42d68
d04b5a9
 
 
 
bb42d68
 
d04b5a9
bb42d68
 
d04b5a9
 
 
bb42d68
d04b5a9
bb42d68
d04b5a9
1cfca86
d04b5a9
bb42d68
 
 
1659459
bb42d68
0aa8990
 
bb42d68
99ba0dd
 
 
 
 
 
 
 
 
1659459
99ba0dd
0aa8990
 
 
1cfca86
0aa8990
 
 
 
 
1cfca86
0aa8990
99ba0dd
 
 
 
0aa8990
 
 
 
 
 
 
 
 
 
 
99ba0dd
1cfca86
bb42d68
 
1659459
bb42d68
1659459
bb42d68
 
 
 
d04b5a9
bb42d68
 
d04b5a9
1659459
 
bb42d68
 
 
d04b5a9
1659459
 
d04b5a9
 
 
 
 
bb42d68
d04b5a9
 
 
 
 
 
 
1659459
 
 
 
 
bb42d68
1659459
bb42d68
1659459
 
 
 
bb42d68
 
1659459
bb42d68
 
 
 
d04b5a9
ceeb6a1
d04b5a9
bb42d68
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""

Virtual Try-On — Paint-by-Example + Hugging Face ZeroGPU

No local GPU or model storage needed.

"""

import datetime
import os

import gradio as gr
import spaces
import torch
from PIL import Image, ImageDraw

# ---------------------------------------------------------------------------
# Persistent storage
# ---------------------------------------------------------------------------
DATA_DIR   = "/data" if os.path.exists("/data") else "/tmp"
OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)

os.environ["HF_HOME"]               = os.path.join(DATA_DIR, "hf_cache")
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")

# ---------------------------------------------------------------------------
# Image helpers
# ---------------------------------------------------------------------------
TARGET_SIZE = 512

def _fit_to_square(img: Image.Image, size: int = TARGET_SIZE) -> Image.Image:
    img = img.convert("RGB")
    img.thumbnail((size, size), Image.LANCZOS)
    canvas = Image.new("RGB", (size, size), (255, 255, 255))
    canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
    return canvas

def _make_mask(size: int, cloth_type: str) -> Image.Image:
    mask = Image.new("L", (size, size), 0)
    d = ImageDraw.Draw(mask)
    if cloth_type == "upper":
        d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
    elif cloth_type == "lower":
        d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
    else:
        d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
    return mask

# ---------------------------------------------------------------------------
# GPU inference — returns images + status string
# ---------------------------------------------------------------------------
_pipe = None

@spaces.GPU(duration=120)
def run_tryon(

    person_image: Image.Image,

    garment_image: Image.Image,

    cloth_type: str,

    num_steps: int,

    guidance_scale: float,

    seed: int,

):
    if person_image is None or garment_image is None:
        return None, "❌ Please upload both a person photo and a garment image."

    global _pipe
    if _pipe is None:
        from diffusers import PaintByExamplePipeline
        print("Loading Paint-by-Example (~5 GB, first run only)…")
        _pipe = PaintByExamplePipeline.from_pretrained(
            "Fantasy-Studio/Paint-by-Example",
            torch_dtype=torch.float16,
        ).to("cuda")
        _pipe.set_progress_bar_config(disable=True)
        print("Pipeline ready.")

    person  = _fit_to_square(person_image)
    garment = _fit_to_square(garment_image)
    mask    = _make_mask(TARGET_SIZE, cloth_type)

    rng = torch.Generator(device="cuda")
    rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())

    result = _pipe(
        image=person,
        mask_image=mask,
        example_image=garment,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        generator=rng,
    )
    output_images = result.images

    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    for i, img in enumerate(output_images):
        img.save(os.path.join(OUTPUT_DIR, f"tryon_{timestamp}_{i}.png"), format="PNG")

    return output_images, "✅ Done! Right-click an image in the gallery to save it."

# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Virtual Try-On", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        "# 👗 Virtual Try-On\n"
        "Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
        "> Runs on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed.\n"
        "> **First run:** ~2-3 min (model download ~5 GB).  **Subsequent runs:** ~15-30s."
    )

    with gr.Row():
        with gr.Column():
            person_input  = gr.Image(label="Person Photo", type="pil", height=350)
            garment_input = gr.Image(label="Garment Image", type="pil", height=350)
            cloth_type = gr.Radio(
                ["upper", "lower", "overall"],
                value="upper",
                label="Garment Type",
                info="upper=top/shirt  |  lower=pants/skirt  |  overall=dress/full outfit",
            )
            with gr.Accordion("Advanced", open=False):
                num_steps  = gr.Slider(10, 50, value=30, step=1, label="Steps")
                guidance   = gr.Slider(1.0, 10.0, value=7.5, step=0.5, label="Guidance Scale")
                seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
            try_btn = gr.Button("👗 Try On", variant="primary", size="lg")

        with gr.Column():
            status_box = gr.Textbox(
                label="Status", value="Ready — upload images and click Try On",
                interactive=False, max_lines=2,
            )
            output_gallery = gr.Gallery(label="Result", columns=1, height=420)

    # Chain: first update status immediately, then run inference
    try_btn.click(
        fn=lambda: "⏳ Requesting GPU + loading model…  (first run ~3 min, please wait)",
        inputs=None,
        outputs=[status_box],
    ).then(
        fn=run_tryon,
        inputs=[person_input, garment_input, cloth_type, num_steps, guidance, seed_input],
        outputs=[output_gallery, status_box],
    )

    gr.Markdown(
        "---\n"
        "**Tips:** front-facing photo · garment on white/neutral background · upper body for shirts\n\n"
        "Built with [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example) · "
        "[Gradio](https://gradio.app) · [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu)"
    )

if __name__ == "__main__":
    demo.launch()