ssoxye commited on
Commit
2cc3e3c
·
1 Parent(s): d19c49c

update bitwise

Browse files
Files changed (1) hide show
  1. app.py +63 -33
app.py CHANGED
@@ -114,8 +114,8 @@ W: Optional[int] = None
114
  @dataclass
115
  class Paths:
116
  person_path: str
117
- depth_path: Optional[str] # sketch(guide) optional
118
- style_path: str
119
  output_path: str
120
 
121
 
@@ -184,7 +184,7 @@ def remove_small_white_components(
184
  parsing_img: Image.Image,
185
  *,
186
  white_threshold: int = 128,
187
- min_white_area: int = 50,
188
  use_open: bool = False,
189
  open_ksize: int = 3,
190
  morph_iters: int = 1,
@@ -220,7 +220,6 @@ def remove_small_white_components(
220
  return Image.fromarray(mask, mode="L")
221
 
222
 
223
-
224
  def compute_hw_from_person(person_path: str):
225
  img = _imread_or_raise(person_path)
226
  orig_h, orig_w = img.shape[:2]
@@ -237,6 +236,7 @@ def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
237
  if H is None or W is None:
238
  raise RuntimeError("Global H/W not set.")
239
  img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE)
 
240
  img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
241
  _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
242
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -283,10 +283,10 @@ def make_depth(depth_path: str) -> Image.Image:
283
  raise RuntimeError("Global H/W not set. Call run_one() first.")
284
 
285
  depth_img = _imread_or_raise(depth_path, 0)
286
- inverted_depth = cv2.bitwise_not(depth_img)
287
- contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
288
 
289
- filled_depth = inverted_depth.copy()
290
  cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
291
 
292
  filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
@@ -421,6 +421,29 @@ _UI_TO_EXTRACTOR_CATEGORY = {
421
  }
422
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
425
  global H, W
426
  pipe, device, _dtype = get_pipe_and_device()
@@ -438,9 +461,7 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
438
  parsing_img = res["images"][0] if res.get("images") else None
439
  if parsing_img is None:
440
  raise RuntimeError("run_simple_extractor returned no parsing images.")
441
-
442
 
443
-
444
  parsing_img = remove_small_white_components(
445
  parsing_img,
446
  white_threshold=128,
@@ -448,13 +469,7 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
448
  use_open=False,
449
  )
450
 
451
-
452
- use_depth_path = (
453
- paths.depth_path is not None
454
- and isinstance(paths.depth_path, str)
455
- and len(paths.depth_path) > 0
456
- and os.path.exists(paths.depth_path)
457
- )
458
 
459
  if use_depth_path:
460
  sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
@@ -482,7 +497,6 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
482
  garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
483
  garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
484
  garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
485
-
486
  garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
487
  garment_pil = Image.fromarray(garment_rgb)
488
 
@@ -493,6 +507,11 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
493
  gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0])
494
  garment_mask_pil = Image.fromarray(gm)
495
 
 
 
 
 
 
496
  print(
497
  "[SIZE] person:", person_pil.size,
498
  "mask:", mask_pil.size,
@@ -501,6 +520,10 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
501
  "gmask:", garment_mask_pil.size,
502
  "ui_category:", category,
503
  "extractor_category:", extractor_category,
 
 
 
 
504
  flush=True
505
  )
506
 
@@ -511,8 +534,8 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
511
  device,
512
  mask_pil,
513
  person_pil,
514
- content_scale=0.3,
515
- style_scale=0.5,
516
  garment_images=garment_pil,
517
  garment_mask=garment_mask_pil,
518
  )
@@ -525,13 +548,19 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
525
  except Exception:
526
  pass
527
 
528
- style_img = Image.open(paths.style_path).convert("RGB")
529
-
530
- if prompt != "":
531
- prompt = extractor_category + " with " + prompt
 
 
 
 
 
 
532
  else:
533
  prompt = extractor_category
534
-
535
  print("==== prompt? ", prompt, flush=True)
536
 
537
  with torch.inference_mode():
@@ -566,8 +595,9 @@ def set_seed(seed: int):
566
  def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
567
  print("[UI] infer_web called", flush=True)
568
 
569
- if person_fp is None or style_fp is None:
570
- raise gr.Error("person / style 이미지는 필수입니다. (sketch는 선택)")
 
571
 
572
  if category not in ("Upper-body", "Lower-body", "Dress"):
573
  raise gr.Error(f"Invalid category: {category}")
@@ -580,7 +610,7 @@ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
580
  paths = Paths(
581
  person_path=person_fp,
582
  depth_path=sketch_fp,
583
- style_path=style_fp,
584
  output_path=out_path,
585
  )
586
 
@@ -593,7 +623,7 @@ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
593
 
594
 
595
  with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
596
- gr.Markdown("## VISTA Demo\nperson / style 필수, sketch(guide)는 선택입니다.")
597
 
598
  category_toggle = gr.Radio(
599
  choices=["Dress", "Upper-body", "Lower-body"],
@@ -617,18 +647,18 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
617
  gr.Markdown("#### Examples")
618
  gr.Examples(
619
  examples=person_examples,
620
- inputs=[person_in], # ✅ person만 채움 (독립 선택)
621
  examples_per_page=8,
622
  )
623
 
624
  # -------- Style column --------
625
  with gr.Column(scale=1):
626
- style_in = gr.Image(label="Style Image (required)", type="filepath")
627
  if style_examples:
628
  gr.Markdown("#### Examples")
629
  gr.Examples(
630
  examples=style_examples,
631
- inputs=[style_in], # ✅ style만 채움 (독립 선택)
632
  examples_per_page=8,
633
  )
634
 
@@ -642,7 +672,7 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
642
  gr.Markdown("#### Examples")
643
  gr.Examples(
644
  examples=sketch_examples,
645
- inputs=[sketch_in], # ✅ sketch만 채움 (독립 선택)
646
  examples_per_page=8,
647
  )
648
 
@@ -650,7 +680,7 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
650
  prompt_in = gr.Textbox(
651
  label="Prompt",
652
  value="",
653
- placeholder="ex) lace, button, …",
654
  lines=2,
655
  )
656
  steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
 
114
  @dataclass
115
  class Paths:
116
  person_path: str
117
+ depth_path: Optional[str] # sketch(guide) optional
118
+ style_path: Optional[str] # ✅ style optional (변경)
119
  output_path: str
120
 
121
 
 
184
  parsing_img: Image.Image,
185
  *,
186
  white_threshold: int = 128,
187
+ min_white_area: int = 150,
188
  use_open: bool = False,
189
  open_ksize: int = 3,
190
  morph_iters: int = 1,
 
220
  return Image.fromarray(mask, mode="L")
221
 
222
 
 
223
  def compute_hw_from_person(person_path: str):
224
  img = _imread_or_raise(person_path)
225
  orig_h, orig_w = img.shape[:2]
 
236
  if H is None or W is None:
237
  raise RuntimeError("Global H/W not set.")
238
  img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE)
239
+ img = cv2.bitwise_not(img)
240
  img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
241
  _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
242
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
283
  raise RuntimeError("Global H/W not set. Call run_one() first.")
284
 
285
  depth_img = _imread_or_raise(depth_path, 0)
286
+ # inverted_depth = cv2.bitwise_not(depth_img)
287
+ contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
288
 
289
+ filled_depth = depth_img.copy()
290
  cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
291
 
292
  filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
 
421
  }
422
 
423
 
424
+ def _has_valid_file(path: Optional[str]) -> bool:
425
+ return (
426
+ path is not None
427
+ and isinstance(path, str)
428
+ and len(path) > 0
429
+ and os.path.exists(path)
430
+ )
431
+
432
+
433
+ def _resolve_content_style_scales(style_present: bool, prompt_present: bool) -> Tuple[float, float]:
434
+ """
435
+ 요구사항:
436
+ - style image 없으면: (0.0, 0.0)
437
+ - prompt 없으면: (0.4, 0.6)
438
+ - 둘 다 있으면: 기존 유지 (0.3, 0.5)
439
+ """
440
+ if not style_present:
441
+ return 0.0, 0.0
442
+ if not prompt_present:
443
+ return 0.4, 0.65
444
+ return 0.4, 0.5
445
+
446
+
447
  def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
448
  global H, W
449
  pipe, device, _dtype = get_pipe_and_device()
 
461
  parsing_img = res["images"][0] if res.get("images") else None
462
  if parsing_img is None:
463
  raise RuntimeError("run_simple_extractor returned no parsing images.")
 
464
 
 
465
  parsing_img = remove_small_white_components(
466
  parsing_img,
467
  white_threshold=128,
 
469
  use_open=False,
470
  )
471
 
472
+ use_depth_path = _has_valid_file(paths.depth_path)
 
 
 
 
 
 
473
 
474
  if use_depth_path:
475
  sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
 
497
  garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
498
  garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
499
  garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
 
500
  garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
501
  garment_pil = Image.fromarray(garment_rgb)
502
 
 
507
  gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0])
508
  garment_mask_pil = Image.fromarray(gm)
509
 
510
+ # ✅ 조건에 따른 scale 결정
511
+ style_present = _has_valid_file(paths.style_path)
512
+ prompt_present = (prompt is not None) and (str(prompt).strip() != "")
513
+ content_scale, style_scale = _resolve_content_style_scales(style_present, prompt_present)
514
+
515
  print(
516
  "[SIZE] person:", person_pil.size,
517
  "mask:", mask_pil.size,
 
520
  "gmask:", garment_mask_pil.size,
521
  "ui_category:", category,
522
  "extractor_category:", extractor_category,
523
+ "style_present:", style_present,
524
+ "prompt_present:", prompt_present,
525
+ "content_scale:", content_scale,
526
+ "style_scale:", style_scale,
527
  flush=True
528
  )
529
 
 
534
  device,
535
  mask_pil,
536
  person_pil,
537
+ content_scale=content_scale, # ✅ 변경
538
+ style_scale=style_scale, # ✅ 변경
539
  garment_images=garment_pil,
540
  garment_mask=garment_mask_pil,
541
  )
 
548
  except Exception:
549
  pass
550
 
551
+ # style image 없을 때도 generate 입력이 None이 되지 않게 대체
552
+ if style_present:
553
+ style_img = Image.open(paths.style_path).convert("RGB")
554
+ else:
555
+ # scale이 0이므로 영향은 없고, 함수 시그니처만 만족시키기 위한 대체값
556
+ style_img = garment_pil
557
+
558
+ # prompt 구성은 기존 유지
559
+ if prompt is not None and str(prompt).strip() != "":
560
+ prompt = extractor_category + " with " + str(prompt).strip()
561
  else:
562
  prompt = extractor_category
563
+
564
  print("==== prompt? ", prompt, flush=True)
565
 
566
  with torch.inference_mode():
 
595
  def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
596
  print("[UI] infer_web called", flush=True)
597
 
598
+ # person만 필수, style은 선택
599
+ if person_fp is None:
600
+ raise gr.Error("person 이미지는 필수입니다. (style/sketch는 선택)")
601
 
602
  if category not in ("Upper-body", "Lower-body", "Dress"):
603
  raise gr.Error(f"Invalid category: {category}")
 
610
  paths = Paths(
611
  person_path=person_fp,
612
  depth_path=sketch_fp,
613
+ style_path=style_fp, # ✅ None 가능
614
  output_path=out_path,
615
  )
616
 
 
623
 
624
 
625
  with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
626
+ gr.Markdown("## VISTA Demo\nperson 필수, style/sketch(guide)는 선택입니다.")
627
 
628
  category_toggle = gr.Radio(
629
  choices=["Dress", "Upper-body", "Lower-body"],
 
647
  gr.Markdown("#### Examples")
648
  gr.Examples(
649
  examples=person_examples,
650
+ inputs=[person_in],
651
  examples_per_page=8,
652
  )
653
 
654
  # -------- Style column --------
655
  with gr.Column(scale=1):
656
+ style_in = gr.Image(label="Style Image (optional)", type="filepath")
657
  if style_examples:
658
  gr.Markdown("#### Examples")
659
  gr.Examples(
660
  examples=style_examples,
661
+ inputs=[style_in],
662
  examples_per_page=8,
663
  )
664
 
 
672
  gr.Markdown("#### Examples")
673
  gr.Examples(
674
  examples=sketch_examples,
675
+ inputs=[sketch_in],
676
  examples_per_page=8,
677
  )
678
 
 
680
  prompt_in = gr.Textbox(
681
  label="Prompt",
682
  value="",
683
+ placeholder="ex) crystal, lace, button, …",
684
  lines=2,
685
  )
686
  steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")