Nandha2017 commited on
Commit
d04b5a9
·
verified ·
1 Parent(s): bb42d68

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +100 -192
  2. requirements.txt +8 -9
app.py CHANGED
@@ -1,9 +1,6 @@
1
  """
2
- Virtual Try-On — Powered by CatVTON + Hugging Face ZeroGPU
3
- ===========================================================
4
- No local GPU or model storage needed.
5
- Models download once to /data on HF's servers.
6
- Generated images are saved to the user's local device.
7
  """
8
 
9
  import datetime
@@ -15,10 +12,10 @@ import numpy as np
15
  import spaces
16
  import torch
17
  from huggingface_hub import snapshot_download
18
- from PIL import Image
19
 
20
  # ---------------------------------------------------------------------------
21
- # Persistent storage (HF Spaces /data, else /tmp)
22
  # ---------------------------------------------------------------------------
23
  DATA_DIR = "/data" if os.path.exists("/data") else "/tmp"
24
  MODELS_DIR = os.path.join(DATA_DIR, "catvton_models")
@@ -26,109 +23,73 @@ OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
26
  os.makedirs(MODELS_DIR, exist_ok=True)
27
  os.makedirs(OUTPUT_DIR, exist_ok=True)
28
 
29
- # Point HF cache to persistent storage
30
- os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
31
- os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
32
 
33
  # ---------------------------------------------------------------------------
34
- # Model download (runs at Space startup on HF servers, NOT locally)
35
  # ---------------------------------------------------------------------------
36
- CATVTON_REPO = "zhengchong/CatVTON"
37
- CATVTON_LOCAL = os.path.join(MODELS_DIR, "CatVTON")
38
 
39
  def download_models():
40
- if not os.path.exists(os.path.join(CATVTON_LOCAL, "config.json")):
41
- print("Downloading CatVTON model to HF persistent storage...")
42
- snapshot_download(
43
- repo_id=CATVTON_REPO,
44
- local_dir=CATVTON_LOCAL,
45
- local_dir_use_symlinks=False,
46
- )
47
- print("CatVTON model ready.")
48
- else:
49
- print("CatVTON model already cached.")
 
50
 
51
  # ---------------------------------------------------------------------------
52
- # Pipeline loader (lazy only after GPU is assigned)
53
  # ---------------------------------------------------------------------------
54
- _pipeline = None
55
-
56
- def load_pipeline():
57
- global _pipeline
58
- if _pipeline is not None:
59
- return _pipeline
60
-
61
- from diffusers import AutoencoderKL, UNet2DConditionModel
62
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
63
- StableDiffusionInpaintPipeline,
64
- )
65
- from transformers import CLIPTextModel, CLIPTokenizer
66
-
67
- # CatVTON uses a custom diffusers-compatible pipeline
68
- # Fall back to standard diffusers inpaint if custom loader unavailable
69
- try:
70
- sys.path.insert(0, CATVTON_LOCAL)
71
- from model.pipeline import CatVTONPipeline
72
- _pipeline = CatVTONPipeline(
73
- base_ckpt=CATVTON_LOCAL,
74
- attn_ckpt=CATVTON_LOCAL,
75
- attn_ckpt_version="mix",
76
- weight_dtype=torch.float16,
77
- device="cuda",
78
- skip_safety_check=True,
79
- )
80
- print("CatVTON custom pipeline loaded.")
81
- except Exception as e:
82
- print(f"CatVTON custom pipeline failed ({e}), using diffusers fallback...")
83
- _pipeline = StableDiffusionInpaintPipeline.from_pretrained(
84
- CATVTON_LOCAL,
85
- torch_dtype=torch.float16,
86
- safety_checker=None,
87
- ).to("cuda")
88
- print("Diffusers fallback pipeline loaded.")
89
-
90
- return _pipeline
91
-
92
 
93
  # ---------------------------------------------------------------------------
94
- # Mask generation utilities
95
  # ---------------------------------------------------------------------------
96
- def _resize_and_pad(img: Image.Image, size: int = 768) -> Image.Image:
97
- """Resize image to square, preserving aspect ratio with padding."""
 
 
98
  img.thumbnail((size, size), Image.LANCZOS)
99
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
100
- x = (size - img.width) // 2
101
- y = (size - img.height) // 2
102
- canvas.paste(img, (x, y))
103
  return canvas
104
 
105
-
106
- def _build_mask(person_img: Image.Image, cloth_type: str) -> Image.Image:
107
- """
108
- Build a rough inpainting mask based on cloth_type.
109
- For a proper implementation, use a segmentation model (e.g. SCHP).
110
- This simple version covers standard body regions.
111
- """
112
- w, h = person_img.size
113
- mask = Image.new("L", (w, h), 0)
114
- import PIL.ImageDraw as ImageDraw
115
- draw = ImageDraw.Draw(mask)
116
-
117
  if cloth_type == "upper":
118
- # Cover torso: from ~20% to ~65% height
119
- draw.rectangle([int(w * 0.1), int(h * 0.18), int(w * 0.9), int(h * 0.65)], fill=255)
120
  elif cloth_type == "lower":
121
- # Cover legs: from ~55% to ~100% height
122
- draw.rectangle([int(w * 0.05), int(h * 0.55), int(w * 0.95), int(h * 1.0)], fill=255)
123
  else: # overall / dress
124
- # Cover full body: from ~15% to ~100% height
125
- draw.rectangle([int(w * 0.05), int(h * 0.15), int(w * 0.95), int(h * 1.0)], fill=255)
126
-
127
  return mask
128
 
129
-
130
  # ---------------------------------------------------------------------------
131
- # Inference (ZeroGPU)
132
  # ---------------------------------------------------------------------------
133
  @spaces.GPU(duration=120)
134
  def run_tryon(
@@ -138,129 +99,80 @@ def run_tryon(
138
  num_steps: int,
139
  guidance_scale: float,
140
  seed: int,
141
- ) -> tuple[list, list]:
142
- """
143
- Run virtual try-on inference on HF ZeroGPU.
144
- Returns (gallery_images, downloadable_file_paths).
145
- """
146
  if person_image is None or garment_image is None:
147
- raise gr.Error("Please upload both a person image and a garment image.")
148
 
149
- pipe = load_pipeline()
150
 
151
- # Pre-process
152
- size = 768
153
- person_resized = _resize_and_pad(person_image.convert("RGB"), size)
154
- garment_resized = _resize_and_pad(garment_image.convert("RGB"), size)
155
- mask = _build_mask(person_resized, cloth_type)
156
 
157
- generator = torch.Generator(device="cuda")
158
- if seed == -1:
159
- seed = torch.randint(0, 2**32, (1,)).item()
160
- generator.manual_seed(int(seed))
161
 
162
- # Run pipeline
163
- try:
164
- # CatVTON custom call signature
165
- result = pipe(
166
- image=person_resized,
167
- condition_image=garment_resized,
168
- mask=mask,
169
- num_inference_steps=num_steps,
170
- guidance_scale=guidance_scale,
171
- generator=generator,
172
- )
173
- output_images = result if isinstance(result, list) else [result]
174
- except TypeError:
175
- # Diffusers fallback call signature
176
- result = pipe(
177
- prompt="a person wearing the garment, photorealistic, high quality",
178
- image=person_resized,
179
- mask_image=mask,
180
- num_inference_steps=num_steps,
181
- guidance_scale=guidance_scale,
182
- generator=generator,
183
- )
184
- output_images = result.images
185
 
186
- # Save outputs
187
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
188
  saved_paths = []
189
- pil_images = []
190
  for i, img in enumerate(output_images):
191
- if not isinstance(img, Image.Image):
192
- img = Image.fromarray(np.uint8(img))
193
- pil_images.append(img)
194
  path = os.path.join(OUTPUT_DIR, f"tryon_{timestamp}_{i}.png")
195
  img.save(path, format="PNG")
196
  saved_paths.append(path)
197
 
198
- return pil_images, saved_paths
199
-
200
 
201
  # ---------------------------------------------------------------------------
202
  # Gradio UI
203
  # ---------------------------------------------------------------------------
204
- EXAMPLES = [] # add example paths here if desired
205
-
206
- with gr.Blocks(title="Virtual Try-On — CatVTON", theme=gr.themes.Soft()) as demo:
207
  gr.Markdown(
208
  "# 👗 Virtual Try-On\n"
209
- "Upload a **person photo** and a **garment image**, then click **Try On**.\n\n"
210
- "> Runs on Hugging Face ZeroGPU (free A10G) — no local GPU or storage needed. \n"
211
- "> Generated images are saved to your device via the Download button."
212
  )
213
 
214
  with gr.Row():
215
- with gr.Column(scale=1):
216
- person_input = gr.Image(
217
- label="Person Photo",
218
- type="pil",
219
- height=400,
 
 
 
220
  )
221
- garment_input = gr.Image(
222
- label="Garment Image",
223
- type="pil",
224
- height=400,
225
- )
226
-
227
- with gr.Column(scale=1):
228
- output_gallery = gr.Gallery(
229
- label="Result",
230
- show_label=True,
231
- columns=1,
232
- height=400,
233
- )
234
- output_files = gr.File(
235
  label="⬇ Download to your device",
236
  file_count="multiple",
237
  interactive=False,
238
  )
239
 
240
- with gr.Row():
241
- cloth_type = gr.Radio(
242
- ["upper", "lower", "overall"],
243
- value="upper",
244
- label="Garment Type",
245
- info="upper = top/shirt, lower = pants/skirt, overall = dress/full outfit",
246
- )
247
-
248
- with gr.Accordion("Advanced Settings", open=False):
249
- with gr.Row():
250
- num_steps = gr.Slider(
251
- minimum=10, maximum=50, value=30, step=1,
252
- label="Inference Steps",
253
- )
254
- guidance = gr.Slider(
255
- minimum=1.0, maximum=10.0, value=2.5, step=0.5,
256
- label="Guidance Scale",
257
- )
258
- seed_input = gr.Number(
259
- label="Seed (-1 = random)", value=-1, precision=0,
260
- )
261
-
262
- try_btn = gr.Button("👗 Try On", variant="primary", size="lg")
263
-
264
  try_btn.click(
265
  fn=run_tryon,
266
  inputs=[person_input, garment_input, cloth_type, num_steps, guidance, seed_input],
@@ -269,16 +181,12 @@ with gr.Blocks(title="Virtual Try-On — CatVTON", theme=gr.themes.Soft()) as de
269
 
270
  gr.Markdown(
271
  "---\n"
272
- "**Notes:** \n"
273
- "- First run downloads the model (~2-4 GB) to HF persistent storage — takes a few minutes once. \n"
274
- "- Subsequent runs start immediately (model cached). \n"
275
- "- For best results: use a front-facing photo with clear garment visibility. \n"
276
- "- Built with [CatVTON](https://github.com/zhengchong/CatVTON) + "
277
- "[Gradio](https://gradio.app) + [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu)"
278
  )
279
 
280
-
281
- # Download model at Space startup (on HF servers, not locally)
282
  download_models()
283
 
284
  if __name__ == "__main__":
 
1
  """
2
+ Virtual Try-On — CatVTON + Hugging Face ZeroGPU
3
+ No local GPU or model storage needed. Generated images download to your device.
 
 
 
4
  """
5
 
6
  import datetime
 
12
  import spaces
13
  import torch
14
  from huggingface_hub import snapshot_download
15
+ from PIL import Image, ImageDraw
16
 
17
  # ---------------------------------------------------------------------------
18
+ # Persistent storage (/data on ZeroGPU Spaces, /tmp fallback)
19
  # ---------------------------------------------------------------------------
20
  DATA_DIR = "/data" if os.path.exists("/data") else "/tmp"
21
  MODELS_DIR = os.path.join(DATA_DIR, "catvton_models")
 
23
  os.makedirs(MODELS_DIR, exist_ok=True)
24
  os.makedirs(OUTPUT_DIR, exist_ok=True)
25
 
26
+ os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
27
+ os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
 
28
 
29
  # ---------------------------------------------------------------------------
30
+ # Model download runs once at Space startup on HF servers (not locally)
31
  # ---------------------------------------------------------------------------
32
+ CATVTON_REPO = "zhengchong/CatVTON"
33
+ CATVTON_LOCAL = os.path.join(MODELS_DIR, "CatVTON")
34
 
35
  def download_models():
36
+ if os.path.exists(os.path.join(CATVTON_LOCAL, "model_index.json")):
37
+ print("CatVTON already cached.")
38
+ return
39
+ print("Downloading CatVTON (~4 GB) to HF persistent storage…")
40
+ snapshot_download(
41
+ repo_id=CATVTON_REPO,
42
+ local_dir=CATVTON_LOCAL,
43
+ local_dir_use_symlinks=False,
44
+ ignore_patterns=["*.md", "*.txt", "*.py"],
45
+ )
46
+ print("CatVTON ready.")
47
 
48
  # ---------------------------------------------------------------------------
49
+ # Pipeline (loaded lazily inside @spaces.GPU)
50
  # ---------------------------------------------------------------------------
51
+ _pipe = None
52
+
53
+ def _get_pipe():
54
+ global _pipe
55
+ if _pipe is not None:
56
+ return _pipe
57
+ from diffusers import StableDiffusionInpaintPipeline
58
+ _pipe = StableDiffusionInpaintPipeline.from_pretrained(
59
+ CATVTON_LOCAL,
60
+ torch_dtype=torch.float16,
61
+ safety_checker=None,
62
+ requires_safety_checker=False,
63
+ ).to("cuda")
64
+ _pipe.set_progress_bar_config(disable=True)
65
+ print("Pipeline loaded on CUDA.")
66
+ return _pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # ---------------------------------------------------------------------------
69
+ # Image helpers
70
  # ---------------------------------------------------------------------------
71
+ TARGET_SIZE = 512
72
+
73
+ def _fit_to_square(img: Image.Image, size: int = TARGET_SIZE) -> Image.Image:
74
+ img = img.convert("RGB")
75
  img.thumbnail((size, size), Image.LANCZOS)
76
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
77
+ canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
 
 
78
  return canvas
79
 
80
+ def _make_mask(size: int, cloth_type: str) -> Image.Image:
81
+ mask = Image.new("L", (size, size), 0)
82
+ d = ImageDraw.Draw(mask)
 
 
 
 
 
 
 
 
 
83
  if cloth_type == "upper":
84
+ d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
 
85
  elif cloth_type == "lower":
86
+ d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
 
87
  else: # overall / dress
88
+ d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
 
 
89
  return mask
90
 
 
91
  # ---------------------------------------------------------------------------
92
+ # ZeroGPU inference
93
  # ---------------------------------------------------------------------------
94
  @spaces.GPU(duration=120)
95
  def run_tryon(
 
99
  num_steps: int,
100
  guidance_scale: float,
101
  seed: int,
102
+ ) -> tuple:
 
 
 
 
103
  if person_image is None or garment_image is None:
104
+ raise gr.Error("Please upload both a person photo and a garment image.")
105
 
106
+ pipe = _get_pipe()
107
 
108
+ person = _fit_to_square(person_image)
109
+ garment = _fit_to_square(garment_image)
110
+ mask = _make_mask(TARGET_SIZE, cloth_type)
 
 
111
 
112
+ rng = torch.Generator(device="cuda")
113
+ rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())
 
 
114
 
115
+ prompt = (
116
+ "a person wearing the garment in the reference image, "
117
+ "photorealistic, high quality, natural lighting"
118
+ )
119
+ negative = "blurry, distorted, deformed, low quality, artifacts"
120
+
121
+ result = pipe(
122
+ prompt=prompt,
123
+ negative_prompt=negative,
124
+ image=person,
125
+ mask_image=mask,
126
+ num_inference_steps=num_steps,
127
+ guidance_scale=guidance_scale,
128
+ generator=rng,
129
+ )
130
+ output_images = result.images
 
 
 
 
 
 
 
131
 
 
132
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
133
  saved_paths = []
 
134
  for i, img in enumerate(output_images):
 
 
 
135
  path = os.path.join(OUTPUT_DIR, f"tryon_{timestamp}_{i}.png")
136
  img.save(path, format="PNG")
137
  saved_paths.append(path)
138
 
139
+ return output_images, saved_paths
 
140
 
141
  # ---------------------------------------------------------------------------
142
  # Gradio UI
143
  # ---------------------------------------------------------------------------
144
+ with gr.Blocks(title="Virtual Try-On", theme=gr.themes.Soft()) as demo:
 
 
145
  gr.Markdown(
146
  "# 👗 Virtual Try-On\n"
147
+ "Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
148
+ "> Runs entirely on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed. \n"
149
+ "> Models download once to HF persistent storage. Images save to your device via the Download button."
150
  )
151
 
152
  with gr.Row():
153
+ with gr.Column():
154
+ person_input = gr.Image(label="Person Photo", type="pil", height=380)
155
+ garment_input = gr.Image(label="Garment Image", type="pil", height=380)
156
+ cloth_type = gr.Radio(
157
+ ["upper", "lower", "overall"],
158
+ value="upper",
159
+ label="Garment Type",
160
+ info="upper=top/shirt | lower=pants/skirt | overall=dress/full outfit",
161
  )
162
+ with gr.Accordion("Advanced", open=False):
163
+ num_steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
164
+ guidance = gr.Slider(1.0, 10.0, value=7.5, step=0.5, label="Guidance Scale")
165
+ seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
166
+ try_btn = gr.Button("👗 Try On", variant="primary", size="lg")
167
+
168
+ with gr.Column():
169
+ output_gallery = gr.Gallery(label="Result", columns=1, height=380)
170
+ output_files = gr.File(
 
 
 
 
 
171
  label="⬇ Download to your device",
172
  file_count="multiple",
173
  interactive=False,
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  try_btn.click(
177
  fn=run_tryon,
178
  inputs=[person_input, garment_input, cloth_type, num_steps, guidance, seed_input],
 
181
 
182
  gr.Markdown(
183
  "---\n"
184
+ "**Tips:** front-facing photo · garment on white/neutral background · upper body for shirts\n\n"
185
+ "First run: ~2-5 min (model download). Subsequent runs: ~15-30s.\n\n"
186
+ "Built with [CatVTON](https://github.com/zhengchong/CatVTON) · "
187
+ "[Gradio](https://gradio.app) · [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu)"
 
 
188
  )
189
 
 
 
190
  download_models()
191
 
192
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,13 +1,12 @@
1
- gradio>=4.44.0
2
  spaces
3
- torch
4
- torchvision
5
- diffusers>=0.27.0
6
- transformers>=4.40.0
7
- accelerate>=0.28.0
8
- huggingface_hub>=0.27.0
9
  Pillow>=10.0.0
10
  numpy>=1.24.0
11
  safetensors>=0.4.2
12
- omegaconf
13
- einops
 
1
+ gradio==4.44.0
2
  spaces
3
+ torch==2.3.1
4
+ torchvision==0.18.1
5
+ diffusers==0.29.2
6
+ transformers==4.44.2
7
+ accelerate==0.33.0
8
+ huggingface_hub>=0.24.0
9
  Pillow>=10.0.0
10
  numpy>=1.24.0
11
  safetensors>=0.4.2
12
+ einops==0.8.0