ssoxye commited on
Commit
6ecb196
·
1 Parent(s): 9ffd8e3

update save

Browse files
Files changed (1) hide show
  1. app.py +76 -712
app.py CHANGED
@@ -1,714 +1,3 @@
1
- # import os
2
- # import sys
3
- # import glob
4
-
5
- # # ---------------------------------------------------------
6
- # # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
7
- # # ---------------------------------------------------------
8
- # ROOT = os.path.dirname(os.path.abspath(__file__))
9
- # if ROOT not in sys.path:
10
- # sys.path.insert(0, ROOT)
11
-
12
- # print("[BOOT] ROOT =", ROOT, flush=True)
13
- # print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
14
-
15
- # import tempfile
16
- # from dataclasses import dataclass
17
- # from functools import lru_cache
18
- # from typing import Optional, Tuple, List, Dict
19
-
20
- # import gradio as gr
21
- # import torch
22
- # import numpy as np
23
- # import cv2
24
- # import imageio
25
- # from PIL import Image, ImageOps
26
- # from transformers import pipeline
27
- # from huggingface_hub import hf_hub_download
28
-
29
- # import diffusers3
30
- # print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
31
-
32
- # from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
33
- # from diffusers3.models.controlnet import ControlNetModel
34
- # from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
35
- # StableDiffusionXLControlNetImg2ImgPipeline,
36
- # )
37
- # from ip_adapter import IPAdapterXL
38
-
39
- # # extractor
40
- # from preprocess.simple_extractor import run as run_simple_extractor
41
-
42
-
43
- # # =========================
44
- # # HF Hub repo ids
45
- # # =========================
46
- # BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
47
- # CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
48
-
49
- # # assets dataset repo
50
- # ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
51
- # ASSETS_REPO_TYPE = "dataset"
52
-
53
- # depth_estimator = pipeline("depth-estimation")
54
-
55
-
56
- # def asset_path(relpath: str) -> str:
57
- # return hf_hub_download(
58
- # repo_id=ASSETS_REPO,
59
- # repo_type=ASSETS_REPO_TYPE,
60
- # filename=relpath,
61
- # )
62
-
63
-
64
- # @lru_cache(maxsize=1)
65
- # def get_assets():
66
- # print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True)
67
-
68
- # image_encoder_weight = asset_path("image_encoder/model.safetensors")
69
- # _ = asset_path("image_encoder/config.json")
70
- # image_encoder_dir = os.path.dirname(image_encoder_weight)
71
-
72
- # ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
73
- # schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
74
-
75
- # print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True)
76
- # print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True)
77
- # print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True)
78
- # return image_encoder_dir, ip_ckpt, schp_ckpt
79
-
80
-
81
- # # =========================
82
- # # Example assets for Gradio UI (✅ 분리형)
83
- # # =========================
84
- # def _is_image_file(p: str) -> bool:
85
- # ext = os.path.splitext(p.lower())[1]
86
- # return ext in (".png", ".jpg", ".jpeg", ".webp")
87
-
88
-
89
- # def build_ui_example_lists(root_dir: str = ROOT) -> Dict[str, List[str]]:
90
- # """
91
- # Returns dict of example filepaths:
92
- # - persons: [{root}/examples/person/*]
93
- # - styles : [{root}/examples/style/*]
94
- # - sketches: [{root}/examples/sketch/*] (optional)
95
- # """
96
- # person_dir = os.path.join(root_dir, "examples", "person")
97
- # style_dir = os.path.join(root_dir, "examples", "style")
98
- # sketch_dir = os.path.join(root_dir, "examples", "sketch")
99
-
100
- # persons = [p for p in sorted(glob.glob(os.path.join(person_dir, "*"))) if _is_image_file(p)]
101
- # styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)]
102
- # sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)]
103
-
104
- # return {"persons": persons, "styles": styles, "sketches": sketches}
105
-
106
-
107
- # DEFAULT_STEPS = 40
108
- # DEBUG_SAVE = False
109
-
110
- # H: Optional[int] = None
111
- # W: Optional[int] = None
112
-
113
-
114
- # @dataclass
115
- # class Paths:
116
- # person_path: str
117
- # depth_path: Optional[str] # sketch(guide) optional
118
- # style_path: Optional[str] # ✅ style optional (변경)
119
- # output_path: str
120
-
121
-
122
- # def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
123
- # img = cv2.imread(path, flag)
124
- # if img is None:
125
- # raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
126
- # return img
127
-
128
-
129
- # def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value):
130
- # """
131
- # arr: HxWxC or HxW
132
- # target_width로 center crop 또는 좌/우 padding(비대칭 포함)해서 정확히 맞춤.
133
- # """
134
- # if arr.ndim not in (2, 3):
135
- # raise ValueError(f"arr must be 2D or 3D, got shape={arr.shape}")
136
-
137
- # h = arr.shape[0]
138
- # w = arr.shape[1]
139
-
140
- # if w == target_width:
141
- # return arr
142
-
143
- # if w > target_width:
144
- # left = (w - target_width) // 2
145
- # return arr[:, left:left + target_width] if arr.ndim == 2 else arr[:, left:left + target_width, :]
146
-
147
- # # w < target_width: pad
148
- # total = target_width - w
149
- # left = total // 2
150
- # right = total - left # ✅ remainder를 ��른쪽이 먹어서 항상 정확히 target_width
151
-
152
- # if arr.ndim == 2:
153
- # return cv2.copyMakeBorder(
154
- # arr, 0, 0, left, right,
155
- # borderType=cv2.BORDER_CONSTANT,
156
- # value=pad_value,
157
- # )
158
- # else:
159
- # return cv2.copyMakeBorder(
160
- # arr, 0, 0, left, right,
161
- # borderType=cv2.BORDER_CONSTANT,
162
- # value=pad_value,
163
- # )
164
-
165
-
166
- # def apply_parsing_white_mask_to_person_cv2(
167
- # person_pil: Image.Image,
168
- # parsing_img: Image.Image
169
- # ) -> np.ndarray:
170
- # person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
171
- # mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
172
-
173
- # if mask.shape[:2] != person_rgb.shape[:2]:
174
- # mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
175
-
176
- # white_mask = (mask == 255)
177
- # result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
178
- # result_rgb[white_mask] = person_rgb[white_mask]
179
- # result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
180
- # return result_bgr
181
-
182
-
183
- # def remove_small_white_components(
184
- # parsing_img: Image.Image,
185
- # *,
186
- # white_threshold: int = 128,
187
- # min_white_area: int = 150,
188
- # use_open: bool = False,
189
- # open_ksize: int = 3,
190
- # morph_iters: int = 1,
191
- # ) -> Image.Image:
192
- # """
193
- # - 흰색(=foreground)으로 이진화
194
- # - connected components로 '작은 흰색 덩어리'만 제거
195
- # - (옵션) OPEN을 아주 약하게 적용해 작은 점/가시 제거 (흰색이 늘어나는 CLOSE는 사용 X)
196
- # """
197
- # if not isinstance(parsing_img, Image.Image):
198
- # raise TypeError("parsing_img must be a PIL.Image.Image")
199
-
200
- # arr = np.array(parsing_img.convert("L"), dtype=np.uint8)
201
- # mask = np.where(arr >= int(white_threshold), 255, 0).astype(np.uint8)
202
-
203
- # # 1) 작은 흰색 연결요소 제거
204
- # num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
205
- # keep = np.zeros_like(mask)
206
- # for lab in range(1, num_labels):
207
- # area = int(stats[lab, cv2.CC_STAT_AREA])
208
- # if area >= int(min_white_area):
209
- # keep[labels == lab] = 255
210
- # mask = keep
211
-
212
- # # 2) (옵션) OPEN: 작은 흰 점/가시 제거 + 경계 약간 정리 (흰색 증가 방향 아님)
213
- # if use_open and int(open_ksize) > 1:
214
- # k = int(open_ksize)
215
- # if k % 2 == 0:
216
- # k += 1
217
- # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
218
- # mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_iters))
219
-
220
- # return Image.fromarray(mask, mode="L")
221
-
222
-
223
- # def compute_hw_from_person(person_path: str):
224
- # img = _imread_or_raise(person_path)
225
- # orig_h, orig_w = img.shape[:2]
226
- # scale = 1024.0 / float(orig_h)
227
- # new_h = 1024
228
- # new_w = int(round(orig_w * scale))
229
- # if new_w > 1024:
230
- # new_w = 1024
231
- # return new_h, new_w
232
-
233
-
234
- # def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
235
- # global H, W
236
- # if H is None or W is None:
237
- # raise RuntimeError("Global H/W not set.")
238
- # img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE)
239
- # img = cv2.bitwise_not(img)
240
- # img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
241
- # _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
242
- # contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
243
- # filled = np.zeros_like(binary)
244
- # cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED)
245
- # filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB)
246
- # return Image.fromarray(filled_rgb)
247
-
248
-
249
- # def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image:
250
- # a = np.array(img1.convert("RGB"), dtype=np.uint8)
251
- # b = np.array(img2.convert("RGB"), dtype=np.uint8)
252
- # white_a = np.all(a == 255, axis=-1)
253
- # white_b = np.all(b == 255, axis=-1)
254
- # out = a.copy()
255
- # out[white_a | white_b] = 255
256
- # return Image.fromarray(out)
257
-
258
-
259
- # def preprocess_mask(mask_img: Image.Image) -> Image.Image:
260
- # global H, W
261
- # m = np.array(mask_img.convert("L"), dtype=np.uint8)
262
-
263
- # if (H is not None) and (W is not None):
264
- # m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
265
-
266
- # _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
267
-
268
- # target_width = 1024
269
- # m = _pad_or_crop_to_width_np(m, target_width, pad_value=0)
270
-
271
- # kernel = np.ones((12, 12), np.uint8)
272
- # m = cv2.dilate(m, kernel, iterations=1)
273
-
274
- # if DEBUG_SAVE:
275
- # cv2.imwrite("mask_final_1024.png", m)
276
-
277
- # return Image.fromarray(m, mode="L").convert("RGB")
278
-
279
-
280
- # def make_depth(depth_path: str) -> Image.Image:
281
- # global H, W
282
- # if H is None or W is None:
283
- # raise RuntimeError("Global H/W not set. Call run_one() first.")
284
-
285
- # depth_img = _imread_or_raise(depth_path, 0)
286
- # # inverted_depth = cv2.bitwise_not(depth_img)
287
- # contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
288
-
289
- # filled_depth = depth_img.copy()
290
- # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
291
-
292
- # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
293
- # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
294
-
295
- # inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
296
-
297
- # with torch.inference_mode():
298
- # image_depth = depth_estimator(inverted_image)["depth"]
299
-
300
- # if DEBUG_SAVE:
301
- # image_depth.save("depth.png")
302
-
303
- # return image_depth
304
-
305
-
306
- # def _edges_from_parsing(parsing_img: Image.Image) -> np.ndarray:
307
- # m = np.array(parsing_img.convert("L"), dtype=np.uint8)
308
- # _, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
309
- # edges = cv2.Canny(m_bin, 50, 150)
310
- # edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
311
- # return edges.astype(np.uint8)
312
-
313
-
314
- # def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image:
315
- # global H, W
316
- # if H is None or W is None:
317
- # raise RuntimeError("Global H/W not set. Call run_one() first.")
318
-
319
- # depth_img = _edges_from_parsing(parsing_img)
320
- # contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
321
-
322
- # filled_depth = depth_img.copy()
323
- # cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
324
-
325
- # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
326
- # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
327
-
328
- # inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
329
-
330
- # with torch.inference_mode():
331
- # image_depth = depth_estimator(inverted_image)["depth"]
332
-
333
- # if DEBUG_SAVE:
334
- # image_depth.save("depth.png")
335
-
336
- # return image_depth
337
-
338
-
339
- # def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
340
- # target_h, target_w = 1024, 768
341
- # h, w = arr.shape[:2]
342
- # if h != target_h:
343
- # arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA)
344
- # h, w = arr.shape[:2]
345
- # if w < target_w:
346
- # pad = (target_w - w) // 2
347
- # arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255])
348
- # w = arr.shape[1]
349
- # left = (w - target_w) // 2
350
- # return arr[:, left:left + target_w]
351
-
352
-
353
- # def save_cropped(imgs, out_path: str):
354
- # np_imgs = [np.asarray(im) for im in imgs]
355
- # cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs]
356
- # out = np.concatenate(cropped, axis=1)
357
- # os.makedirs(os.path.dirname(out_path), exist_ok=True)
358
- # imageio.imsave(out_path, out)
359
-
360
-
361
- # @lru_cache(maxsize=1)
362
- # def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
363
- # device = "cuda" if torch.cuda.is_available() else "cpu"
364
- # dtype = torch.float32
365
-
366
- # print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
367
-
368
- # controlnet = ControlNetModel.from_pretrained(
369
- # CONTROLNET_ID,
370
- # torch_dtype=dtype,
371
- # use_safetensors=True,
372
- # ).to(device)
373
-
374
- # vae = AutoencoderKL.from_pretrained(
375
- # BASE_MODEL_ID,
376
- # subfolder="vae",
377
- # torch_dtype=dtype,
378
- # use_safetensors=True,
379
- # ).to(device)
380
-
381
- # unet = UNet2DConditionModel.from_pretrained(
382
- # BASE_MODEL_ID,
383
- # subfolder="unet",
384
- # torch_dtype=dtype,
385
- # use_safetensors=True,
386
- # ).to(device)
387
-
388
- # pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
389
- # BASE_MODEL_ID,
390
- # controlnet=controlnet,
391
- # vae=vae,
392
- # unet=unet,
393
- # torch_dtype=dtype,
394
- # use_safetensors=True,
395
- # add_watermarker=False,
396
- # ).to(device)
397
-
398
- # if device == "cuda":
399
- # try:
400
- # pipe.vae.to(dtype=dtype)
401
- # if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
402
- # pipe.vae.config.force_upcast = False
403
- # except Exception as e:
404
- # print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
405
-
406
- # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
407
- # pipe.enable_attention_slicing()
408
- # try:
409
- # pipe.enable_xformers_memory_efficient_attention()
410
- # except Exception as e:
411
- # print("[PIPE] xformers not enabled:", repr(e), flush=True)
412
-
413
- # return pipe, device, dtype
414
-
415
-
416
- # # UI 표기 → 내부 extractor category 문자열 매핑
417
- # _UI_TO_EXTRACTOR_CATEGORY = {
418
- # "Upper-body": "Upper-cloth",
419
- # "Lower-body": "Bottom",
420
- # "Dress": "Dress",
421
- # }
422
-
423
-
424
- # def _has_valid_file(path: Optional[str]) -> bool:
425
- # return (
426
- # path is not None
427
- # and isinstance(path, str)
428
- # and len(path) > 0
429
- # and os.path.exists(path)
430
- # )
431
-
432
-
433
- # def _resolve_content_style_scales(style_present: bool, prompt_present: bool) -> Tuple[float, float]:
434
- # """
435
- # 요구사항:
436
- # - style image 없으면: (0.0, 0.0)
437
- # - prompt 없���면: (0.4, 0.6)
438
- # - 둘 다 있으면: 기존 유지 (0.3, 0.5)
439
- # """
440
- # if not style_present:
441
- # return 0.0, 0.0
442
- # if not prompt_present:
443
- # return 0.4, 0.65
444
- # return 0.4, 0.5
445
-
446
-
447
- # def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
448
- # global H, W
449
- # pipe, device, _dtype = get_pipe_and_device()
450
- # image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
451
-
452
- # H, W = compute_hw_from_person(paths.person_path)
453
-
454
- # extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress")
455
-
456
- # res = run_simple_extractor(
457
- # category=extractor_category,
458
- # input_path=os.path.abspath(paths.person_path),
459
- # model_restore=schp_ckpt,
460
- # )
461
- # parsing_img = res["images"][0] if res.get("images") else None
462
- # if parsing_img is None:
463
- # raise RuntimeError("run_simple_extractor returned no parsing images.")
464
-
465
- # parsing_img = remove_small_white_components(
466
- # parsing_img,
467
- # white_threshold=128,
468
- # min_white_area=150, # 데이터에 맞게 30~200 사이 조절
469
- # use_open=False,
470
- # )
471
-
472
- # use_depth_path = _has_valid_file(paths.depth_path)
473
-
474
- # if use_depth_path:
475
- # sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
476
- # else:
477
- # sketch_area = parsing_img.convert("RGB")
478
-
479
- # merged_img = merge_white_regions_or(parsing_img, sketch_area)
480
- # mask_pil = preprocess_mask(merged_img)
481
-
482
- # # person
483
- # person_bgr = _imread_or_raise(paths.person_path)
484
- # person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
485
- # person_bgr = _pad_or_crop_to_width_np(person_bgr, 1024, pad_value=[255, 255, 255])
486
- # person_rgb = cv2.cvtColor(person_bgr, cv2.COLOR_BGR2RGB)
487
- # person_pil = Image.fromarray(person_rgb)
488
-
489
- # # depth
490
- # if use_depth_path:
491
- # depth_map = make_depth(paths.depth_path)
492
- # else:
493
- # depth_map = make_depth_from_parsing_edges(parsing_img)
494
-
495
- # # garment image (✅ 여기서부터가 핵심: 1024 폭 강제)
496
- # personn = Image.open(paths.person_path).convert("RGB")
497
- # garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
498
- # garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
499
- # garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
500
- # garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
501
- # garment_pil = Image.fromarray(garment_rgb)
502
-
503
- # # garment mask (✅ 동일하게 1024 맞춤)
504
- # gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
505
- # gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST)
506
- # gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
507
- # gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0])
508
- # garment_mask_pil = Image.fromarray(gm)
509
-
510
- # # ✅ 조건에 따른 scale 결정
511
- # style_present = _has_valid_file(paths.style_path)
512
- # prompt_present = (prompt is not None) and (str(prompt).strip() != "")
513
- # content_scale, style_scale = _resolve_content_style_scales(style_present, prompt_present)
514
-
515
- # print(
516
- # "[SIZE] person:", person_pil.size,
517
- # "mask:", mask_pil.size,
518
- # "depth:", depth_map.size,
519
- # "garment:", garment_pil.size,
520
- # "gmask:", garment_mask_pil.size,
521
- # "ui_category:", category,
522
- # "extractor_category:", extractor_category,
523
- # "style_present:", style_present,
524
- # "prompt_present:", prompt_present,
525
- # "content_scale:", content_scale,
526
- # "style_scale:", style_scale,
527
- # flush=True
528
- # )
529
-
530
- # ip_model = IPAdapterXL(
531
- # pipe,
532
- # image_encoder_dir,
533
- # ip_ckpt,
534
- # device,
535
- # mask_pil,
536
- # person_pil,
537
- # content_scale=content_scale, # ✅ 변경
538
- # style_scale=style_scale, # ✅ 변경
539
- # garment_images=garment_pil,
540
- # garment_mask=garment_mask_pil,
541
- # )
542
-
543
- # if device == "cuda":
544
- # pipe.to(dtype=torch.float32)
545
- # try:
546
- # for _, proc in pipe.unet.attn_processors.items():
547
- # proc.to(dtype=torch.float32)
548
- # except Exception:
549
- # pass
550
-
551
- # # ✅ style image 없을 때도 generate 입력이 None이 되지 않게 대체
552
- # if style_present:
553
- # style_img = Image.open(paths.style_path).convert("RGB")
554
- # else:
555
- # # scale이 0이므로 영향은 없고, 함수 시그니처만 만족시키기 위한 대체값
556
- # style_img = garment_pil
557
-
558
- # # prompt 구성은 기존 유지
559
- # if prompt is not None and str(prompt).strip() != "":
560
- # prompt = extractor_category + " with " + str(prompt).strip()
561
- # else:
562
- # prompt = extractor_category
563
-
564
- # print("==== prompt? ", prompt, flush=True)
565
-
566
- # with torch.inference_mode():
567
- # images = ip_model.generate(
568
- # pil_image=style_img,
569
- # image=person_pil,
570
- # control_image=depth_map,
571
- # strength=1.0,
572
- # num_samples=1,
573
- # num_inference_steps=int(steps),
574
- # shape_prompt="",
575
- # prompt=prompt or "",
576
- # num=0,
577
- # scale=None,
578
- # controlnet_conditioning_scale=0.7,
579
- # guidance_scale=7.5,
580
- # )
581
-
582
- # save_cropped(images, paths.output_path)
583
- # return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
584
-
585
-
586
- # def set_seed(seed: int):
587
- # if seed is None or seed < 0:
588
- # return
589
- # np.random.seed(seed)
590
- # torch.manual_seed(seed)
591
- # if torch.cuda.is_available():
592
- # torch.cuda.manual_seed_all(seed)
593
-
594
-
595
- # def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
596
- # print("[UI] infer_web called", flush=True)
597
-
598
- # # ✅ person만 필수, style은 선택
599
- # if person_fp is None:
600
- # raise gr.Error("person 이미지는 필수입니다. (style/sketch는 선택)")
601
-
602
- # if category not in ("Upper-body", "Lower-body", "Dress"):
603
- # raise gr.Error(f"Invalid category: {category}")
604
-
605
- # set_seed(int(seed) if seed is not None else -1)
606
-
607
- # tmp_dir = tempfile.mkdtemp(prefix="vista_demo_")
608
- # out_path = os.path.join(tmp_dir, "result.png")
609
-
610
- # paths = Paths(
611
- # person_path=person_fp,
612
- # depth_path=sketch_fp,
613
- # style_path=style_fp, # ✅ None 가능
614
- # output_path=out_path,
615
- # )
616
-
617
- # _, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one(
618
- # paths, prompt=prompt, steps=int(steps), category=category
619
- # )
620
-
621
- # out_img = Image.open(out_path).convert("RGB")
622
- # return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
623
-
624
-
625
- # with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
626
- # gr.Markdown("## VISTA Demo\nperson 필수, style/sketch(guide)는 선택입니다.")
627
-
628
- # category_toggle = gr.Radio(
629
- # choices=["Dress", "Upper-body", "Lower-body"],
630
- # value="Dress",
631
- # label="Category",
632
- # interactive=True,
633
- # )
634
-
635
- # # ✅ 예시 리스트(분리)
636
- # ex = build_ui_example_lists(ROOT)
637
- # person_examples = [[p] for p in ex["persons"]]
638
- # style_examples = [[p] for p in ex["styles"]]
639
- # sketch_examples = [[p] for p in ex["sketches"]]
640
-
641
- # # 한 행에 Person / Style / Output
642
- # with gr.Row():
643
- # # -------- Person column --------
644
- # with gr.Column(scale=1):
645
- # person_in = gr.Image(label="Person Image (required)", type="filepath")
646
- # if person_examples:
647
- # gr.Markdown("#### Examples")
648
- # gr.Examples(
649
- # examples=person_examples,
650
- # inputs=[person_in],
651
- # examples_per_page=8,
652
- # )
653
-
654
- # # -------- Style column --------
655
- # with gr.Column(scale=1):
656
- # style_in = gr.Image(label="Style Image (optional)", type="filepath")
657
- # if style_examples:
658
- # gr.Markdown("#### Examples")
659
- # gr.Examples(
660
- # examples=style_examples,
661
- # inputs=[style_in],
662
- # examples_per_page=8,
663
- # )
664
-
665
- # # -------- Output column --------
666
- # with gr.Column(scale=1):
667
- # out_img = gr.Image(label="Output", type="pil")
668
-
669
- # with gr.Accordion("Sketch / Guide (optional)", open=False):
670
- # sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
671
- # if sketch_examples:
672
- # gr.Markdown("#### Examples")
673
- # gr.Examples(
674
- # examples=sketch_examples,
675
- # inputs=[sketch_in],
676
- # examples_per_page=8,
677
- # )
678
-
679
- # with gr.Row():
680
- # prompt_in = gr.Textbox(
681
- # label="Prompt",
682
- # value="",
683
- # placeholder="ex) crystal, lace, button, …",
684
- # lines=2,
685
- # )
686
- # steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
687
- # seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
688
-
689
- # run_btn = gr.Button("Run")
690
- # out_file = gr.File(label="Download result.png")
691
-
692
- # gr.Markdown("### Debug Visualizations (mask/depth/etc)")
693
- # with gr.Row():
694
- # dbg_mask = gr.Image(label="mask_pil", type="pil")
695
- # dbg_depth = gr.Image(label="depth_map", type="pil")
696
-
697
- # with gr.Row():
698
- # dbg_person = gr.Image(label="person_pil", type="pil")
699
- # dbg_garment = gr.Image(label="garment_pil", type="pil")
700
- # dbg_gmask = gr.Image(label="garment_mask_pil", type="pil")
701
-
702
- # run_btn.click(
703
- # fn=infer_web,
704
- # inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in, category_toggle],
705
- # outputs=[out_img, out_file, dbg_mask, dbg_depth, dbg_person, dbg_garment, dbg_gmask],
706
- # )
707
-
708
- # demo.queue()
709
- # if __name__ == "__main__":
710
- # demo.launch(server_name="0.0.0.0", server_port=7860)
711
-
712
  import os
713
  import sys
714
  import glob
@@ -1152,6 +441,78 @@ def save_cropped(imgs, out_path: str):
1152
  out = np.concatenate(cropped, axis=1)
1153
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
1154
  imageio.imsave(out_path, out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1155
 
1156
 
1157
  @lru_cache(maxsize=1)
@@ -1379,10 +740,13 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
1379
  guidance_scale=7.5,
1380
  )
1381
 
1382
- save_cropped(images, paths.output_path)
 
 
1383
  return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
1384
 
1385
 
 
1386
  def set_seed(seed: int):
1387
  if seed is None or seed < 0:
1388
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import glob
 
441
  out = np.concatenate(cropped, axis=1)
442
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
443
  imageio.imsave(out_path, out)
444
+
445
+ def _read_hw(path: str) -> Tuple[int, int]:
446
+ img = _imread_or_raise(path) # BGR
447
+ h, w = img.shape[:2]
448
+ return h, w
449
+
450
+
451
+ def _center_crop_lr_to_aspect(arr: np.ndarray, target_aspect: float, *, pad_value=255) -> np.ndarray:
452
+ """
453
+ arr: HxWxC (RGB) or HxW
454
+ target_aspect = target_w / target_h
455
+ - 높이(H)는 유지
456
+ - 좌/우를 동일 비율로 crop해서 target_aspect에 맞춤
457
+ - 만약 현재 폭이 부족하면 좌/우 padding으로 맞춤
458
+ """
459
+ if arr.ndim == 2:
460
+ arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
461
+
462
+ h, w = arr.shape[:2]
463
+ if h <= 0 or w <= 0:
464
+ raise ValueError(f"Invalid image shape: {arr.shape}")
465
+
466
+ desired_w = int(round(h * float(target_aspect)))
467
+ if desired_w <= 0:
468
+ desired_w = 1
469
+
470
+ # 폭이 충분하면 좌/우 crop
471
+ if w >= desired_w:
472
+ left = (w - desired_w) // 2
473
+ right = left + desired_w
474
+ return arr[:, left:right]
475
+
476
+ # 폭이 부족하면 좌/우 padding (요청은 crop이지만 안전장치)
477
+ total = desired_w - w
478
+ left_pad = total // 2
479
+ right_pad = total - left_pad
480
+ return cv2.copyMakeBorder(
481
+ arr,
482
+ 0, 0,
483
+ left_pad, right_pad,
484
+ borderType=cv2.BORDER_CONSTANT,
485
+ value=[pad_value, pad_value, pad_value],
486
+ )
487
+
488
+
489
+ def save_output_match_person(imgs, out_path: str, person_path: str):
490
+ """
491
+ - 출력 imgs(보통 길이 1)를 person 원본 비율에 맞게 좌/우 center-crop
492
+ - person 원본 (W,H)로 resize
493
+ - (imgs가 여러 장이면) 처리 후 가로로 concat해서 저장
494
+ """
495
+ person_h, person_w = _read_hw(person_path)
496
+ target_aspect = float(person_w) / float(person_h)
497
+
498
+ np_imgs = []
499
+ for im in imgs:
500
+ if isinstance(im, Image.Image):
501
+ arr = np.asarray(im.convert("RGB"), dtype=np.uint8)
502
+ else:
503
+ # 혹시 numpy가 들어오는 경우 대비
504
+ arr = np.asarray(im, dtype=np.uint8)
505
+ if arr.ndim == 2:
506
+ arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
507
+
508
+ cropped = _center_crop_lr_to_aspect(arr, target_aspect, pad_value=255)
509
+ resized = cv2.resize(cropped, (person_w, person_h), interpolation=cv2.INTER_AREA)
510
+ np_imgs.append(resized)
511
+
512
+ out = np.concatenate(np_imgs, axis=1) # imgs가 1장이면 그대로
513
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
514
+ imageio.imsave(out_path, out)
515
+
516
 
517
 
518
  @lru_cache(maxsize=1)
 
740
  guidance_scale=7.5,
741
  )
742
 
743
+ # save_cropped(images, paths.output_path)
744
+ # return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
745
+ save_output_match_person(images, paths.output_path, paths.person_path)
746
  return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
747
 
748
 
749
+
750
  def set_seed(seed: int):
751
  if seed is None or seed < 0:
752
  return