ArmanRV commited on
Commit
84301a0
·
verified ·
1 Parent(s): 0548bcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -71
app.py CHANGED
@@ -1,7 +1,6 @@
1
  # -*- coding: utf-8 -*-
2
  import os
3
  import time
4
- import tempfile
5
  from typing import List, Optional, Tuple
6
 
7
  import spaces
@@ -11,7 +10,6 @@ from PIL import Image
11
  import torch
12
  import numpy as np
13
  from torchvision import transforms
14
- from torchvision.transforms.functional import to_pil_image
15
 
16
  from huggingface_hub import login, snapshot_download
17
 
@@ -51,6 +49,7 @@ ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
51
  GARMENTS_DATASET = os.getenv("GARMENTS_DATASET", "").strip() # e.g. "ArmanRV/armanrv-garments"
52
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
53
 
 
54
  def ensure_garments_downloaded() -> None:
55
  """
56
  Downloads garments from HF Dataset into ./garments to avoid Space repo 1GB limit.
@@ -69,7 +68,6 @@ def ensure_garments_downloaded() -> None:
69
  return
70
 
71
  try:
72
- # Download snapshot to local garments/ (no symlinks for HF container)
73
  snapshot_download(
74
  repo_id=GARMENTS_DATASET,
75
  repo_type="dataset",
@@ -81,20 +79,27 @@ def ensure_garments_downloaded() -> None:
81
  except Exception as e:
82
  print("Garments download FAILED:", str(e)[:300])
83
 
 
84
  def list_garments() -> List[str]:
85
- try:
86
- files = []
87
- for f in os.listdir(GARMENT_DIR):
88
- if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
89
- files.append(f)
90
- files.sort()
91
  return files
92
- except Exception:
93
- return []
 
 
 
 
 
 
94
 
95
  def garment_path(filename: str) -> str:
96
  return os.path.join(GARMENT_DIR, filename)
97
 
 
98
  def load_garment_pil(filename: str) -> Optional[Image.Image]:
99
  if not filename:
100
  return None
@@ -106,8 +111,8 @@ def load_garment_pil(filename: str) -> Optional[Image.Image]:
106
  except Exception:
107
  return None
108
 
 
109
  def build_gallery_items(files: List[str]):
110
- # (image_path, caption) — caption empty for clean UI
111
  return [(garment_path(f), "") for f in files]
112
 
113
 
@@ -121,14 +126,7 @@ def clamp_int(x, lo, hi):
121
  x = lo
122
  return max(lo, min(hi, x))
123
 
124
- def pil_to_binary_mask(pil_image, threshold=0):
125
- np_image = np.array(pil_image)
126
- grayscale_image = Image.fromarray(np_image).convert("L")
127
- binary_mask = np.array(grayscale_image) > threshold
128
- mask = (binary_mask.astype(np.uint8) * 255)
129
- return Image.fromarray(mask)
130
 
131
- # global simple rate limit (helps avoid spam during internal demo)
132
  _last_call_ts = 0.0
133
  def allow_call(min_interval_sec: float = 2.5) -> Tuple[bool, str]:
134
  global _last_call_ts
@@ -145,13 +143,15 @@ def allow_call(min_interval_sec: float = 2.5) -> Tuple[bool, str]:
145
  # =========================
146
  base_path = "yisol/IDM-VTON"
147
 
148
- # device policy
149
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
150
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
151
-
152
  print("DEVICE:", DEVICE, "DTYPE:", DTYPE)
153
 
154
- # Load components
 
 
 
 
155
  unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=DTYPE)
156
  unet.requires_grad_(False)
157
 
@@ -164,25 +164,18 @@ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_enco
164
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=DTYPE)
165
 
166
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=DTYPE)
167
-
168
  vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=DTYPE)
169
 
170
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=DTYPE)
171
  UNet_Encoder.requires_grad_(False)
172
 
173
- # Parsing/OpenPose init
174
- # These are heavy; GPU intended. On CPU it may be very slow.
175
  parsing_model = Parsing(0)
176
  openpose_model = OpenPose(0)
177
 
178
- # Freeze
179
  for m in [UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two]:
180
  m.requires_grad_(False)
181
 
182
- tensor_transfrom = transforms.Compose(
183
- [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
184
- )
185
-
186
  pipe = TryonPipeline.from_pretrained(
187
  base_path,
188
  unet=unet,
@@ -200,9 +193,9 @@ pipe.unet_encoder = UNet_Encoder
200
 
201
 
202
  # =========================
203
- # Inference
204
  # =========================
205
- @spaces.GPU # ok on dedicated GPU too
206
  def start_tryon(
207
  human_pil: Image.Image,
208
  garm_img: Image.Image,
@@ -210,25 +203,21 @@ def start_tryon(
210
  crop_center: bool = True,
211
  denoise_steps: int = 25,
212
  seed: int = 42,
213
- ):
214
- """
215
- Simplified local try-on.
216
- Returns: (output_image, masked_preview)
217
- """
218
  device = "cuda" if torch.cuda.is_available() else "cpu"
219
  dtype = torch.float16 if device == "cuda" else torch.float32
220
 
221
- # move heavy models
222
  if device == "cuda":
223
  openpose_model.preprocessor.body_estimation.model.to(device)
224
  pipe.to(device)
225
  pipe.unet_encoder.to(device)
226
 
227
- # resize inputs to expected
228
  garm_img = garm_img.convert("RGB").resize((768, 1024))
229
  human_img_orig = human_pil.convert("RGB")
230
 
231
- # optional center crop
232
  if crop_center:
233
  width, height = human_img_orig.size
234
  target_width = int(min(width, height * (3 / 4)))
@@ -243,20 +232,16 @@ def start_tryon(
243
  else:
244
  human_img = human_img_orig.resize((768, 1024))
245
 
246
- # mask
247
  if auto_mask:
248
  keypoints = openpose_model(human_img.resize((384, 512)))
249
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
250
  mask, _ = get_mask_location("hd", "upper_body", model_parse, keypoints)
251
  mask = mask.resize((768, 1024))
252
  else:
253
- # if someday you add manual mask, you can pass it here
254
  mask = Image.new("L", (768, 1024), 0)
255
 
256
- mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
257
- mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
258
-
259
- # densepose
260
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
261
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
262
 
@@ -274,7 +259,7 @@ def start_tryon(
274
  pose_img = pose_img[:, :, ::-1]
275
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
276
 
277
- # prompts (fixed, like your API demo)
278
  garment_des = "a garment"
279
  prompt_main = "model is wearing " + garment_des
280
  prompt_cloth = "a photo of " + garment_des
@@ -283,19 +268,16 @@ def start_tryon(
283
  denoise_steps = clamp_int(denoise_steps, 20, 40)
284
  seed = clamp_int(seed, 0, 999999)
285
 
286
- # inference
287
  with torch.no_grad():
288
  if device == "cuda":
289
  autocast_ctx = torch.cuda.amp.autocast()
290
  else:
291
- # no autocast on cpu
292
  class _NoCtx:
293
  def __enter__(self): return None
294
  def __exit__(self, *args): return False
295
  autocast_ctx = _NoCtx()
296
 
297
  with autocast_ctx:
298
- # encode prompts
299
  (
300
  prompt_embeds,
301
  negative_prompt_embeds,
@@ -348,8 +330,8 @@ def start_tryon(
348
  if crop_center:
349
  out_img_rs = out_img.resize(crop_size)
350
  human_img_orig.paste(out_img_rs, (int(left), int(top)))
351
- return human_img_orig, mask_gray
352
- return out_img, mask_gray
353
 
354
 
355
  # =========================
@@ -366,7 +348,7 @@ def refresh_catalog():
366
  ensure_garments_downloaded()
367
  files = list_garments()
368
  items = build_gallery_items(files)
369
- status = "✅ Каталог обновлён" if files else "⚠️ Каталог пуст (dataset не скачался или нет файлов)"
370
  return items, files, None, status
371
 
372
  def on_gallery_select(files_list: List[str], evt: gr.SelectData):
@@ -377,31 +359,41 @@ def on_gallery_select(files_list: List[str], evt: gr.SelectData):
377
  return files_list[idx], f"👕 Выбрано: {files_list[idx]}"
378
 
379
  def tryon_ui(person_pil, selected_filename):
 
 
 
380
  ok, msg = allow_call(2.5)
381
  if not ok:
382
- return None, None, msg
 
383
 
384
  if person_pil is None:
385
- return None, None, "❌ Загрузите фото человека"
 
386
  if not selected_filename:
387
- return None, None, "❌ Выберите одежду из каталога"
 
388
 
389
  garm = load_garment_pil(selected_filename)
390
  if garm is None:
391
- return None, None, "❌ Не удалось загрузить выбранную одежду"
392
-
393
- out, masked = start_tryon(
394
- human_pil=person_pil,
395
- garm_img=garm,
396
- auto_mask=True,
397
- crop_center=True,
398
- denoise_steps=25,
399
- seed=42,
400
- )
401
- return out, masked, "✅ Готово"
402
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
- # ensure garments present at startup (best effort)
 
405
  ensure_garments_downloaded()
406
  _initial_files = list_garments()
407
  _initial_items = build_gallery_items(_initial_files)
@@ -432,8 +424,7 @@ with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
432
  status = gr.Textbox(value="Ожидание...", interactive=False)
433
 
434
  with gr.Column():
435
- out = gr.Image(label="Результат", type="pil", height=520)
436
- masked = gr.Image(label="Маска/предпросмотр (служебное)", type="pil", height=320)
437
 
438
  garment_gallery.select(
439
  fn=on_gallery_select,
@@ -450,9 +441,12 @@ with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
450
  run.click(
451
  fn=tryon_ui,
452
  inputs=[person, selected_garment_state],
453
- outputs=[out, masked, status],
454
  )
455
 
 
 
 
456
  if __name__ == "__main__":
457
  demo.launch(
458
  server_name="0.0.0.0",
 
1
  # -*- coding: utf-8 -*-
2
  import os
3
  import time
 
4
  from typing import List, Optional, Tuple
5
 
6
  import spaces
 
10
  import torch
11
  import numpy as np
12
  from torchvision import transforms
 
13
 
14
  from huggingface_hub import login, snapshot_download
15
 
 
49
  GARMENTS_DATASET = os.getenv("GARMENTS_DATASET", "").strip() # e.g. "ArmanRV/armanrv-garments"
50
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
51
 
52
+
53
  def ensure_garments_downloaded() -> None:
54
  """
55
  Downloads garments from HF Dataset into ./garments to avoid Space repo 1GB limit.
 
68
  return
69
 
70
  try:
 
71
  snapshot_download(
72
  repo_id=GARMENTS_DATASET,
73
  repo_type="dataset",
 
79
  except Exception as e:
80
  print("Garments download FAILED:", str(e)[:300])
81
 
82
+
83
  def list_garments() -> List[str]:
84
+ """
85
+ Recursively list images inside ./garments (handles dataset subfolders).
86
+ """
87
+ files: List[str] = []
88
+ if not os.path.isdir(GARMENT_DIR):
 
89
  return files
90
+ for root, _, fnames in os.walk(GARMENT_DIR):
91
+ for f in fnames:
92
+ if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
93
+ rel = os.path.relpath(os.path.join(root, f), GARMENT_DIR)
94
+ files.append(rel)
95
+ files.sort()
96
+ return files
97
+
98
 
99
  def garment_path(filename: str) -> str:
100
  return os.path.join(GARMENT_DIR, filename)
101
 
102
+
103
  def load_garment_pil(filename: str) -> Optional[Image.Image]:
104
  if not filename:
105
  return None
 
111
  except Exception:
112
  return None
113
 
114
+
115
  def build_gallery_items(files: List[str]):
 
116
  return [(garment_path(f), "") for f in files]
117
 
118
 
 
126
  x = lo
127
  return max(lo, min(hi, x))
128
 
 
 
 
 
 
 
129
 
 
130
  _last_call_ts = 0.0
131
  def allow_call(min_interval_sec: float = 2.5) -> Tuple[bool, str]:
132
  global _last_call_ts
 
143
  # =========================
144
  base_path = "yisol/IDM-VTON"
145
 
 
146
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
147
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
148
  print("DEVICE:", DEVICE, "DTYPE:", DTYPE)
149
 
150
+ tensor_transfrom = transforms.Compose(
151
+ [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
152
+ )
153
+
154
+ # Components
155
  unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=DTYPE)
156
  unet.requires_grad_(False)
157
 
 
164
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=DTYPE)
165
 
166
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=DTYPE)
 
167
  vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=DTYPE)
168
 
169
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=DTYPE)
170
  UNet_Encoder.requires_grad_(False)
171
 
172
+ # Preprocessors
 
173
  parsing_model = Parsing(0)
174
  openpose_model = OpenPose(0)
175
 
 
176
  for m in [UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two]:
177
  m.requires_grad_(False)
178
 
 
 
 
 
179
  pipe = TryonPipeline.from_pretrained(
180
  base_path,
181
  unet=unet,
 
193
 
194
 
195
  # =========================
196
+ # Inference (returns ONLY final image)
197
  # =========================
198
+ @spaces.GPU
199
  def start_tryon(
200
  human_pil: Image.Image,
201
  garm_img: Image.Image,
 
203
  crop_center: bool = True,
204
  denoise_steps: int = 25,
205
  seed: int = 42,
206
+ ) -> Image.Image:
207
+
 
 
 
208
  device = "cuda" if torch.cuda.is_available() else "cpu"
209
  dtype = torch.float16 if device == "cuda" else torch.float32
210
 
211
+ # Move models
212
  if device == "cuda":
213
  openpose_model.preprocessor.body_estimation.model.to(device)
214
  pipe.to(device)
215
  pipe.unet_encoder.to(device)
216
 
 
217
  garm_img = garm_img.convert("RGB").resize((768, 1024))
218
  human_img_orig = human_pil.convert("RGB")
219
 
220
+ # Crop
221
  if crop_center:
222
  width, height = human_img_orig.size
223
  target_width = int(min(width, height * (3 / 4)))
 
232
  else:
233
  human_img = human_img_orig.resize((768, 1024))
234
 
235
+ # Mask
236
  if auto_mask:
237
  keypoints = openpose_model(human_img.resize((384, 512)))
238
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
239
  mask, _ = get_mask_location("hd", "upper_body", model_parse, keypoints)
240
  mask = mask.resize((768, 1024))
241
  else:
 
242
  mask = Image.new("L", (768, 1024), 0)
243
 
244
+ # DensePose
 
 
 
245
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
246
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
247
 
 
259
  pose_img = pose_img[:, :, ::-1]
260
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
261
 
262
+ # Fixed prompts (like your API demo)
263
  garment_des = "a garment"
264
  prompt_main = "model is wearing " + garment_des
265
  prompt_cloth = "a photo of " + garment_des
 
268
  denoise_steps = clamp_int(denoise_steps, 20, 40)
269
  seed = clamp_int(seed, 0, 999999)
270
 
 
271
  with torch.no_grad():
272
  if device == "cuda":
273
  autocast_ctx = torch.cuda.amp.autocast()
274
  else:
 
275
  class _NoCtx:
276
  def __enter__(self): return None
277
  def __exit__(self, *args): return False
278
  autocast_ctx = _NoCtx()
279
 
280
  with autocast_ctx:
 
281
  (
282
  prompt_embeds,
283
  negative_prompt_embeds,
 
330
  if crop_center:
331
  out_img_rs = out_img.resize(crop_size)
332
  human_img_orig.paste(out_img_rs, (int(left), int(top)))
333
+ return human_img_orig
334
+ return out_img
335
 
336
 
337
  # =========================
 
348
  ensure_garments_downloaded()
349
  files = list_garments()
350
  items = build_gallery_items(files)
351
+ status = "✅ Каталог обновлён" if files else "⚠️ Каталог пуст (проверь dataset/токен)"
352
  return items, files, None, status
353
 
354
  def on_gallery_select(files_list: List[str], evt: gr.SelectData):
 
359
  return files_list[idx], f"👕 Выбрано: {files_list[idx]}"
360
 
361
  def tryon_ui(person_pil, selected_filename):
362
+ # Сразу показать что кнопка сработала
363
+ yield None, "⏳ Обработка... (первый запуск может быть дольше)"
364
+
365
  ok, msg = allow_call(2.5)
366
  if not ok:
367
+ yield None, msg
368
+ return
369
 
370
  if person_pil is None:
371
+ yield None, "❌ Загрузите фото человека"
372
+ return
373
  if not selected_filename:
374
+ yield None, "❌ Выберите одежду (клик по превью)"
375
+ return
376
 
377
  garm = load_garment_pil(selected_filename)
378
  if garm is None:
379
+ yield None, "❌ Не удалось загрузить выбранную одежду"
380
+ return
 
 
 
 
 
 
 
 
 
381
 
382
+ try:
383
+ out = start_tryon(
384
+ human_pil=person_pil,
385
+ garm_img=garm,
386
+ auto_mask=True,
387
+ crop_center=True,
388
+ denoise_steps=25,
389
+ seed=42,
390
+ )
391
+ yield out, "✅ Готово"
392
+ except Exception as e:
393
+ yield None, f"❌ Ошибка: {type(e).__name__}: {str(e)[:220]}"
394
 
395
+
396
+ # Preload garments (best-effort)
397
  ensure_garments_downloaded()
398
  _initial_files = list_garments()
399
  _initial_items = build_gallery_items(_initial_files)
 
424
  status = gr.Textbox(value="Ожидание...", interactive=False)
425
 
426
  with gr.Column():
427
+ out = gr.Image(label="Результат", type="pil", height=760)
 
428
 
429
  garment_gallery.select(
430
  fn=on_gallery_select,
 
441
  run.click(
442
  fn=tryon_ui,
443
  inputs=[person, selected_garment_state],
444
+ outputs=[out, status],
445
  )
446
 
447
+ # IMPORTANT: queue helps stability on GPU
448
+ demo.queue(concurrency_count=1, max_size=20)
449
+
450
  if __name__ == "__main__":
451
  demo.launch(
452
  server_name="0.0.0.0",