ArmanRV commited on
Commit
f364720
·
verified ·
1 Parent(s): 0cb8e39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -102
app.py CHANGED
@@ -1,30 +1,12 @@
1
  # -*- coding: utf-8 -*-
2
- """
3
- Virtual Try-On Rendez-vous — production wrapper for IDM-VTON (SDXL)
4
-
5
- Что изменено по твоему запросу (убрано/исправлено):
6
- 1) НЕТ “жёстко upper_body для всего” — маска выбирается АВТО по имени/папке одежды (dress/lower/upper),
7
- либо можно отключить авто-маску полностью.
8
- 2) НЕТ fixed strength=1.0 — strength настраиваемый (по умолчанию 0.9).
9
- 3) НЕТ фиксированных промптов “a garment” — промпт генерируется из имени файла/папки одежды + эвристики,
10
- можно переопределить вручную.
11
- 4) НЕТ crop-center + paste обратно — используется letterbox (масштаб с сохранением пропорций + padding),
12
- затем padding убирается, и результат возвращается в исходный размер.
13
- 5) НЕТ принудительного 768×1024 “всегда” — размер выбирается ДИНАМИЧЕСКИ от входного фото (с ограничением max_side),
14
- кратно 8.
15
- 6) НЕТ низких/фиксированных CFG/steps/seed — все параметры управляемые в UI; seed может быть -1 (рандом).
16
-
17
- Остальное (датасет одежды, галерея, queue, patch gradio_client) оставлено как инфраструктура.
18
- """
19
  import os
20
  import re
21
  import time
22
- import math
23
- from typing import List, Optional, Tuple, Dict, Any
24
 
25
  import spaces
26
  import gradio as gr
27
- from PIL import Image, ImageOps
28
 
29
  # =========================
30
  # FIX: gradio 4.24 / gradio_client crashes on boolean JSON Schemas in /api_info
@@ -116,14 +98,11 @@ APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
116
  # =========================
117
  GARMENT_DIR = "garments"
118
  ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
119
- GARMENTS_DATASET = os.getenv("GARMENTS_DATASET", "").strip() # e.g. "ArmanRV/armanrv-garments"
120
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
121
 
122
 
123
  def ensure_garments_downloaded() -> None:
124
- """
125
- Downloads garments from HF Dataset into ./garments to avoid Space repo 1GB limit.
126
- """
127
  os.makedirs(GARMENT_DIR, exist_ok=True)
128
 
129
  if HF_TOKEN:
@@ -151,9 +130,6 @@ def ensure_garments_downloaded() -> None:
151
 
152
 
153
  def list_garments() -> List[str]:
154
- """
155
- Recursively list images inside ./garments (handles dataset subfolders).
156
- """
157
  files: List[str] = []
158
  if not os.path.isdir(GARMENT_DIR):
159
  return files
@@ -183,7 +159,6 @@ def load_garment_pil(filename: str) -> Optional[Image.Image]:
183
 
184
 
185
  def build_gallery_items(files: List[str]):
186
- # Gallery items format: [(filepath, caption), ...]
187
  return [(garment_path(f), "") for f in files]
188
 
189
 
@@ -225,17 +200,15 @@ def round_to_multiple(x: int, m: int = 8) -> int:
225
 
226
  def pick_target_size_keep_aspect(w: int, h: int, max_side: int) -> Tuple[int, int]:
227
  """
228
- Возвращает (tw, th) <= max_side по большей стороне, кратно 8.
229
  """
230
  if w <= 0 or h <= 0:
231
  return 768, 1024
232
  scale = min(max_side / float(max(w, h)), 1.0)
233
  tw = round_to_multiple(int(w * scale), 8)
234
  th = round_to_multiple(int(h * scale), 8)
235
- # защитимся от слишком маленьких
236
  tw = max(512, tw)
237
  th = max(512, th)
238
- # еще раз не превышать max_side
239
  if max(tw, th) > max_side:
240
  scale2 = max_side / float(max(tw, th))
241
  tw = round_to_multiple(int(tw * scale2), 8)
@@ -243,10 +216,10 @@ def pick_target_size_keep_aspect(w: int, h: int, max_side: int) -> Tuple[int, in
243
  return tw, th
244
 
245
 
246
- def letterbox(img: Image.Image, target_w: int, target_h: int, fill=(0, 0, 0)) -> Tuple[Image.Image, Dict[str, int]]:
247
  """
248
- Масштабирует с сохранением пропорций + padding до target_w/target_h.
249
- Возвращает (img_lb, meta) где meta содержит offset/size для обратного unletterbox.
250
  """
251
  src_w, src_h = img.size
252
  if src_w <= 0 or src_h <= 0:
@@ -262,45 +235,44 @@ def letterbox(img: Image.Image, target_w: int, target_h: int, fill=(0, 0, 0)) ->
262
  x = (target_w - new_w) // 2
263
  y = (target_h - new_h) // 2
264
  canvas.paste(img_rs, (x, y))
265
- meta = {"x": x, "y": y, "w": new_w, "h": new_h, "src_w": src_w, "src_h": src_h}
266
- return canvas, meta
267
 
268
 
269
  def unletterbox(img_lb: Image.Image, meta: Dict[str, int]) -> Image.Image:
 
 
 
 
 
270
  """
271
- Вырезает область без padding и возвращает как есть (потом можно resize к исходнику).
272
  """
273
  x, y, w, h = meta["x"], meta["y"], meta["w"], meta["h"]
274
- return img_lb.crop((x, y, x + w, y + h))
 
 
 
 
275
 
276
 
277
  def infer_garment_class_from_path(relpath: str) -> str:
278
  """
279
- Возвращает тип для get_mask_location: 'upper_body' | 'lower_body' | 'dresses'
280
- Это НЕ “жестко upper_body” — эвристика по папке/имени.
281
  """
282
- s = (relpath or "").lower()
283
- # папки/имена под платья
284
- if any(k in s for k in ["dress", "dresses", "suk", "plate", "плать", "sarafan"]):
285
  return "dresses"
286
- # низ
287
  if any(k in s for k in ["pants", "trouser", "jeans", "skirt", "short", "брюк", "джин", "юбк", "шорт"]):
288
  return "lower_body"
289
- # верх по умолчанию
290
  return "upper_body"
291
 
292
 
293
  def guess_garment_description(relpath: str) -> str:
294
- """
295
- Генерирует более полезное текстовое описание одежды из имени файла/папки.
296
- (Это замена твоего фиксированного 'a garment'.)
297
- """
298
- s = (relpath or "").replace("\\", "/").lower()
299
- # словарь эвристик
300
  mapping = [
301
- (["shearling", "dub", "дублен", "sheepskin"], "a shearling jacket"),
302
  (["coat", "пальт", "overcoat"], "a coat"),
303
- (["jacket", "куртк", "bomber", "парка", "parka"], "a jacket"),
304
  (["blazer", "пидж", "suit"], "a blazer"),
305
  (["hoodie", "худи"], "a hoodie"),
306
  (["sweater", "свит", "jumper"], "a sweater"),
@@ -314,18 +286,56 @@ def guess_garment_description(relpath: str) -> str:
314
  if any(k in s for k in keys):
315
  return desc
316
 
317
- # иначе — попытка вытащить “человеческое” имя
318
  base = os.path.splitext(os.path.basename(s))[0]
319
  base = re.sub(r"[_\-]+", " ", base)
320
  base = re.sub(r"\d+", " ", base)
321
  base = re.sub(r"\s+", " ", base).strip()
322
  if len(base) >= 3:
323
- # ограничим длину
324
- words = base.split()[:4]
325
- return "a " + " ".join(words)
326
  return "a piece of clothing"
327
 
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  # =========================
330
  # Model init (local IDM-VTON)
331
  # =========================
@@ -337,7 +347,6 @@ print("DEVICE:", DEVICE, "DTYPE:", DTYPE, flush=True)
337
 
338
  tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
339
 
340
- # Components
341
  unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=DTYPE)
342
  unet.requires_grad_(False)
343
 
@@ -355,7 +364,6 @@ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=DTYP
355
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=DTYPE)
356
  UNet_Encoder.requires_grad_(False)
357
 
358
- # Preprocessors
359
  parsing_model = Parsing(0)
360
  openpose_model = OpenPose(0)
361
 
@@ -379,58 +387,74 @@ pipe.unet_encoder = UNet_Encoder
379
 
380
 
381
  # =========================
382
- # Inference (returns ONLY final image)
383
  # =========================
384
  @spaces.GPU
385
  def start_tryon(
386
  human_pil: Image.Image,
387
  garm_img: Image.Image,
388
  garm_relpath: str = "",
 
389
  auto_mask: bool = True,
390
- denoise_steps: int = 30,
391
- guidance_scale: float = 3.5,
 
 
392
  strength: float = 0.90,
393
  seed: int = -1,
394
  max_side: int = 1024,
395
  prompt_override: str = "",
396
  negative_prompt: str = "monochrome, lowres, bad anatomy, worst quality, low quality",
397
  ) -> Image.Image:
398
- # pick device/dtype
399
  device = "cuda" if torch.cuda.is_available() else "cpu"
400
  dtype = torch.float16 if device == "cuda" else torch.float32
401
 
402
- # Move models
403
  if device == "cuda":
404
  openpose_model.preprocessor.body_estimation.model.to(device)
 
405
  pipe.to(device)
406
  pipe.unet_encoder.to(device)
407
 
408
- # --- sizes (dynamic, no forced 768x1024) ---
409
  human_img_orig = human_pil.convert("RGB")
410
  src_w, src_h = human_img_orig.size
 
411
  target_w, target_h = pick_target_size_keep_aspect(src_w, src_h, max_side=max_side)
412
 
413
- # letterbox to target size (no crop-center, no paste-back)
414
- human_lb, lb_meta = letterbox(human_img_orig, target_w, target_h, fill=(0, 0, 0))
415
  garm_img = garm_img.convert("RGB")
416
- garm_lb, _ = letterbox(garm_img, target_w, target_h, fill=(0, 0, 0))
417
 
418
- # --- Mask (not fixed upper_body) ---
419
- if auto_mask:
420
- # preprocess runs on 384x512; use letterbox to avoid distortion
421
- human_384, _m = letterbox(human_lb, 384, 512, fill=(0, 0, 0))
422
- keypoints = openpose_model(human_384)
423
- model_parse, _ = parsing_model(human_384)
424
 
 
 
 
 
425
  cloth_class = infer_garment_class_from_path(garm_relpath)
426
- mask, _ = get_mask_location("hd", cloth_class, model_parse, keypoints)
427
- # upscale mask back to target size
428
- mask = mask.resize((target_w, target_h), Image.BILINEAR)
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  else:
430
  mask = Image.new("L", (target_w, target_h), 0)
431
 
432
- # --- DensePose ---
433
- human_dp = _apply_exif_orientation(human_lb.resize((384, 512)))
434
  human_dp = convert_PIL_to_numpy(human_dp, format="BGR")
435
 
436
  args = apply_net.create_argument_parser().parse_args(
@@ -445,11 +469,12 @@ def start_tryon(
445
  "cuda" if device == "cuda" else "cpu",
446
  )
447
  )
448
- pose_img = args.func(args, human_dp)
449
- pose_img = pose_img[:, :, ::-1]
450
- pose_img = Image.fromarray(pose_img).resize((target_w, target_h), Image.BILINEAR)
 
451
 
452
- # --- prompts (not fixed “a garment”) ---
453
  garment_desc = guess_garment_description(garm_relpath)
454
  if prompt_override and prompt_override.strip():
455
  garment_desc = prompt_override.strip()
@@ -457,15 +482,14 @@ def start_tryon(
457
  prompt_main = f"model is wearing {garment_desc}"
458
  prompt_cloth = f"a photo of {garment_desc}"
459
 
460
- # --- params (no fixed low steps/cfg/seed) ---
461
  denoise_steps = clamp_int(denoise_steps, 15, 60)
462
  guidance_scale = clamp_float(guidance_scale, 0.0, 12.0)
463
  strength = clamp_float(strength, 0.50, 1.00)
464
- if seed is None:
465
- seed = -1
466
- seed = int(seed)
467
  if seed < 0:
468
- # random but reproducible per call if needed
469
  seed = int.from_bytes(os.urandom(2), "big") + int(time.time() * 1000) % 1000000
470
 
471
  with torch.no_grad():
@@ -504,7 +528,6 @@ def start_tryon(
504
 
505
  pose_t = tensor_transfrom(pose_img).unsqueeze(0).to(device=device, dtype=dtype)
506
  garm_t = tensor_transfrom(garm_lb).unsqueeze(0).to(device=device, dtype=dtype)
507
-
508
  generator = torch.Generator(device).manual_seed(seed)
509
 
510
  images = pipe(
@@ -514,7 +537,7 @@ def start_tryon(
514
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device=device, dtype=dtype),
515
  num_inference_steps=denoise_steps,
516
  generator=generator,
517
- strength=strength, # <-- not fixed 1.0
518
  pose_img=pose_t,
519
  text_embeds_cloth=prompt_embeds_c.to(device=device, dtype=dtype),
520
  cloth=garm_t,
@@ -522,13 +545,13 @@ def start_tryon(
522
  image=human_lb,
523
  height=target_h,
524
  width=target_w,
525
- ip_adapter_image=garm_lb, # keep conditioning, but not hard-resized 768x1024
526
- guidance_scale=guidance_scale, # <-- not fixed low value
527
  )[0]
528
 
529
  out_img_lb = images[0].convert("RGB")
530
 
531
- # remove letterbox padding and resize back to original size (no crop-center paste)
532
  out_core = unletterbox(out_img_lb, lb_meta)
533
  out_final = out_core.resize((src_w, src_h), Image.LANCZOS)
534
  return out_final
@@ -563,7 +586,10 @@ def on_gallery_select(files_list: List[str], evt: gr.SelectData):
563
  def tryon_ui(
564
  person_pil,
565
  selected_filename,
 
566
  auto_mask,
 
 
567
  steps,
568
  cfg,
569
  strength,
@@ -595,7 +621,10 @@ def tryon_ui(
595
  human_pil=person_pil,
596
  garm_img=garm,
597
  garm_relpath=selected_filename,
 
598
  auto_mask=bool(auto_mask),
 
 
599
  denoise_steps=int(steps),
600
  guidance_scale=float(cfg),
601
  strength=float(strength),
@@ -608,7 +637,7 @@ def tryon_ui(
608
  yield None, f"❌ Ошибка: {type(e).__name__}: {str(e)[:220]}"
609
 
610
 
611
- # Preload garments
612
  ensure_garments_downloaded()
613
  _initial_files = list_garments()
614
  _initial_items = build_gallery_items(_initial_files)
@@ -635,17 +664,34 @@ with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
635
  allow_preview=True,
636
  )
637
 
638
- with gr.Accordion("⚙️ Настройки качества", open=False):
639
- auto_mask = gr.Checkbox(value=True, label="Auto mask (парсинг + поза)")
640
- steps = gr.Slider(15, 60, value=30, step=1, label="Шаги (num_inference_steps)")
641
- cfg = gr.Slider(0.0, 12.0, value=3.5, step=0.1, label="Guidance scale (CFG)")
642
- strength = gr.Slider(0.50, 1.00, value=0.90, step=0.01, label="Strength (насколько сильно перерисовывать)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = случайный)")
644
- max_side = gr.Slider(768, 1408, value=1024, step=64, label="Максимальный размер стороны (динамический)")
 
645
  prompt_override = gr.Textbox(
646
  value="",
647
  label="Описание одежды (опц.)",
648
- placeholder="Напр.: a black leather jacket / a blazer / a coat ... (если пусто — авто по имени файла)",
649
  )
650
 
651
  run = gr.Button("Примерить", variant="primary")
@@ -668,7 +714,20 @@ with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
668
 
669
  run.click(
670
  fn=tryon_ui,
671
- inputs=[person, selected_garment_state, auto_mask, steps, cfg, strength, seed, max_side, prompt_override],
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  outputs=[out, status],
673
  concurrency_limit=1,
674
  )
 
1
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import re
4
  import time
5
+ from typing import List, Optional, Tuple, Dict
 
6
 
7
  import spaces
8
  import gradio as gr
9
+ from PIL import Image
10
 
11
  # =========================
12
  # FIX: gradio 4.24 / gradio_client crashes on boolean JSON Schemas in /api_info
 
98
  # =========================
99
  GARMENT_DIR = "garments"
100
  ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
101
+ GARMENTS_DATASET = os.getenv("GARMENTS_DATASET", "").strip()
102
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
103
 
104
 
105
  def ensure_garments_downloaded() -> None:
 
 
 
106
  os.makedirs(GARMENT_DIR, exist_ok=True)
107
 
108
  if HF_TOKEN:
 
130
 
131
 
132
  def list_garments() -> List[str]:
 
 
 
133
  files: List[str] = []
134
  if not os.path.isdir(GARMENT_DIR):
135
  return files
 
159
 
160
 
161
  def build_gallery_items(files: List[str]):
 
162
  return [(garment_path(f), "") for f in files]
163
 
164
 
 
200
 
201
  def pick_target_size_keep_aspect(w: int, h: int, max_side: int) -> Tuple[int, int]:
202
  """
203
+ (tw, th) <= max_side по большей стороне, кратно 8
204
  """
205
  if w <= 0 or h <= 0:
206
  return 768, 1024
207
  scale = min(max_side / float(max(w, h)), 1.0)
208
  tw = round_to_multiple(int(w * scale), 8)
209
  th = round_to_multiple(int(h * scale), 8)
 
210
  tw = max(512, tw)
211
  th = max(512, th)
 
212
  if max(tw, th) > max_side:
213
  scale2 = max_side / float(max(tw, th))
214
  tw = round_to_multiple(int(tw * scale2), 8)
 
216
  return tw, th
217
 
218
 
219
+ def letterbox(img: Image.Image, target_w: int, target_h: int, fill=(127, 127, 127)) -> Tuple[Image.Image, Dict[str, int]]:
220
  """
221
+ Resize with aspect + padding to (target_w,target_h).
222
+ meta: x,y,w,h for core region inside padded canvas
223
  """
224
  src_w, src_h = img.size
225
  if src_w <= 0 or src_h <= 0:
 
235
  x = (target_w - new_w) // 2
236
  y = (target_h - new_h) // 2
237
  canvas.paste(img_rs, (x, y))
238
+ return canvas, {"x": x, "y": y, "w": new_w, "h": new_h, "src_w": src_w, "src_h": src_h}
 
239
 
240
 
241
  def unletterbox(img_lb: Image.Image, meta: Dict[str, int]) -> Image.Image:
242
+ x, y, w, h = meta["x"], meta["y"], meta["w"], meta["h"]
243
+ return img_lb.crop((x, y, x + w, y + h))
244
+
245
+
246
+ def paste_into_canvas(canvas_mode: str, canvas_size: Tuple[int, int], core_img: Image.Image, meta: Dict[str, int], fill):
247
  """
248
+ Вклеивает core_img в канвас (target_w,target_h) по meta x,y.
249
  """
250
  x, y, w, h = meta["x"], meta["y"], meta["w"], meta["h"]
251
+ canvas = Image.new(canvas_mode, canvas_size, fill)
252
+ if core_img.size != (w, h):
253
+ core_img = core_img.resize((w, h), Image.BILINEAR)
254
+ canvas.paste(core_img, (x, y))
255
+ return canvas
256
 
257
 
258
  def infer_garment_class_from_path(relpath: str) -> str:
259
  """
260
+ 'upper_body' | 'lower_body' | 'dresses'
 
261
  """
262
+ s = (relpath or "").lower().replace("\\", "/")
263
+ if any(k in s for k in ["dress", "dresses", "sarafan", "plate", "плать", "сараф"]):
 
264
  return "dresses"
 
265
  if any(k in s for k in ["pants", "trouser", "jeans", "skirt", "short", "брюк", "джин", "юбк", "шорт"]):
266
  return "lower_body"
 
267
  return "upper_body"
268
 
269
 
270
  def guess_garment_description(relpath: str) -> str:
271
+ s = (relpath or "").lower().replace("\\", "/")
 
 
 
 
 
272
  mapping = [
273
+ (["shearling", "дублен", "sheepskin"], "a shearling jacket"),
274
  (["coat", "пальт", "overcoat"], "a coat"),
275
+ (["jacket", "куртк", "парка", "parka", "bomber"], "a jacket"),
276
  (["blazer", "пидж", "suit"], "a blazer"),
277
  (["hoodie", "худи"], "a hoodie"),
278
  (["sweater", "свит", "jumper"], "a sweater"),
 
286
  if any(k in s for k in keys):
287
  return desc
288
 
 
289
  base = os.path.splitext(os.path.basename(s))[0]
290
  base = re.sub(r"[_\-]+", " ", base)
291
  base = re.sub(r"\d+", " ", base)
292
  base = re.sub(r"\s+", " ", base).strip()
293
  if len(base) >= 3:
294
+ return "a " + " ".join(base.split()[:4])
 
 
295
  return "a piece of clothing"
296
 
297
 
298
+ def apply_safety_clamp(mask_full: Image.Image, meta: Dict[str, int], garment_class: str, clamp_strength: float) -> Image.Image:
299
+ """
300
+ Универсальная страховка от “уехало вниз/вверх”:
301
+ - upper_body: оставляем маску выше линии бёдер (чем больше clamp_strength, тем “выше” граница)
302
+ - lower_body: оставляем маску ниже линии талии/бёдер (чем больше clamp_strength, тем “ниже” граница)
303
+ - dresses: не трогаем
304
+
305
+ clamp_strength: 0..1 (0 = почти не влияет, 1 = сильнее)
306
+ """
307
+ if garment_class == "dresses":
308
+ return mask_full
309
+
310
+ tw, th = mask_full.size
311
+ x, y, w, h = meta["x"], meta["y"], meta["w"], meta["h"]
312
+
313
+ # базовые линии (проценты по core высоте) — эмпирика для full-body
314
+ # upper_body: граница где-то около 0.60..0.72 от высоты core
315
+ # lower_body: граница около 0.34..0.48 от высоты core
316
+ clamp_strength = clamp_float(clamp_strength, 0.0, 1.0)
317
+
318
+ if garment_class == "upper_body":
319
+ lo, hi = 0.60, 0.72
320
+ frac = lo + (hi - lo) * (1.0 - clamp_strength) # clamp_strength↑ => граница ближе к lo (выше)
321
+ cut_y = y + int(frac * h)
322
+ keep = mask_full.crop((0, 0, tw, max(0, min(th, cut_y))))
323
+ out = Image.new("L", (tw, th), 0)
324
+ out.paste(keep, (0, 0))
325
+ return out
326
+
327
+ if garment_class == "lower_body":
328
+ lo, hi = 0.34, 0.48
329
+ frac = lo + (hi - lo) * (clamp_strength) # clamp_strength↑ => граница ближе к hi (ниже)
330
+ cut_y = y + int(frac * h)
331
+ keep = mask_full.crop((0, max(0, min(th, cut_y)), tw, th))
332
+ out = Image.new("L", (tw, th), 0)
333
+ out.paste(keep, (0, max(0, min(th, cut_y))))
334
+ return out
335
+
336
+ return mask_full
337
+
338
+
339
  # =========================
340
  # Model init (local IDM-VTON)
341
  # =========================
 
347
 
348
  tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
349
 
 
350
  unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=DTYPE)
351
  unet.requires_grad_(False)
352
 
 
364
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=DTYPE)
365
  UNet_Encoder.requires_grad_(False)
366
 
 
367
  parsing_model = Parsing(0)
368
  openpose_model = OpenPose(0)
369
 
 
387
 
388
 
389
  # =========================
390
+ # Inference
391
  # =========================
392
  @spaces.GPU
393
  def start_tryon(
394
  human_pil: Image.Image,
395
  garm_img: Image.Image,
396
  garm_relpath: str = "",
397
+ garment_type_override: str = "auto", # auto | upper_body | lower_body | dresses
398
  auto_mask: bool = True,
399
+ safety_clamp: bool = True,
400
+ clamp_strength: float = 0.55, # 0..1
401
+ denoise_steps: int = 34,
402
+ guidance_scale: float = 3.8,
403
  strength: float = 0.90,
404
  seed: int = -1,
405
  max_side: int = 1024,
406
  prompt_override: str = "",
407
  negative_prompt: str = "monochrome, lowres, bad anatomy, worst quality, low quality",
408
  ) -> Image.Image:
 
409
  device = "cuda" if torch.cuda.is_available() else "cpu"
410
  dtype = torch.float16 if device == "cuda" else torch.float32
411
 
 
412
  if device == "cuda":
413
  openpose_model.preprocessor.body_estimation.model.to(device)
414
+
415
  pipe.to(device)
416
  pipe.unet_encoder.to(device)
417
 
 
418
  human_img_orig = human_pil.convert("RGB")
419
  src_w, src_h = human_img_orig.size
420
+
421
  target_w, target_h = pick_target_size_keep_aspect(src_w, src_h, max_side=max_side)
422
 
423
+ # letterbox for model canvas (important: gray padding)
424
+ human_lb, lb_meta = letterbox(human_img_orig, target_w, target_h, fill=(127, 127, 127))
425
  garm_img = garm_img.convert("RGB")
426
+ garm_lb, _ = letterbox(garm_img, target_w, target_h, fill=(127, 127, 127))
427
 
428
+ # Core region (no padding) — IMPORTANT for preprocessors
429
+ human_core = unletterbox(human_lb, lb_meta)
430
+ x, y, w, h = lb_meta["x"], lb_meta["y"], lb_meta["w"], lb_meta["h"]
 
 
 
431
 
432
+ # garment class
433
+ if garment_type_override and garment_type_override != "auto":
434
+ cloth_class = garment_type_override
435
+ else:
436
  cloth_class = infer_garment_class_from_path(garm_relpath)
437
+
438
+ # ---- MASK (compute on core -> paste to full) ----
439
+ if auto_mask:
440
+ human_core_384 = human_core.resize((384, 512), Image.BILINEAR)
441
+ keypoints = openpose_model(human_core_384)
442
+ model_parse, _ = parsing_model(human_core_384)
443
+
444
+ mask_core_384, _ = get_mask_location("hd", cloth_class, model_parse, keypoints)
445
+ mask_core = mask_core_384.resize((w, h), Image.BILINEAR)
446
+
447
+ mask_full = Image.new("L", (target_w, target_h), 0)
448
+ mask_full.paste(mask_core, (x, y))
449
+
450
+ if safety_clamp:
451
+ mask_full = apply_safety_clamp(mask_full, lb_meta, cloth_class, clamp_strength)
452
+ mask = mask_full
453
  else:
454
  mask = Image.new("L", (target_w, target_h), 0)
455
 
456
+ # ---- DensePose (compute on core -> paste to full) ----
457
+ human_dp = _apply_exif_orientation(human_core.resize((384, 512), Image.BILINEAR))
458
  human_dp = convert_PIL_to_numpy(human_dp, format="BGR")
459
 
460
  args = apply_net.create_argument_parser().parse_args(
 
469
  "cuda" if device == "cuda" else "cpu",
470
  )
471
  )
472
+ pose_core = args.func(args, human_dp)
473
+ pose_core = pose_core[:, :, ::-1]
474
+ pose_core = Image.fromarray(pose_core).resize((w, h), Image.BILINEAR)
475
+ pose_img = paste_into_canvas("RGB", (target_w, target_h), pose_core, lb_meta, (127, 127, 127))
476
 
477
+ # ---- prompts (not fixed) ----
478
  garment_desc = guess_garment_description(garm_relpath)
479
  if prompt_override and prompt_override.strip():
480
  garment_desc = prompt_override.strip()
 
482
  prompt_main = f"model is wearing {garment_desc}"
483
  prompt_cloth = f"a photo of {garment_desc}"
484
 
485
+ # ---- params ----
486
  denoise_steps = clamp_int(denoise_steps, 15, 60)
487
  guidance_scale = clamp_float(guidance_scale, 0.0, 12.0)
488
  strength = clamp_float(strength, 0.50, 1.00)
489
+ max_side = clamp_int(max_side, 640, 2048)
490
+
491
+ seed = int(seed) if seed is not None else -1
492
  if seed < 0:
 
493
  seed = int.from_bytes(os.urandom(2), "big") + int(time.time() * 1000) % 1000000
494
 
495
  with torch.no_grad():
 
528
 
529
  pose_t = tensor_transfrom(pose_img).unsqueeze(0).to(device=device, dtype=dtype)
530
  garm_t = tensor_transfrom(garm_lb).unsqueeze(0).to(device=device, dtype=dtype)
 
531
  generator = torch.Generator(device).manual_seed(seed)
532
 
533
  images = pipe(
 
537
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device=device, dtype=dtype),
538
  num_inference_steps=denoise_steps,
539
  generator=generator,
540
+ strength=strength,
541
  pose_img=pose_t,
542
  text_embeds_cloth=prompt_embeds_c.to(device=device, dtype=dtype),
543
  cloth=garm_t,
 
545
  image=human_lb,
546
  height=target_h,
547
  width=target_w,
548
+ ip_adapter_image=garm_lb,
549
+ guidance_scale=guidance_scale,
550
  )[0]
551
 
552
  out_img_lb = images[0].convert("RGB")
553
 
554
+ # remove padding and return to original resolution
555
  out_core = unletterbox(out_img_lb, lb_meta)
556
  out_final = out_core.resize((src_w, src_h), Image.LANCZOS)
557
  return out_final
 
586
  def tryon_ui(
587
  person_pil,
588
  selected_filename,
589
+ garment_type_override,
590
  auto_mask,
591
+ safety_clamp,
592
+ clamp_strength,
593
  steps,
594
  cfg,
595
  strength,
 
621
  human_pil=person_pil,
622
  garm_img=garm,
623
  garm_relpath=selected_filename,
624
+ garment_type_override=str(garment_type_override),
625
  auto_mask=bool(auto_mask),
626
+ safety_clamp=bool(safety_clamp),
627
+ clamp_strength=float(clamp_strength),
628
  denoise_steps=int(steps),
629
  guidance_scale=float(cfg),
630
  strength=float(strength),
 
637
  yield None, f"❌ Ошибка: {type(e).__name__}: {str(e)[:220]}"
638
 
639
 
640
+ # preload garments
641
  ensure_garments_downloaded()
642
  _initial_files = list_garments()
643
  _initial_items = build_gallery_items(_initial_files)
 
664
  allow_preview=True,
665
  )
666
 
667
+ with gr.Accordion("⚙️ Настройки", open=False):
668
+ garment_type_override = gr.Dropdown(
669
+ choices=["auto", "upper_body", "lower_body", "dresses"],
670
+ value="auto",
671
+ label="Тип одежды (override)",
672
+ )
673
+ auto_mask = gr.Checkbox(value=True, label="Auto mask (parsing + openpose)")
674
+
675
+ safety_clamp = gr.Checkbox(
676
+ value=True,
677
+ label="Safety clamp (защита от съезда зоны редактирования)",
678
+ )
679
+ clamp_strength = gr.Slider(
680
+ 0.0, 1.0, value=0.55, step=0.01,
681
+ label="Clamp strength (0 = мягко, 1 = сильнее)",
682
+ )
683
+
684
+ steps = gr.Slider(15, 60, value=34, step=1, label="Шаги (num_inference_steps)")
685
+ cfg = gr.Slider(0.0, 12.0, value=3.8, step=0.1, label="Guidance scale (CFG)")
686
+ strength = gr.Slider(0.50, 1.00, value=0.90, step=0.01, label="Strength")
687
+
688
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = случайный)")
689
+ max_side = gr.Slider(768, 1536, value=1024, step=64, label="Макс. сторона (динамический размер)")
690
+
691
  prompt_override = gr.Textbox(
692
  value="",
693
  label="Описание одежды (опц.)",
694
+ placeholder="Напр.: a blazer / a dress / a t-shirt ... (если пусто — авто по имени файла)",
695
  )
696
 
697
  run = gr.Button("Примерить", variant="primary")
 
714
 
715
  run.click(
716
  fn=tryon_ui,
717
+ inputs=[
718
+ person,
719
+ selected_garment_state,
720
+ garment_type_override,
721
+ auto_mask,
722
+ safety_clamp,
723
+ clamp_strength,
724
+ steps,
725
+ cfg,
726
+ strength,
727
+ seed,
728
+ max_side,
729
+ prompt_override,
730
+ ],
731
  outputs=[out, status],
732
  concurrency_limit=1,
733
  )