Nandha2017 commited on
Commit
bb42d68
·
verified ·
1 Parent(s): 5d5eb22

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +46 -12
  2. app.py +285 -0
  3. packages.txt +6 -0
  4. requirements.txt +13 -0
README.md CHANGED
@@ -1,12 +1,46 @@
1
- ---
2
- title: Virtual Tryon
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Virtual Try-On (CatVTON)
3
+ emoji: 👗
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: "4.44.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ hardware: zero-a10g
12
+ ---
13
+
14
+ # 👗 Virtual Try-On
15
+
16
+ Try on garments virtually using AI — runs entirely in your browser via Hugging Face ZeroGPU.
17
+
18
+ **No local GPU or storage needed.**
19
+
20
+ ## How to Use
21
+
22
+ 1. Upload a **person photo** (front-facing works best)
23
+ 2. Upload a **garment image** (product photo on white background works best)
24
+ 3. Select the garment type (upper / lower / overall)
25
+ 4. Click **Try On**
26
+ 5. Download the result to your device
27
+
28
+ ## Technical Details
29
+
30
+ - **Model**: [CatVTON](https://github.com/zhengchong/CatVTON) (`zhengchong/CatVTON`)
31
+ - **GPU**: Hugging Face ZeroGPU (A10G, free tier)
32
+ - **Model storage**: Downloaded once to `/data` persistent storage on HF servers
33
+ - **Your device**: Only needs a web browser — no downloads, no GPU
34
+
35
+ ## Notes
36
+
37
+ - First run takes ~2-5 minutes (model download to HF servers)
38
+ - Subsequent runs start immediately (model cached in persistent storage)
39
+ - For best results: clear front-facing photos, garment on white/neutral background
40
+ - ZeroGPU provides ~120 seconds of GPU time per generation
41
+
42
+ ## Built With
43
+
44
+ - [CatVTON](https://github.com/zhengchong/CatVTON) — virtual try-on model
45
+ - [Gradio](https://gradio.app) — web interface
46
+ - [Hugging Face ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu) — free GPU
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Virtual Try-On — Powered by CatVTON + Hugging Face ZeroGPU
3
+ ===========================================================
4
+ No local GPU or model storage needed.
5
+ Models download once to /data on HF's servers.
6
+ Generated images are saved to the user's local device.
7
+ """
8
+
9
+ import datetime
10
+ import os
11
+ import sys
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import spaces
16
+ import torch
17
+ from huggingface_hub import snapshot_download
18
+ from PIL import Image
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Persistent storage (HF Spaces /data, else /tmp)
22
+ # ---------------------------------------------------------------------------
23
+ DATA_DIR = "/data" if os.path.exists("/data") else "/tmp"
24
+ MODELS_DIR = os.path.join(DATA_DIR, "catvton_models")
25
+ OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
26
+ os.makedirs(MODELS_DIR, exist_ok=True)
27
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
28
+
29
+ # Point HF cache to persistent storage
30
+ os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
31
+ os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Model download (runs at Space startup — on HF servers, NOT locally)
35
+ # ---------------------------------------------------------------------------
36
+ CATVTON_REPO = "zhengchong/CatVTON"
37
+ CATVTON_LOCAL = os.path.join(MODELS_DIR, "CatVTON")
38
+
39
+ def download_models():
40
+ if not os.path.exists(os.path.join(CATVTON_LOCAL, "config.json")):
41
+ print("Downloading CatVTON model to HF persistent storage...")
42
+ snapshot_download(
43
+ repo_id=CATVTON_REPO,
44
+ local_dir=CATVTON_LOCAL,
45
+ local_dir_use_symlinks=False,
46
+ )
47
+ print("CatVTON model ready.")
48
+ else:
49
+ print("CatVTON model already cached.")
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Pipeline loader (lazy — only after GPU is assigned)
53
+ # ---------------------------------------------------------------------------
54
+ _pipeline = None
55
+
56
+ def load_pipeline():
57
+ global _pipeline
58
+ if _pipeline is not None:
59
+ return _pipeline
60
+
61
+ from diffusers import AutoencoderKL, UNet2DConditionModel
62
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
63
+ StableDiffusionInpaintPipeline,
64
+ )
65
+ from transformers import CLIPTextModel, CLIPTokenizer
66
+
67
+ # CatVTON uses a custom diffusers-compatible pipeline
68
+ # Fall back to standard diffusers inpaint if custom loader unavailable
69
+ try:
70
+ sys.path.insert(0, CATVTON_LOCAL)
71
+ from model.pipeline import CatVTONPipeline
72
+ _pipeline = CatVTONPipeline(
73
+ base_ckpt=CATVTON_LOCAL,
74
+ attn_ckpt=CATVTON_LOCAL,
75
+ attn_ckpt_version="mix",
76
+ weight_dtype=torch.float16,
77
+ device="cuda",
78
+ skip_safety_check=True,
79
+ )
80
+ print("CatVTON custom pipeline loaded.")
81
+ except Exception as e:
82
+ print(f"CatVTON custom pipeline failed ({e}), using diffusers fallback...")
83
+ _pipeline = StableDiffusionInpaintPipeline.from_pretrained(
84
+ CATVTON_LOCAL,
85
+ torch_dtype=torch.float16,
86
+ safety_checker=None,
87
+ ).to("cuda")
88
+ print("Diffusers fallback pipeline loaded.")
89
+
90
+ return _pipeline
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Mask generation utilities
95
+ # ---------------------------------------------------------------------------
96
+ def _resize_and_pad(img: Image.Image, size: int = 768) -> Image.Image:
97
+ """Resize image to square, preserving aspect ratio with padding."""
98
+ img.thumbnail((size, size), Image.LANCZOS)
99
+ canvas = Image.new("RGB", (size, size), (255, 255, 255))
100
+ x = (size - img.width) // 2
101
+ y = (size - img.height) // 2
102
+ canvas.paste(img, (x, y))
103
+ return canvas
104
+
105
+
106
+ def _build_mask(person_img: Image.Image, cloth_type: str) -> Image.Image:
107
+ """
108
+ Build a rough inpainting mask based on cloth_type.
109
+ For a proper implementation, use a segmentation model (e.g. SCHP).
110
+ This simple version covers standard body regions.
111
+ """
112
+ w, h = person_img.size
113
+ mask = Image.new("L", (w, h), 0)
114
+ import PIL.ImageDraw as ImageDraw
115
+ draw = ImageDraw.Draw(mask)
116
+
117
+ if cloth_type == "upper":
118
+ # Cover torso: from ~20% to ~65% height
119
+ draw.rectangle([int(w * 0.1), int(h * 0.18), int(w * 0.9), int(h * 0.65)], fill=255)
120
+ elif cloth_type == "lower":
121
+ # Cover legs: from ~55% to ~100% height
122
+ draw.rectangle([int(w * 0.05), int(h * 0.55), int(w * 0.95), int(h * 1.0)], fill=255)
123
+ else: # overall / dress
124
+ # Cover full body: from ~15% to ~100% height
125
+ draw.rectangle([int(w * 0.05), int(h * 0.15), int(w * 0.95), int(h * 1.0)], fill=255)
126
+
127
+ return mask
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Inference (ZeroGPU)
132
+ # ---------------------------------------------------------------------------
133
+ @spaces.GPU(duration=120)
134
+ def run_tryon(
135
+ person_image: Image.Image,
136
+ garment_image: Image.Image,
137
+ cloth_type: str,
138
+ num_steps: int,
139
+ guidance_scale: float,
140
+ seed: int,
141
+ ) -> tuple[list, list]:
142
+ """
143
+ Run virtual try-on inference on HF ZeroGPU.
144
+ Returns (gallery_images, downloadable_file_paths).
145
+ """
146
+ if person_image is None or garment_image is None:
147
+ raise gr.Error("Please upload both a person image and a garment image.")
148
+
149
+ pipe = load_pipeline()
150
+
151
+ # Pre-process
152
+ size = 768
153
+ person_resized = _resize_and_pad(person_image.convert("RGB"), size)
154
+ garment_resized = _resize_and_pad(garment_image.convert("RGB"), size)
155
+ mask = _build_mask(person_resized, cloth_type)
156
+
157
+ generator = torch.Generator(device="cuda")
158
+ if seed == -1:
159
+ seed = torch.randint(0, 2**32, (1,)).item()
160
+ generator.manual_seed(int(seed))
161
+
162
+ # Run pipeline
163
+ try:
164
+ # CatVTON custom call signature
165
+ result = pipe(
166
+ image=person_resized,
167
+ condition_image=garment_resized,
168
+ mask=mask,
169
+ num_inference_steps=num_steps,
170
+ guidance_scale=guidance_scale,
171
+ generator=generator,
172
+ )
173
+ output_images = result if isinstance(result, list) else [result]
174
+ except TypeError:
175
+ # Diffusers fallback call signature
176
+ result = pipe(
177
+ prompt="a person wearing the garment, photorealistic, high quality",
178
+ image=person_resized,
179
+ mask_image=mask,
180
+ num_inference_steps=num_steps,
181
+ guidance_scale=guidance_scale,
182
+ generator=generator,
183
+ )
184
+ output_images = result.images
185
+
186
+ # Save outputs
187
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
188
+ saved_paths = []
189
+ pil_images = []
190
+ for i, img in enumerate(output_images):
191
+ if not isinstance(img, Image.Image):
192
+ img = Image.fromarray(np.uint8(img))
193
+ pil_images.append(img)
194
+ path = os.path.join(OUTPUT_DIR, f"tryon_{timestamp}_{i}.png")
195
+ img.save(path, format="PNG")
196
+ saved_paths.append(path)
197
+
198
+ return pil_images, saved_paths
199
+
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # Gradio UI
203
+ # ---------------------------------------------------------------------------
204
+ EXAMPLES = [] # add example paths here if desired
205
+
206
+ with gr.Blocks(title="Virtual Try-On — CatVTON", theme=gr.themes.Soft()) as demo:
207
+ gr.Markdown(
208
+ "# 👗 Virtual Try-On\n"
209
+ "Upload a **person photo** and a **garment image**, then click **Try On**.\n\n"
210
+ "> Runs on Hugging Face ZeroGPU (free A10G) — no local GPU or storage needed. \n"
211
+ "> Generated images are saved to your device via the Download button."
212
+ )
213
+
214
+ with gr.Row():
215
+ with gr.Column(scale=1):
216
+ person_input = gr.Image(
217
+ label="Person Photo",
218
+ type="pil",
219
+ height=400,
220
+ )
221
+ garment_input = gr.Image(
222
+ label="Garment Image",
223
+ type="pil",
224
+ height=400,
225
+ )
226
+
227
+ with gr.Column(scale=1):
228
+ output_gallery = gr.Gallery(
229
+ label="Result",
230
+ show_label=True,
231
+ columns=1,
232
+ height=400,
233
+ )
234
+ output_files = gr.File(
235
+ label="⬇ Download to your device",
236
+ file_count="multiple",
237
+ interactive=False,
238
+ )
239
+
240
+ with gr.Row():
241
+ cloth_type = gr.Radio(
242
+ ["upper", "lower", "overall"],
243
+ value="upper",
244
+ label="Garment Type",
245
+ info="upper = top/shirt, lower = pants/skirt, overall = dress/full outfit",
246
+ )
247
+
248
+ with gr.Accordion("Advanced Settings", open=False):
249
+ with gr.Row():
250
+ num_steps = gr.Slider(
251
+ minimum=10, maximum=50, value=30, step=1,
252
+ label="Inference Steps",
253
+ )
254
+ guidance = gr.Slider(
255
+ minimum=1.0, maximum=10.0, value=2.5, step=0.5,
256
+ label="Guidance Scale",
257
+ )
258
+ seed_input = gr.Number(
259
+ label="Seed (-1 = random)", value=-1, precision=0,
260
+ )
261
+
262
+ try_btn = gr.Button("👗 Try On", variant="primary", size="lg")
263
+
264
+ try_btn.click(
265
+ fn=run_tryon,
266
+ inputs=[person_input, garment_input, cloth_type, num_steps, guidance, seed_input],
267
+ outputs=[output_gallery, output_files],
268
+ )
269
+
270
+ gr.Markdown(
271
+ "---\n"
272
+ "**Notes:** \n"
273
+ "- First run downloads the model (~2-4 GB) to HF persistent storage — takes a few minutes once. \n"
274
+ "- Subsequent runs start immediately (model cached). \n"
275
+ "- For best results: use a front-facing photo with clear garment visibility. \n"
276
+ "- Built with [CatVTON](https://github.com/zhengchong/CatVTON) + "
277
+ "[Gradio](https://gradio.app) + [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu)"
278
+ )
279
+
280
+
281
+ # Download model at Space startup (on HF servers, not locally)
282
+ download_models()
283
+
284
+ if __name__ == "__main__":
285
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
3
+ libsm6
4
+ libxext6
5
+ libxrender-dev
6
+ libgomp1
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ spaces
3
+ torch
4
+ torchvision
5
+ diffusers>=0.27.0
6
+ transformers>=4.40.0
7
+ accelerate>=0.28.0
8
+ huggingface_hub>=0.27.0
9
+ Pillow>=10.0.0
10
+ numpy>=1.24.0
11
+ safetensors>=0.4.2
12
+ omegaconf
13
+ einops