MogensR commited on
Commit
8a850cc
·
1 Parent(s): 2cd2385
Files changed (2) hide show
  1. app.py +60 -54
  2. pipeline.py +82 -10
app.py CHANGED
@@ -13,26 +13,64 @@
13
 
14
  import os
15
  import json
16
- import gradio as gr
 
17
  from pathlib import Path
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Local pipeline
 
20
  import pipeline
21
 
22
- # --- Compact styling to resemble older v4 look ---
23
- CSS = """
24
- /* Narrow overall width */
25
- .gradio-container { max-width: 1200px !important; }
26
-
27
- /* Tighten gaps between elements */
28
- .gap-4, .gap-6, .gap-8 { gap: 0.5rem !important; }
29
-
30
- /* Keep video panels from growing too tall */
31
- #in_video video, #out_video video { max-height: 420px; }
32
-
33
- /* Trim markdown spacing */
34
- .prose h1 { margin-bottom: 0.5rem !important; }
35
- """
36
 
37
  def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Progress(track_tqdm=True)):
38
  """
@@ -59,18 +97,7 @@ def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Prog
59
  return (out_path if out_path else None), json.dumps(diag, indent=2)
60
 
61
 
62
- with gr.Blocks(
63
- title="BackgroundFX Pro (SAM2 + MatAnyone)",
64
- theme=gr.themes.Soft(
65
- primary_hue="blue",
66
- radius_size=gr.themes.sizes.radius_sm,
67
- spacing_size=gr.themes.sizes.spacing_sm,
68
- text_size=gr.themes.sizes.text_md,
69
- ),
70
- css=CSS,
71
- fill_height=True,
72
- analytics_enabled=False
73
- ) as demo:
74
  gr.Markdown(
75
  """
76
  # 🎬 BackgroundFX Pro
@@ -82,24 +109,10 @@ def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Prog
82
  """
83
  )
84
 
85
- with gr.Row(equal_height=True):
86
  with gr.Column(scale=2):
87
- in_video = gr.Video(
88
- label="Input Video",
89
- sources=["upload"],
90
- interactive=True,
91
- elem_id="in_video",
92
- show_label=True,
93
- height=420,
94
- autoplay=False
95
- )
96
- in_bg = gr.Image(
97
- label="Background Image",
98
- type="filepath",
99
- interactive=True,
100
- image_mode="RGB",
101
- show_label=True
102
- )
103
  with gr.Column(scale=1):
104
  point_x = gr.Number(label="Foreground point X (optional)", value=None, precision=0)
105
  point_y = gr.Number(label="Foreground point Y (optional)", value=None, precision=0)
@@ -107,18 +120,11 @@ def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Prog
107
  process_btn = gr.Button("Process", variant="primary")
108
 
109
  with gr.Row():
110
- out_video = gr.Video(
111
- label="Output (H.264 MP4)",
112
- elem_id="out_video",
113
- height=420,
114
- autoplay=False,
115
- show_download_button=True
116
- )
117
- out_diag = gr.JSON(label="Diagnostics", show_label=True)
118
 
119
  def _on_click(video, bg, px, py, auto):
120
  v, d = _process_entry(video, bg, px, py, auto)
121
- # Gradio's Video output expects a filepath; JSON expects dict (we have string)
122
  try:
123
  d_dict = json.loads(d)
124
  except Exception:
@@ -135,5 +141,5 @@ def _on_click(video, bg, px, py, auto):
135
  # Dynamic host/port via env; suitable defaults for Hugging Face Spaces
136
  host = os.environ.get("HOST", "0.0.0.0")
137
  port = int(os.environ.get("PORT", "7860"))
138
- # Gradio 5.x: no concurrency_count
139
  demo.queue(max_size=16).launch(server_name=host, server_port=port, show_error=True)
 
13
 
14
  import os
15
  import json
16
+ import logging
17
+ import subprocess
18
  from pathlib import Path
19
+ from typing import Optional, Tuple
20
+
21
+ import gradio as gr
22
 
23
+ # --------------------------------------------------------------------------------------
24
+ # Early GPU/perf diagnostics (IMPORT FIRST so logs show even if pipeline import fails)
25
+ # --------------------------------------------------------------------------------------
26
+ logger = logging.getLogger("backgroundfx_pro")
27
+ if not logger.handlers:
28
+ h = logging.StreamHandler()
29
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
30
+ logger.addHandler(h)
31
+ logger.setLevel(logging.INFO)
32
+
33
+ # Try to load perf_tuning (forces CUDA or warns; sets cuDNN/TF32; logs banner)
34
+ try:
35
+ import perf_tuning # noqa: F401
36
+ logger.info("perf_tuning imported successfully.")
37
+ except Exception as e:
38
+ logger.warning(f"perf_tuning not loaded: {e}")
39
+
40
+ def _log_gpu_diag():
41
+ # Torch info
42
+ try:
43
+ import torch
44
+ logger.info(f"torch.__version__={torch.__version__} | torch.version.cuda={getattr(torch.version, 'cuda', None)}")
45
+ logger.info(f"torch.cuda.is_available()={torch.cuda.is_available()}")
46
+ if torch.cuda.is_available():
47
+ try:
48
+ idx = torch.cuda.current_device()
49
+ name = torch.cuda.get_device_name(idx)
50
+ cap = torch.cuda.get_device_capability(idx)
51
+ logger.info(f"Current CUDA device: {idx} | {name} | cc {cap[0]}.{cap[1]}")
52
+ except Exception as e:
53
+ logger.info(f"CUDA device query failed: {e}")
54
+ except Exception as e:
55
+ logger.warning(f"Could not import torch for GPU diag: {e}")
56
+
57
+ # nvidia-smi
58
+ try:
59
+ out = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True)
60
+ if out.returncode == 0:
61
+ logger.info("nvidia-smi -L:\n" + out.stdout.strip())
62
+ else:
63
+ logger.warning("nvidia-smi -L failed or unavailable.")
64
+ except Exception as e:
65
+ logger.warning(f"nvidia-smi not runnable: {e}")
66
+
67
+ _log_gpu_diag()
68
+
69
+ # --------------------------------------------------------------------------------------
70
  # Local pipeline
71
+ # --------------------------------------------------------------------------------------
72
  import pipeline
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Progress(track_tqdm=True)):
76
  """
 
97
  return (out_path if out_path else None), json.dumps(diag, indent=2)
98
 
99
 
100
+ with gr.Blocks(title="BackgroundFX Pro (SAM2 + MatAnyone)", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
101
  gr.Markdown(
102
  """
103
  # 🎬 BackgroundFX Pro
 
109
  """
110
  )
111
 
112
+ with gr.Row():
113
  with gr.Column(scale=2):
114
+ in_video = gr.Video(label="Input Video", sources=["upload"], interactive=True)
115
+ in_bg = gr.Image(label="Background Image", type="filepath", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with gr.Column(scale=1):
117
  point_x = gr.Number(label="Foreground point X (optional)", value=None, precision=0)
118
  point_y = gr.Number(label="Foreground point Y (optional)", value=None, precision=0)
 
120
  process_btn = gr.Button("Process", variant="primary")
121
 
122
  with gr.Row():
123
+ out_video = gr.Video(label="Output (H.264 MP4)")
124
+ out_diag = gr.JSON(label="Diagnostics")
 
 
 
 
 
 
125
 
126
  def _on_click(video, bg, px, py, auto):
127
  v, d = _process_entry(video, bg, px, py, auto)
 
128
  try:
129
  d_dict = json.loads(d)
130
  except Exception:
 
141
  # Dynamic host/port via env; suitable defaults for Hugging Face Spaces
142
  host = os.environ.get("HOST", "0.0.0.0")
143
  port = int(os.environ.get("PORT", "7860"))
144
+ # NOTE: gradio>=5 removed concurrency_count; use max_size only
145
  demo.queue(max_size=16).launch(server_name=host, server_port=port, show_error=True)
pipeline.py CHANGED
@@ -12,12 +12,12 @@
12
  - Fallbacks: MediaPipe SelfieSegmentation → else OpenCV GrabCut
13
  - H.264 MP4 output (ffmpeg when available; OpenCV fallback)
14
  - Audio mux: original audio copied into final output (AAC) if present
15
- - NEW: Stage-A transparent export (VP9 with alpha or checkerboard preview)
16
 
17
  Environment knobs (all optional):
18
  - THIRD_PARTY_SAM2_DIR, THIRD_PARTY_MATANY_DIR
19
  - SAM2_MODEL_CFG, SAM2_CHECKPOINT, SAM2_DEVICE
20
- - MATANY_REPO_ID, MATANY_CHECKPOINT, MATANY_DEVICE
21
  - FFMPEG_BIN
22
  - REFINE_GRABCUT=1 | 0 (enable/disable seed mask GrabCut refinement)
23
  - REFINE_GRABCUT_ITERS=2 (GrabCut iterations)
@@ -39,6 +39,7 @@
39
  import tempfile
40
  import logging
41
  import subprocess
 
42
  from pathlib import Path
43
  from typing import Optional, Tuple, Dict, Any, Union
44
 
@@ -405,9 +406,15 @@ def _build_stage_a_checkerboard_from_mask(
405
  return ok_any
406
 
407
  # --------------------------------------------------------------------------------------
408
- # SAM2 Integration
409
  # --------------------------------------------------------------------------------------
410
  def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
 
 
 
 
 
 
411
  meta = {"sam2_import_ok": False, "sam2_init_ok": False}
412
  try:
413
  from sam2.build_sam import build_sam2 # type: ignore
@@ -422,12 +429,50 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
422
  ckpt = os.environ.get("SAM2_CHECKPOINT", "")
423
 
424
  try:
425
- sam = build_sam2(checkpoint=ckpt if ckpt else None, model_cfg=cfg, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  predictor = SAM2ImagePredictor(sam)
427
- meta["sam2_init_ok"] = True
428
- meta["sam2_device"] = device
429
- meta["sam2_cfg"] = cfg
430
- meta["sam2_ckpt"] = ckpt or "(repo default)"
 
 
431
  return predictor, True, meta
432
  except Exception as e:
433
  logger.error(f"SAM2 init failed: {e}")
@@ -508,10 +553,23 @@ def _refine_mask_grabcut(image_bgr: np.ndarray,
508
  return m
509
 
510
  # --------------------------------------------------------------------------------------
511
- # MatAnyone Integration
512
  # --------------------------------------------------------------------------------------
513
  def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
 
 
 
 
 
 
514
  meta = {"matany_import_ok": False, "matany_init_ok": False}
 
 
 
 
 
 
 
515
  try:
516
  try:
517
  from inference_core import InferenceCore # type: ignore
@@ -526,6 +584,20 @@ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
526
  repo_id = os.environ.get("MATANY_REPO_ID", "")
527
  ckpt = os.environ.get("MATANY_CHECKPOINT", "")
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  candidates = [
530
  {"kwargs": {"repo_id": repo_id or None, "checkpoint": ckpt or None, "device": device}},
531
  {"kwargs": {"checkpoint": ckpt or None, "device": device}},
@@ -544,7 +616,7 @@ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
544
  last_err = e
545
  continue
546
 
547
- logger.error(f"MatAnyone init failed: {last_err}")
548
  return None, False, meta
549
 
550
  def run_matany(matany: object,
 
12
  - Fallbacks: MediaPipe SelfieSegmentation → else OpenCV GrabCut
13
  - H.264 MP4 output (ffmpeg when available; OpenCV fallback)
14
  - Audio mux: original audio copied into final output (AAC) if present
15
+ - Stage-A transparent export (VP9 with alpha or checkerboard preview)
16
 
17
  Environment knobs (all optional):
18
  - THIRD_PARTY_SAM2_DIR, THIRD_PARTY_MATANY_DIR
19
  - SAM2_MODEL_CFG, SAM2_CHECKPOINT, SAM2_DEVICE
20
+ - MATANY_REPO_ID, MATANY_CHECKPOINT, MATANY_DEVICE, ENABLE_MATANY=1|0
21
  - FFMPEG_BIN
22
  - REFINE_GRABCUT=1 | 0 (enable/disable seed mask GrabCut refinement)
23
  - REFINE_GRABCUT_ITERS=2 (GrabCut iterations)
 
39
  import tempfile
40
  import logging
41
  import subprocess
42
+ import inspect
43
  from pathlib import Path
44
  from typing import Optional, Tuple, Dict, Any, Union
45
 
 
406
  return ok_any
407
 
408
  # --------------------------------------------------------------------------------------
409
+ # SAM2 Integration (robust to different build_sam2 signatures)
410
  # --------------------------------------------------------------------------------------
411
  def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
412
+ """
413
+ Robust SAM2 loader that adapts to different build_sam2 signatures:
414
+ - config_file vs model_cfg
415
+ - checkpoint vs ckpt_path vs weights
416
+ - optional device kwarg
417
+ """
418
  meta = {"sam2_import_ok": False, "sam2_init_ok": False}
419
  try:
420
  from sam2.build_sam import build_sam2 # type: ignore
 
429
  ckpt = os.environ.get("SAM2_CHECKPOINT", "")
430
 
431
  try:
432
+ params = set(inspect.signature(build_sam2).parameters.keys())
433
+ kwargs = {}
434
+
435
+ # Config arg
436
+ if "config_file" in params:
437
+ kwargs["config_file"] = cfg
438
+ elif "model_cfg" in params:
439
+ kwargs["model_cfg"] = cfg
440
+ else:
441
+ # if neither is present, try positional later
442
+ pass
443
+
444
+ # Checkpoint arg
445
+ if ckpt:
446
+ if "checkpoint" in params:
447
+ kwargs["checkpoint"] = ckpt
448
+ elif "ckpt_path" in params:
449
+ kwargs["ckpt_path"] = ckpt
450
+ elif "weights" in params:
451
+ kwargs["weights"] = ckpt
452
+
453
+ # Device (if supported via kwarg)
454
+ if "device" in params:
455
+ kwargs["device"] = device
456
+
457
+ # Try keyword call first
458
+ try:
459
+ sam = build_sam2(**kwargs)
460
+ except TypeError:
461
+ # Fallback to positional (cfg, ckpt?, device?)
462
+ pos = [cfg]
463
+ if ckpt:
464
+ pos.append(ckpt)
465
+ if "device" not in kwargs:
466
+ pos.append(device)
467
+ sam = build_sam2(*pos)
468
+
469
  predictor = SAM2ImagePredictor(sam)
470
+ meta.update({
471
+ "sam2_init_ok": True,
472
+ "sam2_device": device,
473
+ "sam2_cfg": cfg,
474
+ "sam2_ckpt": ckpt or "(repo default)"
475
+ })
476
  return predictor, True, meta
477
  except Exception as e:
478
  logger.error(f"SAM2 init failed: {e}")
 
553
  return m
554
 
555
  # --------------------------------------------------------------------------------------
556
+ # MatAnyone Integration (robust + disable switch)
557
  # --------------------------------------------------------------------------------------
558
  def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
559
+ """
560
+ MatAnyone loader that:
561
+ - Skips if ENABLE_MATANY=0
562
+ - Detects forks that require a `network` arg and exits cleanly with diagnostics
563
+ - Otherwise tries repo/checkpoint style constructors
564
+ """
565
  meta = {"matany_import_ok": False, "matany_init_ok": False}
566
+
567
+ enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
568
+ if enable_env in {"0", "false", "off", "no"}:
569
+ logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
570
+ meta["disabled"] = True
571
+ return None, False, meta
572
+
573
  try:
574
  try:
575
  from inference_core import InferenceCore # type: ignore
 
584
  repo_id = os.environ.get("MATANY_REPO_ID", "")
585
  ckpt = os.environ.get("MATANY_CHECKPOINT", "")
586
 
587
+ # If this fork needs a prebuilt network, tell the user and skip
588
+ try:
589
+ sig = inspect.signature(InferenceCore)
590
+ if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
591
+ logger.error(
592
+ "This MatAnyone fork expects `InferenceCore(network=...)`. "
593
+ "Pin a fork/commit that supplies a checkpoint-based constructor, "
594
+ "or set ENABLE_MATANY=0 to skip."
595
+ )
596
+ meta["needs_network_arg"] = True
597
+ return None, False, meta
598
+ except Exception:
599
+ pass
600
+
601
  candidates = [
602
  {"kwargs": {"repo_id": repo_id or None, "checkpoint": ckpt or None, "device": device}},
603
  {"kwargs": {"checkpoint": ckpt or None, "device": device}},
 
616
  last_err = e
617
  continue
618
 
619
+ logger.error(f"MatAnyone init failed with all fallbacks: {last_err}")
620
  return None, False, meta
621
 
622
  def run_matany(matany: object,