MogensR commited on
Commit
df850a4
·
verified ·
1 Parent(s): 937519c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -54
app.py CHANGED
@@ -4,12 +4,11 @@
4
  # =============================================================================
5
  """
6
  Enhanced Video Background Replacement (SAM2 + MatAnyone + AI Backgrounds)
7
- - Robust memory management & cleanup
8
- - SAM2 person mask (CUDA)
9
- - MatAnyone matting w/ **probability** mask on the first frame (no idx_mask)
10
- - Cleaned tensor shapes (image: 3xHxW, prob: 1xHxW), consistent device/dtype
11
- - Optional SDXL / Playground / OpenAI background generation
12
- - Gradio UI with “chapters” in code for quick edits
13
  """
14
 
15
  # =============================================================================
@@ -22,9 +21,7 @@
22
  import psutil
23
  import time
24
  import json
25
- import math
26
  import base64
27
- import queue
28
  import random
29
  import shutil
30
  import logging
@@ -45,7 +42,7 @@
45
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
46
  logger = logging.getLogger("bgx")
47
 
48
- # Environment tuning (safe defaults; do not overwrite if already set)
49
  os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY")
50
  os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1")
51
  os.environ.setdefault("PYTHONUNBUFFERED", "1")
@@ -351,7 +348,7 @@ def initialize(self) -> bool:
351
  model = build_sam2("sam2.1/sam2.1_hiera_l.yaml", str(ckpt), device="cuda")
352
  self.predictor = SAM2ImagePredictor(model)
353
 
354
- # Quick smoke test
355
  test = np.zeros((64, 64, 3), dtype=np.uint8)
356
  self.predictor.set_image(test)
357
  masks, scores, _ = self.predictor.predict(
@@ -416,47 +413,37 @@ def __init__(self):
416
  self.core = None
417
  self.initialized = False
418
 
419
- # ----- small tensor helpers -----
420
- def _to_chw_float(self, img01: np.ndarray) -> torch.Tensor:
421
- # img01: HxWx3, float32 [0..1]
422
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
423
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # 3xHxW
424
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
425
 
426
- def _prob_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> torch.Tensor:
427
- # returns 1xHxW float32 [0..1]
428
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
429
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
430
  prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # 1xHxW
431
  t = torch.from_numpy(prob).contiguous().float()
432
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
433
 
434
- def _alpha_to_u8_hw(self, alpha_like: torch.Tensor) -> np.ndarray:
435
- # Accepts tensor with shapes: (1,H,W) or (H,W) or (K,H,W) where K==1
436
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
437
- # Many MatAnyone step() return (indices, probs)
438
- alpha_like = alpha_like[1]
439
  if isinstance(alpha_like, torch.Tensor):
440
  t = alpha_like.detach()
441
  if t.is_cuda:
442
  t = t.cpu()
443
- t = t.float().clamp(0, 1)
444
- a = t.numpy()
445
  else:
446
  a = np.asarray(alpha_like, dtype=np.float32)
447
  a = np.clip(a, 0, 1)
448
-
449
- if a.ndim == 3 and a.shape[0] == 1:
450
- a = a[0] # (H,W)
451
- elif a.ndim == 3 and a.shape[0] > 1:
452
- a = a[0] # take first object
453
- elif a.ndim == 2:
454
- pass
455
- else:
456
- # try to squeeze any trailing singleton dims
457
- a = np.squeeze(a)
458
- if a.ndim != 2:
459
- raise ValueError(f"Alpha map must be HxW; got shape {a.shape}")
460
  return (np.clip(a * 255.0, 0, 255).astype(np.uint8))
461
 
462
  def initialize(self) -> bool:
@@ -498,7 +485,7 @@ def initialize(self) -> bool:
498
  state.matanyone_error = f"MatAnyone init error: {e}"
499
  return False
500
 
501
- # ----- main video matting using PROB mask on first frame -----
502
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
503
  if not self.initialized or self.core is None:
504
  raise RuntimeError("MatAnyone not initialized")
@@ -526,7 +513,7 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
526
 
527
  frame_idx = 0
528
 
529
- # ---------- First frame (with PROB mask) ----------
530
  ok, frame_bgr = cap.read()
531
  if not ok or frame_bgr is None:
532
  cap.release()
@@ -536,14 +523,13 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
536
  prob_chw = self._prob_from_mask_u8(seed_mask, w, h) # 1xHxW
537
 
538
  with torch.no_grad():
539
- # Use PROB path (no idx_mask, no objects). Some forks require `matting=True`
540
  out_prob = self.core.step(img_chw, prob=prob_chw, matting=True)
541
 
542
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
543
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
544
  frame_idx += 1
545
 
546
- # ---------- Remaining frames (no mask) ----------
547
  while True:
548
  ok, frame_bgr = cap.read()
549
  if not ok or frame_bgr is None:
@@ -579,7 +565,7 @@ def process_video(self, input_path: str, mask_path: str, output_path: str) -> st
579
  return str(alpha_path)
580
 
581
  # =============================================================================
582
- # CHAPTER 7: AI BACKGROUNDS (SDXL / Playground / OpenAI)
583
  # =============================================================================
584
  def _maybe_enable_xformers(pipe):
585
  try:
@@ -639,7 +625,7 @@ def generate_sdxl_background(width:int, height:int, prompt:str, steps:int=30, gu
639
  generator=generator
640
  ).images[0]
641
 
642
- out = TEMP_DIR / f"sdxl_bg_{int(time.time())}_{seed:08d}.jpg"
643
  img.save(out, quality=95, optimize=True)
644
  memory_manager.register_temp_file(str(out))
645
  del pipe, img
@@ -680,7 +666,7 @@ def generate_playground_v25_background(width:int, height:int, prompt:str, steps:
680
  generator=generator
681
  ).images[0]
682
 
683
- out = TEMP_DIR / f"pg25_bg_{int(time.time())}_{seed:08d}.jpg"
684
  img.save(out, quality=95, optimize=True)
685
  memory_manager.register_temp_file(str(out))
686
  del pipe, img
@@ -726,7 +712,7 @@ def generate_sd15_background(width:int, height:int, prompt:str, steps:int=25, gu
726
  generator=generator
727
  ).images[0]
728
 
729
- out = TEMP_DIR / f"sd15_bg_{int(time.time())}_{seed:08d}.jpg"
730
  img.save(out, quality=95, optimize=True)
731
  memory_manager.register_temp_file(str(out))
732
  del pipe, img
@@ -785,7 +771,7 @@ def generate_ai_background_router(width:int, height:int, prompt:str, model:str="
785
  return str(out)
786
 
787
  # =============================================================================
788
- # CHAPTER 8: CHUNKED PROCESSOR (optional, unchanged)
789
  # =============================================================================
790
  class ChunkedVideoProcessor:
791
  def __init__(self, chunk_size_frames: int = 60):
@@ -934,7 +920,7 @@ def process_video_main(
934
  alpha_clip = VideoFileClip(alpha_video)
935
 
936
  if background_path and os.path.exists(background_path):
937
- messages.append(f"🖼️ Using background file")
938
  bg_bgr = cv2.imread(background_path)
939
  bg_bgr = cv2.resize(bg_bgr, (w, h))
940
  bg_rgb = cv2.cvtColor(bg_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
@@ -1049,17 +1035,13 @@ def preload(ai_model, openai_key, force_gpu, progress=gr.Progress()):
1049
  progress(0, desc="Preloading...")
1050
  msg = ""
1051
  if ai_model in ("SDXL", "Playground v2.5", "SD 1.5 (fallback)"):
1052
- # “preload lite”: generate tiny image once (2 steps)
1053
  try:
1054
  if ai_model == "SDXL":
1055
- _ = generate_sdxl_background(64, 64, "plain background", steps=2, guidance=3.5,
1056
- seed=42, require_gpu=bool(force_gpu))
1057
  elif ai_model == "Playground v2.5":
1058
- _ = generate_playground_v25_background(64, 64, "plain background", steps=2, guidance=3.5,
1059
- seed=42, require_gpu=bool(force_gpu))
1060
  else:
1061
- _ = generate_sd15_background(64, 64, "plain background", steps=2, guidance=3.5,
1062
- seed=42, require_gpu=bool(force_gpu))
1063
  msg += f"{ai_model} preloaded.\n"
1064
  except Exception as e:
1065
  msg += f"{ai_model} preload failed: {e}\n"
@@ -1152,12 +1134,18 @@ def approve_background(bg_path):
1152
  gr.Markdown("### Background")
1153
  bg_method = gr.Radio(choices=["Upload Image", "Gradients", "AI Generated"],
1154
  value="AI Generated", label="Background Method")
 
 
1155
  with gr.Group(visible=False) as upload_group:
1156
  upload_img = gr.Image(label="Background Image", type="filepath")
1157
- with gr.Group(visible=True) as gradient_group:
 
 
1158
  gradient_choice = gr.Dropdown(label="Gradient Style",
1159
  choices=list(GRADIENT_PRESETS.keys()),
1160
  value="Slate")
 
 
1161
  with gr.Group(visible=True) as ai_group:
1162
  prompt_suggestions = gr.Dropdown(label="💡 Prompt Inspiration",
1163
  choices=AI_PROMPT_SUGGESTIONS,
@@ -1216,6 +1204,7 @@ def approve_background(bg_path):
1216
 
1217
  # --- Wiring ---
1218
  def update_background_visibility(method):
 
1219
  return (
1220
  gr.update(visible=(method == "Upload Image")),
1221
  gr.update(visible=(method == "Gradients")),
@@ -1228,7 +1217,7 @@ def update_prompt_from_suggestion(suggestion):
1228
  return gr.update(value=suggestion)
1229
 
1230
  bg_method.change(
1231
- lambda m: update_background_visibility(m),
1232
  inputs=[bg_method],
1233
  outputs=[upload_group, gradient_group, ai_group]
1234
  )
@@ -1255,10 +1244,11 @@ def update_prompt_from_suggestion(suggestion):
1255
  diagnostics_btn.click(diag, outputs=[diagnostics_output])
1256
  cleanup_btn.click(cleanup, outputs=[diagnostics_output])
1257
 
 
1258
  def process_video(
1259
  video_file,
1260
  bg_method,
1261
- upload_img,
1262
  gradient_choice,
1263
  approved_background_path,
1264
  last_generated_bg,
@@ -1321,7 +1311,7 @@ def process_video(
1321
  inputs=[
1322
  video_input,
1323
  bg_method,
1324
- upload_group, # this group passes the image component value
1325
  gradient_choice,
1326
  approved_background_path, last_generated_bg,
1327
  trim_enabled, trim_seconds, crf_value, audio_enabled,
 
4
  # =============================================================================
5
  """
6
  Enhanced Video Background Replacement (SAM2 + MatAnyone + AI Backgrounds)
7
+ - Strict tensor shapes for MatAnyone (image: 3xHxW, first-frame prob mask: 1xHxW)
8
+ - First frame uses PROB path (no idx_mask / objects) to avoid assertion
9
+ - Memory management & cleanup
10
+ - SDXL / Playground / OpenAI backgrounds
11
+ - Gradio UI with “CHAPTER” dividers
 
12
  """
13
 
14
  # =============================================================================
 
21
  import psutil
22
  import time
23
  import json
 
24
  import base64
 
25
  import random
26
  import shutil
27
  import logging
 
42
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
43
  logger = logging.getLogger("bgx")
44
 
45
+ # Environment tuning (safe defaults)
46
  os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY")
47
  os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1")
48
  os.environ.setdefault("PYTHONUNBUFFERED", "1")
 
348
  model = build_sam2("sam2.1/sam2.1_hiera_l.yaml", str(ckpt), device="cuda")
349
  self.predictor = SAM2ImagePredictor(model)
350
 
351
+ # Smoke test
352
  test = np.zeros((64, 64, 3), dtype=np.uint8)
353
  self.predictor.set_image(test)
354
  masks, scores, _ = self.predictor.predict(
 
413
  self.core = None
414
  self.initialized = False
415
 
416
+ # ----- tensor helpers -----
417
+ def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
 
418
  assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
419
  t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # 3xHxW
420
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
421
 
422
+ def _prob_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
 
423
  if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
424
  mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
425
  prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # 1xHxW
426
  t = torch.from_numpy(prob).contiguous().float()
427
  return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
428
 
429
+ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
 
430
  if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
431
+ alpha_like = alpha_like[1] # handle (indices, probs)
 
432
  if isinstance(alpha_like, torch.Tensor):
433
  t = alpha_like.detach()
434
  if t.is_cuda:
435
  t = t.cpu()
436
+ a = t.float().clamp(0, 1).numpy()
 
437
  else:
438
  a = np.asarray(alpha_like, dtype=np.float32)
439
  a = np.clip(a, 0, 1)
440
+ a = np.squeeze(a)
441
+ if a.ndim != 2:
442
+ # handle shapes (1,H,W) or (K,H,W) → pick first
443
+ if a.ndim == 3 and a.shape[0] >= 1:
444
+ a = a[0]
445
+ else:
446
+ raise ValueError(f"Alpha must be HxW; got {a.shape}")
 
 
 
 
 
447
  return (np.clip(a * 255.0, 0, 255).astype(np.uint8))
448
 
449
  def initialize(self) -> bool:
 
485
  state.matanyone_error = f"MatAnyone init error: {e}"
486
  return False
487
 
488
+ # ----- video matting using first-frame PROB mask -----
489
  def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
490
  if not self.initialized or self.core is None:
491
  raise RuntimeError("MatAnyone not initialized")
 
513
 
514
  frame_idx = 0
515
 
516
+ # First frame (with PROB mask)
517
  ok, frame_bgr = cap.read()
518
  if not ok or frame_bgr is None:
519
  cap.release()
 
523
  prob_chw = self._prob_from_mask_u8(seed_mask, w, h) # 1xHxW
524
 
525
  with torch.no_grad():
 
526
  out_prob = self.core.step(img_chw, prob=prob_chw, matting=True)
527
 
528
  alpha_u8 = self._alpha_to_u8_hw(out_prob)
529
  cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
530
  frame_idx += 1
531
 
532
+ # Remaining frames (no mask)
533
  while True:
534
  ok, frame_bgr = cap.read()
535
  if not ok or frame_bgr is None:
 
565
  return str(alpha_path)
566
 
567
  # =============================================================================
568
+ # CHAPTER 7: AI BACKGROUNDS
569
  # =============================================================================
570
  def _maybe_enable_xformers(pipe):
571
  try:
 
625
  generator=generator
626
  ).images[0]
627
 
628
+ out = TEMP_DIR / f"sdxl_bg_{int(time.time())}_{seed or 0:08d}.jpg"
629
  img.save(out, quality=95, optimize=True)
630
  memory_manager.register_temp_file(str(out))
631
  del pipe, img
 
666
  generator=generator
667
  ).images[0]
668
 
669
+ out = TEMP_DIR / f"pg25_bg_{int(time.time())}_{seed or 0:08d}.jpg"
670
  img.save(out, quality=95, optimize=True)
671
  memory_manager.register_temp_file(str(out))
672
  del pipe, img
 
712
  generator=generator
713
  ).images[0]
714
 
715
+ out = TEMP_DIR / f"sd15_bg_{int(time.time())}_{seed or 0:08d}.jpg"
716
  img.save(out, quality=95, optimize=True)
717
  memory_manager.register_temp_file(str(out))
718
  del pipe, img
 
771
  return str(out)
772
 
773
  # =============================================================================
774
+ # CHAPTER 8: CHUNKED PROCESSOR (optional)
775
  # =============================================================================
776
  class ChunkedVideoProcessor:
777
  def __init__(self, chunk_size_frames: int = 60):
 
920
  alpha_clip = VideoFileClip(alpha_video)
921
 
922
  if background_path and os.path.exists(background_path):
923
+ messages.append("🖼️ Using background file")
924
  bg_bgr = cv2.imread(background_path)
925
  bg_bgr = cv2.resize(bg_bgr, (w, h))
926
  bg_rgb = cv2.cvtColor(bg_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
 
1035
  progress(0, desc="Preloading...")
1036
  msg = ""
1037
  if ai_model in ("SDXL", "Playground v2.5", "SD 1.5 (fallback)"):
 
1038
  try:
1039
  if ai_model == "SDXL":
1040
+ _ = generate_sdxl_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
 
1041
  elif ai_model == "Playground v2.5":
1042
+ _ = generate_playground_v25_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
 
1043
  else:
1044
+ _ = generate_sd15_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
 
1045
  msg += f"{ai_model} preloaded.\n"
1046
  except Exception as e:
1047
  msg += f"{ai_model} preload failed: {e}\n"
 
1134
  gr.Markdown("### Background")
1135
  bg_method = gr.Radio(choices=["Upload Image", "Gradients", "AI Generated"],
1136
  value="AI Generated", label="Background Method")
1137
+
1138
+ # Upload group (hidden by default)
1139
  with gr.Group(visible=False) as upload_group:
1140
  upload_img = gr.Image(label="Background Image", type="filepath")
1141
+
1142
+ # Gradient group (hidden by default)
1143
+ with gr.Group(visible=False) as gradient_group:
1144
  gradient_choice = gr.Dropdown(label="Gradient Style",
1145
  choices=list(GRADIENT_PRESETS.keys()),
1146
  value="Slate")
1147
+
1148
+ # AI group (visible by default)
1149
  with gr.Group(visible=True) as ai_group:
1150
  prompt_suggestions = gr.Dropdown(label="💡 Prompt Inspiration",
1151
  choices=AI_PROMPT_SUGGESTIONS,
 
1204
 
1205
  # --- Wiring ---
1206
  def update_background_visibility(method):
1207
+ # return visibilities for: upload_group, gradient_group, ai_group
1208
  return (
1209
  gr.update(visible=(method == "Upload Image")),
1210
  gr.update(visible=(method == "Gradients")),
 
1217
  return gr.update(value=suggestion)
1218
 
1219
  bg_method.change(
1220
+ update_background_visibility,
1221
  inputs=[bg_method],
1222
  outputs=[upload_group, gradient_group, ai_group]
1223
  )
 
1244
  diagnostics_btn.click(diag, outputs=[diagnostics_output])
1245
  cleanup_btn.click(cleanup, outputs=[diagnostics_output])
1246
 
1247
+ # ----- FIXED: use upload_img (Image component), not upload_group (Group) -----
1248
  def process_video(
1249
  video_file,
1250
  bg_method,
1251
+ upload_img, # <-- correct input
1252
  gradient_choice,
1253
  approved_background_path,
1254
  last_generated_bg,
 
1311
  inputs=[
1312
  video_input,
1313
  bg_method,
1314
+ upload_img, # <-- FIXED here
1315
  gradient_choice,
1316
  approved_background_path, last_generated_bg,
1317
  trim_enabled, trim_seconds, crf_value, audio_enabled,