ssoxye commited on
Commit
b5310ec
·
1 Parent(s): 6b3f441

update examples

Browse files
Files changed (1) hide show
  1. app.py +54 -677
app.py CHANGED
@@ -1,626 +1,3 @@
1
- # import os
2
- # import sys
3
-
4
- # # ---------------------------------------------------------
5
- # # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
6
- # # ---------------------------------------------------------
7
- # ROOT = os.path.dirname(os.path.abspath(__file__))
8
- # if ROOT not in sys.path:
9
- # sys.path.insert(0, ROOT)
10
-
11
- # print("[BOOT] ROOT =", ROOT, flush=True)
12
- # print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
13
-
14
- # import tempfile
15
- # from dataclasses import dataclass
16
- # from functools import lru_cache
17
- # from typing import Optional, Tuple
18
-
19
- # import gradio as gr
20
- # import torch
21
- # import numpy as np
22
- # import cv2
23
- # import imageio
24
- # from PIL import Image, ImageOps
25
- # from transformers import pipeline
26
- # from huggingface_hub import hf_hub_download
27
-
28
- # import diffusers3
29
- # print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
30
-
31
- # from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
32
- # from diffusers3.models.controlnet import ControlNetModel
33
- # from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
34
- # StableDiffusionXLControlNetImg2ImgPipeline,
35
- # )
36
- # from ip_adapter import IPAdapterXL
37
-
38
- # # extractor
39
- # from preprocess.simple_extractor import run as run_simple_extractor
40
-
41
-
42
- # # =========================
43
- # # HF Hub repo ids
44
- # # =========================
45
- # BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
46
- # CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
47
-
48
- # # assets dataset repo
49
- # ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
50
- # ASSETS_REPO_TYPE = "dataset"
51
-
52
- # depth_estimator = pipeline("depth-estimation")
53
-
54
-
55
- # def asset_path(relpath: str) -> str:
56
- # return hf_hub_download(
57
- # repo_id=ASSETS_REPO,
58
- # repo_type=ASSETS_REPO_TYPE,
59
- # filename=relpath,
60
- # )
61
-
62
-
63
- # @lru_cache(maxsize=1)
64
- # def get_assets():
65
- # print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True)
66
-
67
- # image_encoder_weight = asset_path("image_encoder/model.safetensors")
68
- # _ = asset_path("image_encoder/config.json")
69
- # image_encoder_dir = os.path.dirname(image_encoder_weight)
70
-
71
- # ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
72
- # schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
73
-
74
- # print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True)
75
- # print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True)
76
- # print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True)
77
- # return image_encoder_dir, ip_ckpt, schp_ckpt
78
-
79
-
80
- # DEFAULT_STEPS = 40
81
- # DEBUG_SAVE = False
82
-
83
- # H: Optional[int] = None
84
- # W: Optional[int] = None
85
-
86
-
87
- # @dataclass
88
- # class Paths:
89
- # person_path: str
90
- # depth_path: Optional[str] # sketch(guide) optional
91
- # style_path: str
92
- # output_path: str
93
-
94
-
95
- # def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
96
- # img = cv2.imread(path, flag)
97
- # if img is None:
98
- # raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
99
- # return img
100
-
101
-
102
- # def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value):
103
- # """
104
- # arr: HxWxC or HxW
105
- # target_width로 center crop 또는 좌/우 padding(비대칭 포함)해서 정확히 맞춤.
106
- # """
107
- # if arr.ndim not in (2, 3):
108
- # raise ValueError(f"arr must be 2D or 3D, got shape={arr.shape}")
109
-
110
- # h = arr.shape[0]
111
- # w = arr.shape[1]
112
-
113
- # if w == target_width:
114
- # return arr
115
-
116
- # if w > target_width:
117
- # left = (w - target_width) // 2
118
- # return arr[:, left:left + target_width] if arr.ndim == 2 else arr[:, left:left + target_width, :]
119
-
120
- # # w < target_width: pad
121
- # total = target_width - w
122
- # left = total // 2
123
- # right = total - left # ✅ remainder를 오른쪽이 먹어서 항상 정확히 target_width
124
-
125
- # if arr.ndim == 2:
126
- # return cv2.copyMakeBorder(
127
- # arr, 0, 0, left, right,
128
- # borderType=cv2.BORDER_CONSTANT,
129
- # value=pad_value,
130
- # )
131
- # else:
132
- # # 3채널일 때 value는 스칼라 or [b,g,r]/[r,g,b] 모두 허용되는데,
133
- # # 여기선 arr가 RGB/BGR인지 호출자가 정해줌.
134
- # return cv2.copyMakeBorder(
135
- # arr, 0, 0, left, right,
136
- # borderType=cv2.BORDER_CONSTANT,
137
- # value=pad_value,
138
- # )
139
-
140
-
141
- # def apply_parsing_white_mask_to_person_cv2(
142
- # person_pil: Image.Image,
143
- # parsing_img: Image.Image
144
- # ) -> np.ndarray:
145
- # person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
146
- # mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
147
-
148
- # if mask.shape[:2] != person_rgb.shape[:2]:
149
- # mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
150
-
151
- # white_mask = (mask == 255)
152
- # result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
153
- # result_rgb[white_mask] = person_rgb[white_mask]
154
- # result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
155
- # return result_bgr
156
-
157
-
158
- # def clean_and_smooth_parsing_mask(
159
- # parsing_img: Image.Image,
160
- # *,
161
- # white_threshold: int = 128,
162
- # min_white_area: int = 300,
163
- # close_ksize: int = 7,
164
- # open_ksize: int = 3,
165
- # morph_iters: int = 1,
166
- # blur_ksize: int = 0,
167
- # ) -> Image.Image:
168
- # if not isinstance(parsing_img, Image.Image):
169
- # raise TypeError("parsing_img must be a PIL.Image.Image")
170
-
171
- # arr = np.array(parsing_img.convert("L"), dtype=np.uint8)
172
- # mask = np.where(arr >= white_threshold, 255, 0).astype(np.uint8)
173
-
174
- # num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
175
- # keep = np.zeros_like(mask)
176
- # for lab in range(1, num_labels):
177
- # area = int(stats[lab, cv2.CC_STAT_AREA])
178
- # if area >= min_white_area:
179
- # keep[labels == lab] = 255
180
- # mask = keep
181
-
182
- # def _odd_or_one(k: int) -> int:
183
- # k = int(k)
184
- # if k <= 1:
185
- # return 1
186
- # return k if (k % 2 == 1) else (k + 1)
187
-
188
- # close_k = _odd_or_one(close_ksize)
189
- # open_k = _odd_or_one(open_ksize)
190
-
191
- # if close_k > 1:
192
- # k_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k))
193
- # mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, k_close, iterations=int(morph_iters))
194
-
195
- # if open_k > 1:
196
- # k_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k))
197
- # mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k_open, iterations=int(morph_iters))
198
-
199
- # if blur_ksize and int(blur_ksize) > 1:
200
- # b = _odd_or_one(int(blur_ksize))
201
- # mask_blur = cv2.GaussianBlur(mask, (b, b), 0)
202
- # mask = np.where(mask_blur >= 128, 255, 0).astype(np.uint8)
203
-
204
- # return Image.fromarray(mask, mode="L")
205
-
206
-
207
- # def compute_hw_from_person(person_path: str):
208
- # img = _imread_or_raise(person_path)
209
- # orig_h, orig_w = img.shape[:2]
210
- # scale = 1024.0 / float(orig_h)
211
- # new_h = 1024
212
- # new_w = int(round(orig_w * scale))
213
- # if new_w > 1024:
214
- # new_w = 1024
215
- # return new_h, new_w
216
-
217
-
218
- # def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
219
- # global H, W
220
- # if H is None or W is None:
221
- # raise RuntimeError("Global H/W not set.")
222
- # img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE)
223
- # img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
224
- # _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
225
- # contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
226
- # filled = np.zeros_like(binary)
227
- # cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED)
228
- # filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB)
229
- # return Image.fromarray(filled_rgb)
230
-
231
-
232
- # def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image:
233
- # a = np.array(img1.convert("RGB"), dtype=np.uint8)
234
- # b = np.array(img2.convert("RGB"), dtype=np.uint8)
235
- # white_a = np.all(a == 255, axis=-1)
236
- # white_b = np.all(b == 255, axis=-1)
237
- # out = a.copy()
238
- # out[white_a | white_b] = 255
239
- # return Image.fromarray(out)
240
-
241
-
242
- # def preprocess_mask(mask_img: Image.Image) -> Image.Image:
243
- # global H, W
244
- # m = np.array(mask_img.convert("L"), dtype=np.uint8)
245
-
246
- # if (H is not None) and (W is not None):
247
- # m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
248
-
249
- # _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
250
-
251
- # target_width = 1024
252
- # m = _pad_or_crop_to_width_np(m, target_width, pad_value=0)
253
-
254
- # kernel = np.ones((17, 17), np.uint8)
255
- # m = cv2.dilate(m, kernel, iterations=1)
256
-
257
- # if DEBUG_SAVE:
258
- # cv2.imwrite("mask_final_1024.png", m)
259
-
260
- # return Image.fromarray(m, mode="L").convert("RGB")
261
-
262
-
263
- # def make_depth(depth_path: str) -> Image.Image:
264
- # global H, W
265
- # if H is None or W is None:
266
- # raise RuntimeError("Global H/W not set. Call run_one() first.")
267
-
268
- # depth_img = _imread_or_raise(depth_path, 0)
269
- # inverted_depth = cv2.bitwise_not(depth_img)
270
- # contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
271
-
272
- # filled_depth = inverted_depth.copy()
273
- # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
274
-
275
- # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
276
- # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
277
-
278
- # inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
279
-
280
- # with torch.inference_mode():
281
- # image_depth = depth_estimator(inverted_image)["depth"]
282
-
283
- # if DEBUG_SAVE:
284
- # image_depth.save("depth.png")
285
-
286
- # return image_depth
287
-
288
-
289
- # def _edges_from_parsing(parsing_img: Image.Image) -> np.ndarray:
290
- # m = np.array(parsing_img.convert("L"), dtype=np.uint8)
291
- # _, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
292
- # edges = cv2.Canny(m_bin, 50, 150)
293
- # edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
294
- # return edges.astype(np.uint8)
295
-
296
-
297
- # def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image:
298
- # global H, W
299
- # if H is None or W is None:
300
- # raise RuntimeError("Global H/W not set. Call run_one() first.")
301
-
302
- # depth_img = _edges_from_parsing(parsing_img)
303
- # # inverted_depth = cv2.bitwise_not(depth_img)
304
- # contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
305
-
306
- # filled_depth = depth_img.copy()
307
- # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
308
-
309
- # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
310
- # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
311
-
312
- # inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
313
-
314
- # with torch.inference_mode():
315
- # image_depth = depth_estimator(inverted_image)["depth"]
316
-
317
- # if DEBUG_SAVE:
318
- # image_depth.save("depth.png")
319
-
320
- # return image_depth
321
-
322
-
323
- # def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
324
- # target_h, target_w = 1024, 768
325
- # h, w = arr.shape[:2]
326
- # if h != target_h:
327
- # arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA)
328
- # h, w = arr.shape[:2]
329
- # if w < target_w:
330
- # pad = (target_w - w) // 2
331
- # arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255])
332
- # w = arr.shape[1]
333
- # left = (w - target_w) // 2
334
- # return arr[:, left:left + target_w]
335
-
336
-
337
- # def save_cropped(imgs, out_path: str):
338
- # np_imgs = [np.asarray(im) for im in imgs]
339
- # cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs]
340
- # out = np.concatenate(cropped, axis=1)
341
- # os.makedirs(os.path.dirname(out_path), exist_ok=True)
342
- # imageio.imsave(out_path, out)
343
-
344
-
345
- # @lru_cache(maxsize=1)
346
- # def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
347
- # device = "cuda" if torch.cuda.is_available() else "cpu"
348
- # dtype = torch.float32
349
-
350
- # print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
351
-
352
- # controlnet = ControlNetModel.from_pretrained(
353
- # CONTROLNET_ID,
354
- # torch_dtype=dtype,
355
- # use_safetensors=True,
356
- # ).to(device)
357
-
358
- # vae = AutoencoderKL.from_pretrained(
359
- # BASE_MODEL_ID,
360
- # subfolder="vae",
361
- # torch_dtype=dtype,
362
- # use_safetensors=True,
363
- # ).to(device)
364
-
365
- # unet = UNet2DConditionModel.from_pretrained(
366
- # BASE_MODEL_ID,
367
- # subfolder="unet",
368
- # torch_dtype=dtype,
369
- # use_safetensors=True,
370
- # ).to(device)
371
-
372
- # pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
373
- # BASE_MODEL_ID,
374
- # controlnet=controlnet,
375
- # vae=vae,
376
- # unet=unet,
377
- # torch_dtype=dtype,
378
- # use_safetensors=True,
379
- # add_watermarker=False,
380
- # ).to(device)
381
-
382
- # if device == "cuda":
383
- # try:
384
- # pipe.vae.to(dtype=dtype)
385
- # if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
386
- # pipe.vae.config.force_upcast = False
387
- # except Exception as e:
388
- # print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
389
-
390
- # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
391
- # pipe.enable_attention_slicing()
392
- # try:
393
- # pipe.enable_xformers_memory_efficient_attention()
394
- # except Exception as e:
395
- # print("[PIPE] xformers not enabled:", repr(e), flush=True)
396
-
397
- # return pipe, device, dtype
398
-
399
-
400
- # # UI 표기 → 내부 extractor category 문자열 매핑
401
- # _UI_TO_EXTRACTOR_CATEGORY = {
402
- # "Upper-body": "Upper-cloth",
403
- # "Lower-body": "Bottom",
404
- # "Dress": "Dress",
405
- # }
406
-
407
-
408
- # def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
409
- # global H, W
410
- # pipe, device, _dtype = get_pipe_and_device()
411
- # image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
412
-
413
- # H, W = compute_hw_from_person(paths.person_path)
414
-
415
- # extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress")
416
-
417
- # res = run_simple_extractor(
418
- # category=extractor_category,
419
- # input_path=os.path.abspath(paths.person_path),
420
- # model_restore=schp_ckpt,
421
- # )
422
- # parsing_img = res["images"][0] if res.get("images") else None
423
- # if parsing_img is None:
424
- # raise RuntimeError("run_simple_extractor returned no parsing images.")
425
-
426
- # parsing_img = clean_and_smooth_parsing_mask(
427
- # parsing_img,
428
- # min_white_area=50,
429
- # close_ksize=9,
430
- # open_ksize=3,
431
- # morph_iters=1,
432
- # blur_ksize=7,
433
- # )
434
-
435
- # use_depth_path = (
436
- # paths.depth_path is not None
437
- # and isinstance(paths.depth_path, str)
438
- # and len(paths.depth_path) > 0
439
- # and os.path.exists(paths.depth_path)
440
- # )
441
-
442
- # if use_depth_path:
443
- # sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
444
- # else:
445
- # sketch_area = parsing_img.convert("RGB")
446
-
447
- # merged_img = merge_white_regions_or(parsing_img, sketch_area)
448
- # mask_pil = preprocess_mask(merged_img)
449
-
450
- # # person
451
- # person_bgr = _imread_or_raise(paths.person_path)
452
- # person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
453
- # person_bgr = _pad_or_crop_to_width_np(person_bgr, 1024, pad_value=[255, 255, 255])
454
- # person_rgb = cv2.cvtColor(person_bgr, cv2.COLOR_BGR2RGB)
455
- # person_pil = Image.fromarray(person_rgb)
456
-
457
- # # depth
458
- # if use_depth_path:
459
- # depth_map = make_depth(paths.depth_path)
460
- # else:
461
- # depth_map = make_depth_from_parsing_edges(parsing_img)
462
-
463
- # # garment image (✅ 여기서부터가 핵심: 1024 폭 강제)
464
- # personn = Image.open(paths.person_path).convert("RGB")
465
- # garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
466
- # garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
467
- # garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
468
-
469
- # # ✅ 기존 padding=(1024-W)//2 방식 제거 → 비대칭 패딩/크롭으로 정확히 1024
470
- # garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
471
- # garment_pil = Image.fromarray(garment_rgb)
472
-
473
- # # garment mask (✅ 동일하게 1024 맞춤)
474
- # gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
475
- # gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST)
476
- # gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
477
- # gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0])
478
- # garment_mask_pil = Image.fromarray(gm)
479
-
480
- # print(
481
- # "[SIZE] person:", person_pil.size,
482
- # "mask:", mask_pil.size,
483
- # "depth:", depth_map.size,
484
- # "garment:", garment_pil.size,
485
- # "gmask:", garment_mask_pil.size,
486
- # "ui_category:", category,
487
- # "extractor_category:", extractor_category,
488
- # flush=True
489
- # )
490
-
491
- # ip_model = IPAdapterXL(
492
- # pipe,
493
- # image_encoder_dir,
494
- # ip_ckpt,
495
- # device,
496
- # mask_pil,
497
- # person_pil,
498
- # content_scale=0.3,
499
- # style_scale=0.5,
500
- # garment_images=garment_pil,
501
- # garment_mask=garment_mask_pil,
502
- # )
503
-
504
- # if device == "cuda":
505
- # pipe.to(dtype=torch.float32)
506
- # try:
507
- # for _, proc in pipe.unet.attn_processors.items():
508
- # proc.to(dtype=torch.float32)
509
- # except Exception:
510
- # pass
511
-
512
- # style_img = Image.open(paths.style_path).convert("RGB")
513
-
514
- # prompt = extractor_category + "with " + prompt
515
-
516
- # print("====prompt? ", prompt)
517
-
518
- # with torch.inference_mode():
519
- # images = ip_model.generate(
520
- # pil_image=style_img,
521
- # image=person_pil,
522
- # control_image=depth_map,
523
- # strength=1.0,
524
- # num_samples=1,
525
- # num_inference_steps=int(steps),
526
- # shape_prompt="",
527
- # prompt=prompt or "",
528
- # num=0,
529
- # scale=None,
530
- # controlnet_conditioning_scale=0.7,
531
- # guidance_scale=7.5,
532
- # )
533
-
534
- # save_cropped(images, paths.output_path)
535
-
536
- # return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
537
-
538
-
539
- # def set_seed(seed: int):
540
- # if seed is None or seed < 0:
541
- # return
542
- # np.random.seed(seed)
543
- # torch.manual_seed(seed)
544
- # if torch.cuda.is_available():
545
- # torch.cuda.manual_seed_all(seed)
546
-
547
-
548
- # def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
549
- # print("[UI] infer_web called", flush=True)
550
-
551
- # if person_fp is None or style_fp is None:
552
- # raise gr.Error("person / style 이미지는 필수입니다. (sketch는 선택)")
553
-
554
- # if category not in ("Upper-body", "Lower-body", "Dress"):
555
- # raise gr.Error(f"Invalid category: {category}")
556
-
557
- # set_seed(int(seed) if seed is not None else -1)
558
-
559
- # tmp_dir = tempfile.mkdtemp(prefix="vista_demo_")
560
- # out_path = os.path.join(tmp_dir, "result.png")
561
-
562
- # paths = Paths(
563
- # person_path=person_fp,
564
- # depth_path=sketch_fp,
565
- # style_path=style_fp,
566
- # output_path=out_path,
567
- # )
568
-
569
- # _, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one(
570
- # paths, prompt=prompt, steps=int(steps), category=category
571
- # )
572
-
573
- # out_img = Image.open(out_path).convert("RGB")
574
- # return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
575
-
576
-
577
- # with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
578
- # gr.Markdown("## VISTA Demo\nperson / style 필수, sketch(guide)는 선택입니다.")
579
-
580
- # category_toggle = gr.Radio(
581
- # choices=["Dress", "Upper-body", "Lower-body"],
582
- # value="Dress",
583
- # label="Category",
584
- # interactive=True,
585
- # )
586
-
587
- # # 한 행에 Person / Style / Output
588
- # with gr.Row():
589
- # person_in = gr.Image(label="Person Image (required)", type="filepath")
590
- # style_in = gr.Image(label="Style Image (required)", type="filepath")
591
- # out_img = gr.Image(label="Output", type="pil")
592
-
593
- # with gr.Accordion("Sketch / Guide (optional)", open=False):
594
- # sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
595
-
596
- # with gr.Row():
597
- # prompt_in = gr.Textbox(label="Prompt", value="", lines=2)
598
- # steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
599
- # seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
600
-
601
- # run_btn = gr.Button("Run")
602
- # out_file = gr.File(label="Download result.png")
603
-
604
- # gr.Markdown("### Debug Visualizations (mask/depth/etc)")
605
- # with gr.Row():
606
- # dbg_mask = gr.Image(label="mask_pil", type="pil")
607
- # dbg_depth = gr.Image(label="depth_map", type="pil")
608
-
609
- # with gr.Row():
610
- # dbg_person = gr.Image(label="person_pil", type="pil")
611
- # dbg_garment = gr.Image(label="garment_pil", type="pil")
612
- # dbg_gmask = gr.Image(label="garment_mask_pil", type="pil")
613
-
614
- # run_btn.click(
615
- # fn=infer_web,
616
- # inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in, category_toggle],
617
- # outputs=[out_img, out_file, dbg_mask, dbg_depth, dbg_person, dbg_garment, dbg_gmask],
618
- # )
619
-
620
- # demo.queue()
621
- # if __name__ == "__main__":
622
- # demo.launch(server_name="0.0.0.0", server_port=7860)
623
-
624
  import os
625
  import sys
626
  import glob
@@ -638,7 +15,7 @@ print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
638
  import tempfile
639
  from dataclasses import dataclass
640
  from functools import lru_cache
641
- from typing import Optional, Tuple, List
642
 
643
  import gradio as gr
644
  import torch
@@ -702,23 +79,19 @@ def get_assets():
702
 
703
 
704
  # =========================
705
- # Example assets for Gradio UI
706
  # =========================
707
  def _is_image_file(p: str) -> bool:
708
  ext = os.path.splitext(p.lower())[1]
709
  return ext in (".png", ".jpg", ".jpeg", ".webp")
710
 
711
 
712
- def build_ui_examples(root_dir: str = ROOT) -> List[List[object]]:
713
  """
714
- Returns a list of [person_fp, sketch_fp_or_None, style_fp] example rows.
715
-
716
- Put your example files under:
717
- - {root}/examples/person/*
718
- - {root}/examples/style/*
719
- - {root}/examples/sketch/* (optional)
720
-
721
- If folders are missing or empty, this returns [] and the UI will hide Examples.
722
  """
723
  person_dir = os.path.join(root_dir, "examples", "person")
724
  style_dir = os.path.join(root_dir, "examples", "style")
@@ -728,23 +101,7 @@ def build_ui_examples(root_dir: str = ROOT) -> List[List[object]]:
728
  styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)]
729
  sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)]
730
 
731
- if not persons or not styles:
732
- return []
733
-
734
- n_basic = min(len(persons), len(styles))
735
- rows: List[List[object]] = []
736
-
737
- # Examples without sketch (sketch is optional)
738
- for i in range(n_basic):
739
- rows.append([persons[i], None, styles[i]])
740
-
741
- # Examples with sketch (if available)
742
- if sketches:
743
- n_sk = min(len(persons), len(styles), len(sketches))
744
- for i in range(n_sk):
745
- rows.append([persons[i], sketches[i], styles[i]])
746
-
747
- return rows
748
 
749
 
750
  DEFAULT_STEPS = 40
@@ -1176,10 +533,9 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
1176
  pass
1177
 
1178
  style_img = Image.open(paths.style_path).convert("RGB")
1179
-
1180
  prompt = extractor_category + "with " + prompt
1181
-
1182
- print("==== prompt? ", prompt)
1183
 
1184
  with torch.inference_mode():
1185
  images = ip_model.generate(
@@ -1198,7 +554,6 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
1198
  )
1199
 
1200
  save_cropped(images, paths.output_path)
1201
-
1202
  return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
1203
 
1204
 
@@ -1250,33 +605,57 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
1250
  interactive=True,
1251
  )
1252
 
 
 
 
 
 
 
1253
  # 한 행에 Person / Style / Output
1254
  with gr.Row():
1255
- person_in = gr.Image(label="Person Image (required)", type="filepath")
1256
- style_in = gr.Image(label="Style Image (required)", type="filepath")
1257
- out_img = gr.Image(label="Output", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1258
 
1259
  with gr.Accordion("Sketch / Guide (optional)", open=False):
1260
  sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
1261
-
1262
- # Examples: click to auto-fill Person/Sketch/Style inputs
1263
- examples_data = build_ui_examples(ROOT)
1264
- if examples_data:
1265
- gr.Markdown("### Examples (click to load)")
1266
- gr.Examples(
1267
- examples=examples_data,
1268
- inputs=[person_in, sketch_in, style_in],
1269
- examples_per_page=8,
1270
- )
1271
- else:
1272
- gr.Markdown(
1273
- "### Examples\n"
1274
- "Add example files under `./examples/person/` and `./examples/style/` "
1275
- "(optional: `./examples/sketch/`) to show clickable examples here."
1276
- )
1277
 
1278
  with gr.Row():
1279
- prompt_in = gr.Textbox(label="Prompt", value="", lines=2)
 
 
 
 
 
1280
  steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
1281
  seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
1282
 
@@ -1302,5 +681,3 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
1302
  demo.queue()
1303
  if __name__ == "__main__":
1304
  demo.launch(server_name="0.0.0.0", server_port=7860)
1305
-
1306
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import glob
 
15
  import tempfile
16
  from dataclasses import dataclass
17
  from functools import lru_cache
18
+ from typing import Optional, Tuple, List, Dict
19
 
20
  import gradio as gr
21
  import torch
 
79
 
80
 
81
  # =========================
82
+ # Example assets for Gradio UI (✅ 분리형)
83
  # =========================
84
  def _is_image_file(p: str) -> bool:
85
  ext = os.path.splitext(p.lower())[1]
86
  return ext in (".png", ".jpg", ".jpeg", ".webp")
87
 
88
 
89
+ def build_ui_example_lists(root_dir: str = ROOT) -> Dict[str, List[str]]:
90
  """
91
+ Returns dict of example filepaths:
92
+ - persons: [{root}/examples/person/*]
93
+ - styles : [{root}/examples/style/*]
94
+ - sketches: [{root}/examples/sketch/*] (optional)
 
 
 
 
95
  """
96
  person_dir = os.path.join(root_dir, "examples", "person")
97
  style_dir = os.path.join(root_dir, "examples", "style")
 
101
  styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)]
102
  sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)]
103
 
104
+ return {"persons": persons, "styles": styles, "sketches": sketches}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  DEFAULT_STEPS = 40
 
533
  pass
534
 
535
  style_img = Image.open(paths.style_path).convert("RGB")
536
+
537
  prompt = extractor_category + "with " + prompt
538
+ print("==== prompt? ", prompt, flush=True)
 
539
 
540
  with torch.inference_mode():
541
  images = ip_model.generate(
 
554
  )
555
 
556
  save_cropped(images, paths.output_path)
 
557
  return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
558
 
559
 
 
605
  interactive=True,
606
  )
607
 
608
+ # ✅ 예시 리스트(분리)
609
+ ex = build_ui_example_lists(ROOT)
610
+ person_examples = [[p] for p in ex["persons"]]
611
+ style_examples = [[p] for p in ex["styles"]]
612
+ sketch_examples = [[p] for p in ex["sketches"]]
613
+
614
  # 한 행에 Person / Style / Output
615
  with gr.Row():
616
+ # -------- Person column --------
617
+ with gr.Column(scale=1):
618
+ person_in = gr.Image(label="Person Image (required)", type="filepath")
619
+ if person_examples:
620
+ gr.Markdown("#### Examples")
621
+ gr.Examples(
622
+ examples=person_examples,
623
+ inputs=[person_in], # ✅ person만 채움 (독립 선택)
624
+ examples_per_page=8,
625
+ )
626
+
627
+ # -------- Style column --------
628
+ with gr.Column(scale=1):
629
+ style_in = gr.Image(label="Style Image (required)", type="filepath")
630
+ if style_examples:
631
+ gr.Markdown("#### Examples")
632
+ gr.Examples(
633
+ examples=style_examples,
634
+ inputs=[style_in], # ✅ style만 채움 (독립 선택)
635
+ examples_per_page=8,
636
+ )
637
+
638
+ # -------- Output column --------
639
+ with gr.Column(scale=1):
640
+ out_img = gr.Image(label="Output", type="pil")
641
 
642
  with gr.Accordion("Sketch / Guide (optional)", open=False):
643
  sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
644
+ if sketch_examples:
645
+ gr.Markdown("#### Examples")
646
+ gr.Examples(
647
+ examples=sketch_examples,
648
+ inputs=[sketch_in], # sketch만 채움 (독립 선택)
649
+ examples_per_page=8,
650
+ )
 
 
 
 
 
 
 
 
 
651
 
652
  with gr.Row():
653
+ prompt_in = gr.Textbox(
654
+ label="Prompt",
655
+ value="",
656
+ placeholder="예: floral pattern, silk texture, studio lighting",
657
+ lines=2,
658
+ )
659
  steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
660
  seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
661
 
 
681
  demo.queue()
682
  if __name__ == "__main__":
683
  demo.launch(server_name="0.0.0.0", server_port=7860)