ssoxye commited on
Commit
c94fce0
ยท
1 Parent(s): 1a497a0

update sketch input

Browse files
Files changed (1) hide show
  1. app.py +120 -41
app.py CHANGED
@@ -90,7 +90,7 @@ W: Optional[int] = None
90
  @dataclass
91
  class Paths:
92
  person_path: str
93
- depth_path: str
94
  style_path: str
95
  output_path: str
96
 
@@ -101,6 +101,7 @@ def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
101
  raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
102
  return img
103
 
 
104
  def apply_parsing_white_mask_to_person_cv2(
105
  person_pil: Image.Image,
106
  parsing_img: Image.Image
@@ -108,16 +109,12 @@ def apply_parsing_white_mask_to_person_cv2(
108
  """
109
  person_pil(RGB) ํฌ๊ธฐ์— parsing_img(L) ๋งˆ์Šคํฌ๋ฅผ ๋งž์ถฐ์„œ
110
  ํฐ์ƒ‰(255) ์˜์—ญ๋งŒ person์„ ๋‚จ๊ธฐ๊ณ  ๋‚˜๋จธ์ง€๋Š” ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ์œผ๋กœ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜.
111
-
112
- - parsing_img๋Š” person ํฌ๊ธฐ์— ๋ฐ˜๋“œ์‹œ ๋งž์ถฐ์•ผ ํ•จ (NEAREST)
113
  """
114
  person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
115
 
116
- # parsing ๋งˆ์Šคํฌ (L)
117
  mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
118
 
119
- # โœ… ํ•ต์‹ฌ: ํฌ๊ธฐ ๋ถˆ์ผ์น˜ ํ•ด๊ฒฐ (H,W) ๋งž์ถค
120
- if mask.shape[0] != person_rgb.shape[0] or mask.shape[1] != person_rgb.shape[1]:
121
  mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
122
 
123
  white_mask = (mask == 255)
@@ -129,7 +126,6 @@ def apply_parsing_white_mask_to_person_cv2(
129
  return result_bgr
130
 
131
 
132
-
133
  def compute_hw_from_person(person_path: str):
134
  img = _imread_or_raise(person_path)
135
  orig_h, orig_w = img.shape[:2]
@@ -141,11 +137,10 @@ def compute_hw_from_person(person_path: str):
141
  return new_h, new_w
142
 
143
 
144
- def invert_sketch_area(sketch_pil: Image.Image) -> Image.Image:
145
- return ImageOps.invert(sketch_pil.convert("L")).convert("RGB")
146
-
147
-
148
  def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
 
 
 
149
  global H, W
150
  if H is None or W is None:
151
  raise RuntimeError("Global H/W not set.")
@@ -186,9 +181,7 @@ def preprocess_mask(mask_img: Image.Image) -> Image.Image:
186
  left_padding = total_padding // 2
187
  right_padding = total_padding - left_padding
188
  m = cv2.copyMakeBorder(
189
- m,
190
- top=0, bottom=0,
191
- left=left_padding, right=right_padding,
192
  borderType=cv2.BORDER_CONSTANT,
193
  value=0,
194
  )
@@ -204,7 +197,11 @@ def preprocess_mask(mask_img: Image.Image) -> Image.Image:
204
 
205
  return Image.fromarray(m, mode="L").convert("RGB")
206
 
 
207
  def make_depth(depth_path: str) -> Image.Image:
 
 
 
208
  global H, W
209
  if H is None or W is None:
210
  raise RuntimeError("Global H/W not set. Call run_one() first.")
@@ -242,6 +239,70 @@ def make_depth(depth_path: str) -> Image.Image:
242
  return image_depth
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
@@ -269,7 +330,7 @@ def save_cropped(imgs, out_path: str):
269
  @lru_cache(maxsize=1)
270
  def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
271
  device = "cuda" if torch.cuda.is_available() else "cpu"
272
- dtype = torch.float32 # ํ˜„์žฌ ๋„ˆ ์„ค์ • ์œ ์ง€
273
 
274
  print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
275
 
@@ -332,6 +393,7 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
332
 
333
  H, W = compute_hw_from_person(paths.person_path)
334
 
 
335
  res = run_simple_extractor(
336
  category="Upper-clothes",
337
  input_path=os.path.abspath(paths.person_path),
@@ -341,10 +403,27 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
341
  if parsing_img is None:
342
  raise RuntimeError("run_simple_extractor returned no parsing images.")
343
 
344
- sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  merged_img = merge_white_regions_or(parsing_img, sketch_area)
346
  mask_pil = preprocess_mask(merged_img)
347
 
 
348
  person_bgr = _imread_or_raise(paths.person_path)
349
  person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
350
 
@@ -367,25 +446,23 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
367
  person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB)
368
  person_pil = Image.fromarray(person_rgb)
369
 
370
- depth_map = make_depth(paths.depth_path)
371
-
372
-
373
-
374
- personn = Image.open(paths.person_path)
375
-
376
- garment_ = apply_parsing_white_mask_to_person_cv2(
377
- personn,
378
- parsing_img
379
- )
380
 
 
 
 
381
  garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB)
382
-
383
- # โœ… (์ค‘์š”) garment_๋Š” ์›๋ณธ person ํฌ๊ธฐ์ผ ์ˆ˜ ์žˆ์œผ๋‹ˆ ์ „์—ญ (W,H)๋กœ ๋งž์ถ˜ ๋’ค padding
384
  garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
385
-
386
- target_width = 1024 # โœ… ๊ณ ์ •
387
- padding = (target_width - person_bgr.shape[1]) // 2
388
 
 
389
  garment_rgb = cv2.copyMakeBorder(
390
  garment_rgb,
391
  top=0, bottom=0,
@@ -395,9 +472,11 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
395
  )
396
  garment_pil = Image.fromarray(garment_rgb)
397
 
 
398
  gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
399
- gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_AREA)
400
  gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
 
401
  cur_w2 = gm.shape[1]
402
  if cur_w2 < target_width:
403
  total = target_width - cur_w2
@@ -409,7 +488,6 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
409
  gm = gm[:, left2:left2 + target_width]
410
  garment_mask_pil = Image.fromarray(gm)
411
 
412
- # --- sanity sizes (optional)
413
  print(
414
  "[SIZE] person:", person_pil.size,
415
  "mask:", mask_pil.size,
@@ -474,8 +552,10 @@ def set_seed(seed: int):
474
 
475
  def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed):
476
  print("[UI] infer_web called", flush=True)
477
- if person_fp is None or sketch_fp is None or style_fp is None:
478
- raise gr.Error("person / sketch / style ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋‘ ์—…๋กœ๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
 
 
479
 
480
  set_seed(int(seed) if seed is not None else -1)
481
 
@@ -484,7 +564,7 @@ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed):
484
 
485
  paths = Paths(
486
  person_path=person_fp,
487
- depth_path=sketch_fp,
488
  style_path=style_fp,
489
  output_path=out_path,
490
  )
@@ -498,12 +578,12 @@ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed):
498
 
499
 
500
  with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
501
- gr.Markdown("## VISTA Demo\nperson / sketch(guide) / style ์ž…๋ ฅ์œผ๋กœ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.")
502
 
503
  with gr.Row():
504
- person_in = gr.Image(label="Person Image", type="filepath")
505
- sketch_in = gr.Image(label="Sketch / Guide (depth_path)", type="filepath")
506
- style_in = gr.Image(label="Style Image", type="filepath")
507
 
508
  with gr.Row():
509
  prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2)
@@ -534,4 +614,3 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
534
  demo.queue()
535
  if __name__ == "__main__":
536
  demo.launch(server_name="0.0.0.0", server_port=7860)
537
-
 
90
  @dataclass
91
  class Paths:
92
  person_path: str
93
+ depth_path: Optional[str] # โœ… (1) sketch(guide) optional
94
  style_path: str
95
  output_path: str
96
 
 
101
  raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
102
  return img
103
 
104
+
105
  def apply_parsing_white_mask_to_person_cv2(
106
  person_pil: Image.Image,
107
  parsing_img: Image.Image
 
109
  """
110
  person_pil(RGB) ํฌ๊ธฐ์— parsing_img(L) ๋งˆ์Šคํฌ๋ฅผ ๋งž์ถฐ์„œ
111
  ํฐ์ƒ‰(255) ์˜์—ญ๋งŒ person์„ ๋‚จ๊ธฐ๊ณ  ๋‚˜๋จธ์ง€๋Š” ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ์œผ๋กœ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜.
 
 
112
  """
113
  person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
114
 
 
115
  mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
116
 
117
+ if mask.shape[:2] != person_rgb.shape[:2]:
 
118
  mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
119
 
120
  white_mask = (mask == 255)
 
126
  return result_bgr
127
 
128
 
 
129
  def compute_hw_from_person(person_path: str):
130
  img = _imread_or_raise(person_path)
131
  orig_h, orig_w = img.shape[:2]
 
137
  return new_h, new_w
138
 
139
 
 
 
 
 
140
  def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
141
+ """
142
+ sketch(guide) ์—…๋กœ๋“œ ์ด๋ฏธ์ง€๋ฅผ filled mask(RGB)๋กœ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜
143
+ """
144
  global H, W
145
  if H is None or W is None:
146
  raise RuntimeError("Global H/W not set.")
 
181
  left_padding = total_padding // 2
182
  right_padding = total_padding - left_padding
183
  m = cv2.copyMakeBorder(
184
+ m, 0, 0, left_padding, right_padding,
 
 
185
  borderType=cv2.BORDER_CONSTANT,
186
  value=0,
187
  )
 
197
 
198
  return Image.fromarray(m, mode="L").convert("RGB")
199
 
200
+
201
  def make_depth(depth_path: str) -> Image.Image:
202
+ """
203
+ depth_path(guide/sketch)๋กœ๋ถ€ํ„ฐ depth_map ์ƒ์„ฑ (๊ธฐ์กด ๋กœ์ง)
204
+ """
205
  global H, W
206
  if H is None or W is None:
207
  raise RuntimeError("Global H/W not set. Call run_one() first.")
 
239
  return image_depth
240
 
241
 
242
+ def _edges_from_parsing(parsing_img: Image.Image) -> np.ndarray:
243
+ """
244
+ โœ… parsing_pil์—์„œ edge ์ถ”์ถœํ•ด์„œ depth_img(0~255, uint8)๋กœ ๋งŒ๋“ฆ.
245
+ - parsing_img๊ฐ€ ๋งˆ์Šคํฌ(ํฐ์ƒ‰=255)๋ผ๊ณ  ๊ฐ€์ •ํ•˜๊ณ  edge๋ฅผ ๋ฝ‘์Œ
246
+ - edge๋Š” ์„ (ํฐ์ƒ‰)์œผ๋กœ ๋‚˜์˜ค๊ฒŒ ๋งŒ๋“ค์–ด์„œ ์•„๋ž˜ make_depth์™€ ๋™์ผ ํŒŒ์ดํ”„๋ผ์ธ์— ๋„ฃ์„ ์ˆ˜ ์žˆ๊ฒŒ ํ•จ
247
+ """
248
+ m = np.array(parsing_img.convert("L"), dtype=np.uint8)
249
+
250
+ # parsing์ด ์ด๋ฏธ ํฐ์ƒ‰ ๋งˆ์Šคํฌ๋ผ๊ณ  ํ•ด๋„, ํ˜น์‹œ ๊ฐ’์ด ์• ๋งคํ•˜๋ฉด ์ด์ง„ํ™”
251
+ _, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
252
+
253
+ # edge ์ถ”์ถœ: Canny (๋งˆ์Šคํฌ ๊ธฐ๋ฐ˜์ด๋ผ threshold ๋‚ฎ๊ฒŒ)
254
+ edges = cv2.Canny(m_bin, 50, 150)
255
+
256
+ # edges๋Š” 0/255 (ํฐ ์„ ) ํ˜•ํƒœ
257
+ # ๋„ˆ๋ฌด ์–‡์œผ๋ฉด contour fill์ด ๋ถˆ์•ˆ์ •ํ•  ์ˆ˜ ์žˆ์–ด์„œ ์กฐ๊ธˆ ๋‘๊ป๊ฒŒ
258
+ edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
259
+
260
+ return edges.astype(np.uint8)
261
+
262
+
263
+ def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image:
264
+ """
265
+ โœ… (4) depth_map ์—†์„ ๋•Œ:
266
+ - parsing_pil์—์„œ edge ์ถ”์ถœ
267
+ - ์•„๋ž˜ ์ œ๊ณตํ•œ ์ฝ”๋“œ(=make_depth ๋‚ด๋ถ€ ํŒŒ์ดํ”„๋ผ์ธ) ๊ทธ๋Œ€๋กœ ์ ์šฉ
268
+ """
269
+ global H, W
270
+ if H is None or W is None:
271
+ raise RuntimeError("Global H/W not set. Call run_one() first.")
272
+
273
+ depth_img = _edges_from_parsing(parsing_img) # โœ… parsing์—์„œ edge ์ถ”์ถœ ๊ฒฐ๊ณผ
274
+
275
+ # ---------- ์‚ฌ์šฉ์ž ์ œ๊ณต ์ฝ”๋“œ ๊ทธ๋Œ€๋กœ ----------
276
+ inverted_depth = cv2.bitwise_not(depth_img)
277
+ contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
278
+
279
+ filled_depth = inverted_depth.copy()
280
+ cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
281
+
282
+ # โœ… resize๋Š” ์ „์—ญ (W,H)
283
+ filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
284
+
285
+ height, width = filled_depth.shape
286
+ total_padding = 1024 - width
287
+ left_padding = total_padding // 2
288
+ right_padding = total_padding - left_padding
289
+
290
+ padded_depth = cv2.copyMakeBorder(
291
+ filled_depth, 0, 0, left_padding, right_padding,
292
+ borderType=cv2.BORDER_CONSTANT,
293
+ value=0,
294
+ )
295
+
296
+ inverted_image = ImageOps.invert(Image.fromarray(padded_depth))
297
+
298
+ with torch.inference_mode():
299
+ image_depth = depth_estimator(inverted_image)["depth"]
300
+
301
+ if DEBUG_SAVE:
302
+ image_depth.save("depth.png")
303
+
304
+ return image_depth
305
+ # --------------------------------------------
306
 
307
 
308
  def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
 
330
  @lru_cache(maxsize=1)
331
  def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
332
  device = "cuda" if torch.cuda.is_available() else "cpu"
333
+ dtype = torch.float32 # ์œ ์ง€
334
 
335
  print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
336
 
 
393
 
394
  H, W = compute_hw_from_person(paths.person_path)
395
 
396
+ # parsing ์ถ”์ถœ (ํ™”์ดํŠธ=255 ๋งˆ์Šคํฌ ํ˜•ํƒœ๋กœ ์˜ค๋Š” ๊ฑธ ๊ฐ€์ •)
397
  res = run_simple_extractor(
398
  category="Upper-clothes",
399
  input_path=os.path.abspath(paths.person_path),
 
403
  if parsing_img is None:
404
  raise RuntimeError("run_simple_extractor returned no parsing images.")
405
 
406
+ # -------------------------------------------------
407
+ # โœ… (2) UI sketch ์—…๋กœ๋“œ๋Š” optional
408
+ # โœ… (3) depth_path ์—†์œผ๋ฉด sketch_area = parsing_img
409
+ # -------------------------------------------------
410
+ use_depth_path = (
411
+ paths.depth_path is not None
412
+ and isinstance(paths.depth_path, str)
413
+ and len(paths.depth_path) > 0
414
+ and os.path.exists(paths.depth_path)
415
+ )
416
+
417
+ if use_depth_path:
418
+ sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
419
+ else:
420
+ sketch_area = parsing_img.convert("RGB") # โœ… depth_path ์—†์œผ๋ฉด sketch_area = parsing_img
421
+
422
+ # mask ์ƒ์„ฑ
423
  merged_img = merge_white_regions_or(parsing_img, sketch_area)
424
  mask_pil = preprocess_mask(merged_img)
425
 
426
+ # person padding to 1024 width
427
  person_bgr = _imread_or_raise(paths.person_path)
428
  person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
429
 
 
446
  person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB)
447
  person_pil = Image.fromarray(person_rgb)
448
 
449
+ # -------------------------------------------------
450
+ # โœ… (4) depth_map:
451
+ # - depth_path ์žˆ์œผ๋ฉด make_depth(paths.depth_path)
452
+ # - ์—†์œผ๋ฉด parsing_pil edge ์ถ”์ถœ -> ์ œ๊ณต ์ฝ”๋“œ ์ ์šฉ
453
+ # -------------------------------------------------
454
+ if use_depth_path:
455
+ depth_map = make_depth(paths.depth_path)
456
+ else:
457
+ depth_map = make_depth_from_parsing_edges(parsing_img)
 
458
 
459
+ # garment ์ถ”์ถœ (parsing white mask ๊ธฐ๋ฐ˜)
460
+ personn = Image.open(paths.person_path).convert("RGB")
461
+ garment_ = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
462
  garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB)
 
 
463
  garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
 
 
 
464
 
465
+ padding = (target_width - W) // 2 if W < target_width else 0
466
  garment_rgb = cv2.copyMakeBorder(
467
  garment_rgb,
468
  top=0, bottom=0,
 
472
  )
473
  garment_pil = Image.fromarray(garment_rgb)
474
 
475
+ # garment mask (parsing ์ž์ฒด๋ฅผ ๋™์ผ ํฌ๊ธฐ๋กœ)
476
  gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
477
+ gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST)
478
  gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
479
+
480
  cur_w2 = gm.shape[1]
481
  if cur_w2 < target_width:
482
  total = target_width - cur_w2
 
488
  gm = gm[:, left2:left2 + target_width]
489
  garment_mask_pil = Image.fromarray(gm)
490
 
 
491
  print(
492
  "[SIZE] person:", person_pil.size,
493
  "mask:", mask_pil.size,
 
552
 
553
  def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed):
554
  print("[UI] infer_web called", flush=True)
555
+
556
+ # โœ… person / style๋งŒ ํ•„์ˆ˜. sketch๋Š” ์„ ํƒ.
557
+ if person_fp is None or style_fp is None:
558
+ raise gr.Error("person / style ์ด๋ฏธ์ง€๋Š” ํ•„์ˆ˜์ž…๋‹ˆ๋‹ค. (sketch๋Š” ์„ ํƒ)")
559
 
560
  set_seed(int(seed) if seed is not None else -1)
561
 
 
564
 
565
  paths = Paths(
566
  person_path=person_fp,
567
+ depth_path=sketch_fp, # None ๊ฐ€๋Šฅ
568
  style_path=style_fp,
569
  output_path=out_path,
570
  )
 
578
 
579
 
580
  with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
581
+ gr.Markdown("## VISTA Demo\nperson / style ํ•„์ˆ˜, sketch(guide)๋Š” ์„ ํƒ์ž…๋‹ˆ๋‹ค.")
582
 
583
  with gr.Row():
584
+ person_in = gr.Image(label="Person Image (required)", type="filepath")
585
+ sketch_in = gr.Image(label="Sketch / Guide (optional)", type="filepath") # โœ… (2) ์„ ํƒ
586
+ style_in = gr.Image(label="Style Image (required)", type="filepath")
587
 
588
  with gr.Row():
589
  prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2)
 
614
  demo.queue()
615
  if __name__ == "__main__":
616
  demo.launch(server_name="0.0.0.0", server_port=7860)