ssoxye commited on
Commit
6b3f441
·
1 Parent(s): ef6ba1e

update examples

Browse files
Files changed (1) hide show
  1. app.py +130 -106
app.py CHANGED
@@ -23,14 +23,12 @@
23
  # import imageio
24
  # from PIL import Image, ImageOps
25
  # from transformers import pipeline
26
-
27
  # from huggingface_hub import hf_hub_download
28
 
29
  # import diffusers3
30
  # print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
31
 
32
  # from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
33
-
34
  # from diffusers3.models.controlnet import ControlNetModel
35
  # from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
36
  # StableDiffusionXLControlNetImg2ImgPipeline,
@@ -89,7 +87,7 @@
89
  # @dataclass
90
  # class Paths:
91
  # person_path: str
92
- # depth_path: Optional[str] # sketch(guide) optional
93
  # style_path: str
94
  # output_path: str
95
 
@@ -101,22 +99,58 @@
101
  # return img
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # def apply_parsing_white_mask_to_person_cv2(
105
  # person_pil: Image.Image,
106
  # parsing_img: Image.Image
107
  # ) -> np.ndarray:
108
  # person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
109
-
110
  # mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
111
 
112
  # if mask.shape[:2] != person_rgb.shape[:2]:
113
  # mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
114
 
115
  # white_mask = (mask == 255)
116
-
117
  # result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
118
  # result_rgb[white_mask] = person_rgb[white_mask]
119
-
120
  # result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
121
  # return result_bgr
122
 
@@ -134,19 +168,15 @@
134
  # if not isinstance(parsing_img, Image.Image):
135
  # raise TypeError("parsing_img must be a PIL.Image.Image")
136
 
137
- # img_l = parsing_img.convert("L")
138
- # arr = np.array(img_l, dtype=np.uint8)
139
-
140
  # mask = np.where(arr >= white_threshold, 255, 0).astype(np.uint8)
141
 
142
  # num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
143
-
144
  # keep = np.zeros_like(mask)
145
  # for lab in range(1, num_labels):
146
  # area = int(stats[lab, cv2.CC_STAT_AREA])
147
  # if area >= min_white_area:
148
  # keep[labels == lab] = 255
149
-
150
  # mask = keep
151
 
152
  # def _odd_or_one(k: int) -> int:
@@ -219,20 +249,7 @@
219
  # _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
220
 
221
  # target_width = 1024
222
- # h, w = m.shape[:2]
223
-
224
- # if w < target_width:
225
- # total_padding = target_width - w
226
- # left_padding = total_padding // 2
227
- # right_padding = total_padding - left_padding
228
- # m = cv2.copyMakeBorder(
229
- # m, 0, 0, left_padding, right_padding,
230
- # borderType=cv2.BORDER_CONSTANT,
231
- # value=0,
232
- # )
233
- # elif w > target_width:
234
- # left = (w - target_width) // 2
235
- # m = m[:, left:left + target_width]
236
 
237
  # kernel = np.ones((17, 17), np.uint8)
238
  # m = cv2.dilate(m, kernel, iterations=1)
@@ -249,7 +266,6 @@
249
  # raise RuntimeError("Global H/W not set. Call run_one() first.")
250
 
251
  # depth_img = _imread_or_raise(depth_path, 0)
252
-
253
  # inverted_depth = cv2.bitwise_not(depth_img)
254
  # contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
255
 
@@ -257,19 +273,9 @@
257
  # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
258
 
259
  # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
 
260
 
261
- # height, width = filled_depth.shape
262
- # total_padding = 1024 - width
263
- # left_padding = total_padding // 2
264
- # right_padding = total_padding - left_padding
265
-
266
- # padded_depth = cv2.copyMakeBorder(
267
- # filled_depth, 0, 0, left_padding, right_padding,
268
- # borderType=cv2.BORDER_CONSTANT,
269
- # value=0,
270
- # )
271
-
272
- # inverted_image = ImageOps.invert(Image.fromarray(padded_depth))
273
 
274
  # with torch.inference_mode():
275
  # image_depth = depth_estimator(inverted_image)["depth"]
@@ -294,7 +300,6 @@
294
  # raise RuntimeError("Global H/W not set. Call run_one() first.")
295
 
296
  # depth_img = _edges_from_parsing(parsing_img)
297
-
298
  # # inverted_depth = cv2.bitwise_not(depth_img)
299
  # contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
300
 
@@ -302,19 +307,9 @@
302
  # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
303
 
304
  # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
 
305
 
306
- # height, width = filled_depth.shape
307
- # total_padding = 1024 - width
308
- # left_padding = total_padding // 2
309
- # right_padding = total_padding - left_padding
310
-
311
- # padded_depth = cv2.copyMakeBorder(
312
- # filled_depth, 0, 0, left_padding, right_padding,
313
- # borderType=cv2.BORDER_CONSTANT,
314
- # value=0,
315
- # )
316
-
317
- # inverted_image = ImageOps.invert(Image.fromarray(padded_depth))
318
 
319
  # with torch.inference_mode():
320
  # image_depth = depth_estimator(inverted_image)["depth"]
@@ -402,7 +397,7 @@
402
  # return pipe, device, dtype
403
 
404
 
405
- # # UI 표기 → 내부 extractor category 문자열 매핑
406
  # _UI_TO_EXTRACTOR_CATEGORY = {
407
  # "Upper-body": "Upper-cloth",
408
  # "Lower-body": "Bottom",
@@ -411,16 +406,12 @@
411
 
412
 
413
  # def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
414
- # """
415
- # category: UI에서 넘어오는 값(Upper-body/Lower-body/Dress)
416
- # """
417
  # global H, W
418
  # pipe, device, _dtype = get_pipe_and_device()
419
  # image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
420
 
421
  # H, W = compute_hw_from_person(paths.person_path)
422
 
423
- # # ✅ UI category를 extractor가 기대하는 문자열로 변환
424
  # extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress")
425
 
426
  # res = run_simple_extractor(
@@ -456,61 +447,34 @@
456
  # merged_img = merge_white_regions_or(parsing_img, sketch_area)
457
  # mask_pil = preprocess_mask(merged_img)
458
 
 
459
  # person_bgr = _imread_or_raise(paths.person_path)
460
  # person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
461
-
462
- # target_width = 1024
463
- # cur_w = person_bgr.shape[1]
464
- # if cur_w < target_width:
465
- # total = target_width - cur_w
466
- # left = total // 2
467
- # right = total - left
468
- # padded_person = cv2.copyMakeBorder(
469
- # person_bgr, 0, 0, left, right,
470
- # borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255]
471
- # )
472
- # elif cur_w > target_width:
473
- # left = (cur_w - target_width) // 2
474
- # padded_person = person_bgr[:, left:left + target_width]
475
- # else:
476
- # padded_person = person_bgr
477
-
478
- # person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB)
479
  # person_pil = Image.fromarray(person_rgb)
480
 
 
481
  # if use_depth_path:
482
  # depth_map = make_depth(paths.depth_path)
483
  # else:
484
  # depth_map = make_depth_from_parsing_edges(parsing_img)
485
 
 
486
  # personn = Image.open(paths.person_path).convert("RGB")
487
- # garment_ = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
488
- # garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB)
489
  # garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
490
 
491
- # padding = (target_width - W) // 2 if W < target_width else 0
492
- # garment_rgb = cv2.copyMakeBorder(
493
- # garment_rgb,
494
- # top=0, bottom=0,
495
- # left=padding, right=padding,
496
- # borderType=cv2.BORDER_CONSTANT,
497
- # value=[255, 255, 255],
498
- # )
499
  # garment_pil = Image.fromarray(garment_rgb)
500
 
 
501
  # gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
502
  # gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST)
503
  # gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
504
-
505
- # cur_w2 = gm.shape[1]
506
- # if cur_w2 < target_width:
507
- # total = target_width - cur_w2
508
- # left = total // 2
509
- # right = total - left
510
- # gm = cv2.copyMakeBorder(gm, 0, 0, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
511
- # elif cur_w2 > target_width:
512
- # left2 = (cur_w2 - target_width) // 2
513
- # gm = gm[:, left2:left2 + target_width]
514
  # garment_mask_pil = Image.fromarray(gm)
515
 
516
  # print(
@@ -531,7 +495,7 @@
531
  # device,
532
  # mask_pil,
533
  # person_pil,
534
- # content_scale=0.2,
535
  # style_scale=0.5,
536
  # garment_images=garment_pil,
537
  # garment_mask=garment_mask_pil,
@@ -547,9 +511,9 @@
547
 
548
  # style_img = Image.open(paths.style_path).convert("RGB")
549
 
550
- # prompt = category + "with " + category
551
 
552
- # print("====prompt? ", prompt)
553
 
554
  # with torch.inference_mode():
555
  # images = ip_model.generate(
@@ -613,7 +577,6 @@
613
  # with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
614
  # gr.Markdown("## VISTA Demo\nperson / style 필수, sketch(guide)는 선택입니다.")
615
 
616
- # # ✅ UI 표기는 Upper-body/Lower-body/Dress 유지 (기본 Dress)
617
  # category_toggle = gr.Radio(
618
  # choices=["Dress", "Upper-body", "Lower-body"],
619
  # value="Dress",
@@ -621,7 +584,7 @@
621
  # interactive=True,
622
  # )
623
 
624
- # # 한 행에 Person / Style / Output 배치
625
  # with gr.Row():
626
  # person_in = gr.Image(label="Person Image (required)", type="filepath")
627
  # style_in = gr.Image(label="Style Image (required)", type="filepath")
@@ -636,8 +599,6 @@
636
  # seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
637
 
638
  # run_btn = gr.Button("Run")
639
-
640
- # # 파일 다운로드는 Output 아래(다음 행)에 두는 게 일반적으로 보기 좋음
641
  # out_file = gr.File(label="Download result.png")
642
 
643
  # gr.Markdown("### Debug Visualizations (mask/depth/etc)")
@@ -662,6 +623,7 @@
662
 
663
  import os
664
  import sys
 
665
 
666
  # ---------------------------------------------------------
667
  # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
@@ -676,7 +638,7 @@ print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
676
  import tempfile
677
  from dataclasses import dataclass
678
  from functools import lru_cache
679
- from typing import Optional, Tuple
680
 
681
  import gradio as gr
682
  import torch
@@ -739,6 +701,52 @@ def get_assets():
739
  return image_encoder_dir, ip_ckpt, schp_ckpt
740
 
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  DEFAULT_STEPS = 40
743
  DEBUG_SAVE = False
744
 
@@ -791,8 +799,6 @@ def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value):
791
  value=pad_value,
792
  )
793
  else:
794
- # 3채널일 때 value는 스칼라 or [b,g,r]/[r,g,b] 모두 허용되는데,
795
- # 여기선 arr가 RGB/BGR인지 호출자가 정해줌.
796
  return cv2.copyMakeBorder(
797
  arr, 0, 0, left, right,
798
  borderType=cv2.BORDER_CONSTANT,
@@ -962,7 +968,6 @@ def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image:
962
  raise RuntimeError("Global H/W not set. Call run_one() first.")
963
 
964
  depth_img = _edges_from_parsing(parsing_img)
965
- # inverted_depth = cv2.bitwise_not(depth_img)
966
  contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
967
 
968
  filled_depth = depth_img.copy()
@@ -1128,7 +1133,6 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
1128
  garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
1129
  garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
1130
 
1131
- # ✅ 기존 padding=(1024-W)//2 방식 제거 → 비대칭 패딩/크롭으로 정확히 1024
1132
  garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
1133
  garment_pil = Image.fromarray(garment_rgb)
1134
 
@@ -1172,6 +1176,10 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
1172
  pass
1173
 
1174
  style_img = Image.open(paths.style_path).convert("RGB")
 
 
 
 
1175
 
1176
  with torch.inference_mode():
1177
  images = ip_model.generate(
@@ -1251,6 +1259,22 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
1251
  with gr.Accordion("Sketch / Guide (optional)", open=False):
1252
  sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
1253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1254
  with gr.Row():
1255
  prompt_in = gr.Textbox(label="Prompt", value="", lines=2)
1256
  steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
 
23
  # import imageio
24
  # from PIL import Image, ImageOps
25
  # from transformers import pipeline
 
26
  # from huggingface_hub import hf_hub_download
27
 
28
  # import diffusers3
29
  # print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
30
 
31
  # from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
 
32
  # from diffusers3.models.controlnet import ControlNetModel
33
  # from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
34
  # StableDiffusionXLControlNetImg2ImgPipeline,
 
87
  # @dataclass
88
  # class Paths:
89
  # person_path: str
90
+ # depth_path: Optional[str] # sketch(guide) optional
91
  # style_path: str
92
  # output_path: str
93
 
 
99
  # return img
100
 
101
 
102
+ # def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value):
103
+ # """
104
+ # arr: HxWxC or HxW
105
+ # target_width로 center crop 또는 좌/우 padding(비대칭 포함)해서 정확히 맞춤.
106
+ # """
107
+ # if arr.ndim not in (2, 3):
108
+ # raise ValueError(f"arr must be 2D or 3D, got shape={arr.shape}")
109
+
110
+ # h = arr.shape[0]
111
+ # w = arr.shape[1]
112
+
113
+ # if w == target_width:
114
+ # return arr
115
+
116
+ # if w > target_width:
117
+ # left = (w - target_width) // 2
118
+ # return arr[:, left:left + target_width] if arr.ndim == 2 else arr[:, left:left + target_width, :]
119
+
120
+ # # w < target_width: pad
121
+ # total = target_width - w
122
+ # left = total // 2
123
+ # right = total - left # ✅ remainder를 오른쪽이 먹어서 항상 정확히 target_width
124
+
125
+ # if arr.ndim == 2:
126
+ # return cv2.copyMakeBorder(
127
+ # arr, 0, 0, left, right,
128
+ # borderType=cv2.BORDER_CONSTANT,
129
+ # value=pad_value,
130
+ # )
131
+ # else:
132
+ # # 3채널일 때 value는 스칼라 or [b,g,r]/[r,g,b] 모두 허용되는데,
133
+ # # 여기선 arr가 RGB/BGR인지 호출자가 정해줌.
134
+ # return cv2.copyMakeBorder(
135
+ # arr, 0, 0, left, right,
136
+ # borderType=cv2.BORDER_CONSTANT,
137
+ # value=pad_value,
138
+ # )
139
+
140
+
141
  # def apply_parsing_white_mask_to_person_cv2(
142
  # person_pil: Image.Image,
143
  # parsing_img: Image.Image
144
  # ) -> np.ndarray:
145
  # person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
 
146
  # mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
147
 
148
  # if mask.shape[:2] != person_rgb.shape[:2]:
149
  # mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
150
 
151
  # white_mask = (mask == 255)
 
152
  # result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
153
  # result_rgb[white_mask] = person_rgb[white_mask]
 
154
  # result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
155
  # return result_bgr
156
 
 
168
  # if not isinstance(parsing_img, Image.Image):
169
  # raise TypeError("parsing_img must be a PIL.Image.Image")
170
 
171
+ # arr = np.array(parsing_img.convert("L"), dtype=np.uint8)
 
 
172
  # mask = np.where(arr >= white_threshold, 255, 0).astype(np.uint8)
173
 
174
  # num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
 
175
  # keep = np.zeros_like(mask)
176
  # for lab in range(1, num_labels):
177
  # area = int(stats[lab, cv2.CC_STAT_AREA])
178
  # if area >= min_white_area:
179
  # keep[labels == lab] = 255
 
180
  # mask = keep
181
 
182
  # def _odd_or_one(k: int) -> int:
 
249
  # _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
250
 
251
  # target_width = 1024
252
+ # m = _pad_or_crop_to_width_np(m, target_width, pad_value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  # kernel = np.ones((17, 17), np.uint8)
255
  # m = cv2.dilate(m, kernel, iterations=1)
 
266
  # raise RuntimeError("Global H/W not set. Call run_one() first.")
267
 
268
  # depth_img = _imread_or_raise(depth_path, 0)
 
269
  # inverted_depth = cv2.bitwise_not(depth_img)
270
  # contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
271
 
 
273
  # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
274
 
275
  # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
276
+ # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
277
 
278
+ # inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # with torch.inference_mode():
281
  # image_depth = depth_estimator(inverted_image)["depth"]
 
300
  # raise RuntimeError("Global H/W not set. Call run_one() first.")
301
 
302
  # depth_img = _edges_from_parsing(parsing_img)
 
303
  # # inverted_depth = cv2.bitwise_not(depth_img)
304
  # contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
305
 
 
307
  # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
308
 
309
  # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
310
+ # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
311
 
312
+ # inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  # with torch.inference_mode():
315
  # image_depth = depth_estimator(inverted_image)["depth"]
 
397
  # return pipe, device, dtype
398
 
399
 
400
+ # # UI 표기 → 내부 extractor category 문자열 매핑
401
  # _UI_TO_EXTRACTOR_CATEGORY = {
402
  # "Upper-body": "Upper-cloth",
403
  # "Lower-body": "Bottom",
 
406
 
407
 
408
  # def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
 
 
 
409
  # global H, W
410
  # pipe, device, _dtype = get_pipe_and_device()
411
  # image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
412
 
413
  # H, W = compute_hw_from_person(paths.person_path)
414
 
 
415
  # extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress")
416
 
417
  # res = run_simple_extractor(
 
447
  # merged_img = merge_white_regions_or(parsing_img, sketch_area)
448
  # mask_pil = preprocess_mask(merged_img)
449
 
450
+ # # person
451
  # person_bgr = _imread_or_raise(paths.person_path)
452
  # person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
453
+ # person_bgr = _pad_or_crop_to_width_np(person_bgr, 1024, pad_value=[255, 255, 255])
454
+ # person_rgb = cv2.cvtColor(person_bgr, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  # person_pil = Image.fromarray(person_rgb)
456
 
457
+ # # depth
458
  # if use_depth_path:
459
  # depth_map = make_depth(paths.depth_path)
460
  # else:
461
  # depth_map = make_depth_from_parsing_edges(parsing_img)
462
 
463
+ # # garment image (✅ 여기서부터가 핵심: 1024 폭 강제)
464
  # personn = Image.open(paths.person_path).convert("RGB")
465
+ # garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
466
+ # garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
467
  # garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
468
 
469
+ # # ✅ 기존 padding=(1024-W)//2 방식 제거 비대칭 패딩/크롭으로 정확히 1024
470
+ # garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
 
 
 
 
 
 
471
  # garment_pil = Image.fromarray(garment_rgb)
472
 
473
+ # # garment mask (✅ 동일하게 1024 맞춤)
474
  # gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
475
  # gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST)
476
  # gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
477
+ # gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0])
 
 
 
 
 
 
 
 
 
478
  # garment_mask_pil = Image.fromarray(gm)
479
 
480
  # print(
 
495
  # device,
496
  # mask_pil,
497
  # person_pil,
498
+ # content_scale=0.3,
499
  # style_scale=0.5,
500
  # garment_images=garment_pil,
501
  # garment_mask=garment_mask_pil,
 
511
 
512
  # style_img = Image.open(paths.style_path).convert("RGB")
513
 
514
+ # prompt = extractor_category + "with " + prompt
515
 
516
+ # print("====prompt? ", prompt)
517
 
518
  # with torch.inference_mode():
519
  # images = ip_model.generate(
 
577
  # with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
578
  # gr.Markdown("## VISTA Demo\nperson / style 필수, sketch(guide)는 선택입니다.")
579
 
 
580
  # category_toggle = gr.Radio(
581
  # choices=["Dress", "Upper-body", "Lower-body"],
582
  # value="Dress",
 
584
  # interactive=True,
585
  # )
586
 
587
+ # # 한 행에 Person / Style / Output
588
  # with gr.Row():
589
  # person_in = gr.Image(label="Person Image (required)", type="filepath")
590
  # style_in = gr.Image(label="Style Image (required)", type="filepath")
 
599
  # seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
600
 
601
  # run_btn = gr.Button("Run")
 
 
602
  # out_file = gr.File(label="Download result.png")
603
 
604
  # gr.Markdown("### Debug Visualizations (mask/depth/etc)")
 
623
 
624
  import os
625
  import sys
626
+ import glob
627
 
628
  # ---------------------------------------------------------
629
  # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
 
638
  import tempfile
639
  from dataclasses import dataclass
640
  from functools import lru_cache
641
+ from typing import Optional, Tuple, List
642
 
643
  import gradio as gr
644
  import torch
 
701
  return image_encoder_dir, ip_ckpt, schp_ckpt
702
 
703
 
704
+ # =========================
705
+ # Example assets for Gradio UI
706
+ # =========================
707
+ def _is_image_file(p: str) -> bool:
708
+ ext = os.path.splitext(p.lower())[1]
709
+ return ext in (".png", ".jpg", ".jpeg", ".webp")
710
+
711
+
712
+ def build_ui_examples(root_dir: str = ROOT) -> List[List[object]]:
713
+ """
714
+ Returns a list of [person_fp, sketch_fp_or_None, style_fp] example rows.
715
+
716
+ Put your example files under:
717
+ - {root}/examples/person/*
718
+ - {root}/examples/style/*
719
+ - {root}/examples/sketch/* (optional)
720
+
721
+ If folders are missing or empty, this returns [] and the UI will hide Examples.
722
+ """
723
+ person_dir = os.path.join(root_dir, "examples", "person")
724
+ style_dir = os.path.join(root_dir, "examples", "style")
725
+ sketch_dir = os.path.join(root_dir, "examples", "sketch")
726
+
727
+ persons = [p for p in sorted(glob.glob(os.path.join(person_dir, "*"))) if _is_image_file(p)]
728
+ styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)]
729
+ sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)]
730
+
731
+ if not persons or not styles:
732
+ return []
733
+
734
+ n_basic = min(len(persons), len(styles))
735
+ rows: List[List[object]] = []
736
+
737
+ # Examples without sketch (sketch is optional)
738
+ for i in range(n_basic):
739
+ rows.append([persons[i], None, styles[i]])
740
+
741
+ # Examples with sketch (if available)
742
+ if sketches:
743
+ n_sk = min(len(persons), len(styles), len(sketches))
744
+ for i in range(n_sk):
745
+ rows.append([persons[i], sketches[i], styles[i]])
746
+
747
+ return rows
748
+
749
+
750
  DEFAULT_STEPS = 40
751
  DEBUG_SAVE = False
752
 
 
799
  value=pad_value,
800
  )
801
  else:
 
 
802
  return cv2.copyMakeBorder(
803
  arr, 0, 0, left, right,
804
  borderType=cv2.BORDER_CONSTANT,
 
968
  raise RuntimeError("Global H/W not set. Call run_one() first.")
969
 
970
  depth_img = _edges_from_parsing(parsing_img)
 
971
  contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
972
 
973
  filled_depth = depth_img.copy()
 
1133
  garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
1134
  garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
1135
 
 
1136
  garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
1137
  garment_pil = Image.fromarray(garment_rgb)
1138
 
 
1176
  pass
1177
 
1178
  style_img = Image.open(paths.style_path).convert("RGB")
1179
+
1180
+ prompt = extractor_category + "with " + prompt
1181
+
1182
+ print("==== prompt? ", prompt)
1183
 
1184
  with torch.inference_mode():
1185
  images = ip_model.generate(
 
1259
  with gr.Accordion("Sketch / Guide (optional)", open=False):
1260
  sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
1261
 
1262
+ # ✅ Examples: click to auto-fill Person/Sketch/Style inputs
1263
+ examples_data = build_ui_examples(ROOT)
1264
+ if examples_data:
1265
+ gr.Markdown("### Examples (click to load)")
1266
+ gr.Examples(
1267
+ examples=examples_data,
1268
+ inputs=[person_in, sketch_in, style_in],
1269
+ examples_per_page=8,
1270
+ )
1271
+ else:
1272
+ gr.Markdown(
1273
+ "### Examples\n"
1274
+ "Add example files under `./examples/person/` and `./examples/style/` "
1275
+ "(optional: `./examples/sketch/`) to show clickable examples here."
1276
+ )
1277
+
1278
  with gr.Row():
1279
  prompt_in = gr.Textbox(label="Prompt", value="", lines=2)
1280
  steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")