MogensR commited on
Commit
c30b921
·
1 Parent(s): bc9a991

final run?

Browse files
Files changed (8) hide show
  1. .dockerignore +11 -24
  2. Dockerfile +36 -50
  3. app.py +77 -110
  4. models/__init__.py +66 -53
  5. perf_tuning.py +99 -49
  6. pipeline.py +180 -111
  7. requirements.txt +4 -3
  8. ui.py +95 -63
.dockerignore CHANGED
@@ -2,16 +2,12 @@
2
  # .dockerignore for HF Spaces
3
  # ===========================
4
 
5
- # ---------------------------
6
- # VCS (never needed in image)
7
- # ---------------------------
8
  .git
9
  .gitignore
10
  .gitattributes
11
 
12
- # ---------------------------
13
- # Python cache / build files
14
- # ---------------------------
15
  __pycache__/
16
  *.py[cod]
17
  *.pyo
@@ -20,49 +16,41 @@ __pycache__/
20
  *.egg-info/
21
  dist/
22
  build/
 
 
23
 
24
- # ---------------------------
25
  # Virtual environments
26
- # ---------------------------
27
  .env
28
  .venv/
29
  env/
30
  venv/
31
 
32
- # ---------------------------
33
  # External repos (cloned in Docker, not copied from local)
34
- # ---------------------------
35
  third_party/
36
 
37
- # ---------------------------
38
- # Hugging Face / Torch caches (but allow model files that might be needed)
39
- # ---------------------------
40
  .cache/
41
  huggingface/
42
  torch/
43
  data/
44
 
45
- # ---------------------------
46
  # HF Space metadata/state
47
- # ---------------------------
48
  .hf_space/
49
  space.log
50
  gradio_cached_examples/
51
  gradio_static/
52
  __outputs__/
53
 
54
- # ---------------------------
55
  # Logs & temp files
56
- # ---------------------------
57
  *.log
58
  logs/
59
  tmp/
60
  temp/
61
  *.swp
 
 
62
 
63
- # ---------------------------
64
  # Media test assets
65
- # ---------------------------
66
  *.mp4
67
  *.avi
68
  *.mov
@@ -72,9 +60,7 @@ temp/
72
  *.jpeg
73
  *.gif
74
 
75
- # ---------------------------
76
  # OS / IDE cruft
77
- # ---------------------------
78
  .DS_Store
79
  Thumbs.db
80
  .vscode/
@@ -82,10 +68,11 @@ Thumbs.db
82
  *.sublime-project
83
  *.sublime-workspace
84
 
85
- # ---------------------------
86
  # Node / frontend (if present)
87
- # ---------------------------
88
  node_modules/
89
  npm-debug.log
90
  yarn-debug.log
91
- yarn-error.log
 
 
 
 
2
  # .dockerignore for HF Spaces
3
  # ===========================
4
 
5
+ # VCS
 
 
6
  .git
7
  .gitignore
8
  .gitattributes
9
 
10
+ # Python cache / build
 
 
11
  __pycache__/
12
  *.py[cod]
13
  *.pyo
 
16
  *.egg-info/
17
  dist/
18
  build/
19
+ .pytest_cache/
20
+ .python-version
21
 
 
22
  # Virtual environments
 
23
  .env
24
  .venv/
25
  env/
26
  venv/
27
 
 
28
  # External repos (cloned in Docker, not copied from local)
 
29
  third_party/
30
 
31
+ # Hugging Face / Torch caches
 
 
32
  .cache/
33
  huggingface/
34
  torch/
35
  data/
36
 
 
37
  # HF Space metadata/state
 
38
  .hf_space/
39
  space.log
40
  gradio_cached_examples/
41
  gradio_static/
42
  __outputs__/
43
 
 
44
  # Logs & temp files
 
45
  *.log
46
  logs/
47
  tmp/
48
  temp/
49
  *.swp
50
+ .coverage
51
+ coverage.xml
52
 
 
53
  # Media test assets
 
54
  *.mp4
55
  *.avi
56
  *.mov
 
60
  *.jpeg
61
  *.gif
62
 
 
63
  # OS / IDE cruft
 
64
  .DS_Store
65
  Thumbs.db
66
  .vscode/
 
68
  *.sublime-project
69
  *.sublime-workspace
70
 
 
71
  # Node / frontend (if present)
 
72
  node_modules/
73
  npm-debug.log
74
  yarn-debug.log
75
+ yarn-error.log
76
+
77
+ # ---- Optional: allow specific checkpoints if needed ----
78
+ !checkpoints/
Dockerfile CHANGED
@@ -1,119 +1,105 @@
1
  # ===============================
2
- # BackgroundFX Pro — Dockerfile (Updated with Debug)
3
- # Hugging Face Spaces Pro (GPU)
4
  # ===============================
5
 
6
- # CUDA base image (T4-friendly). Build stage has NO GPU access.
7
- FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
8
 
9
- # --- Build args (override in Space Settings → Build args) ---
10
- # Pin external repos for reproducible builds
11
  ARG SAM2_SHA=__PIN_ME__
12
  ARG MATANYONE_SHA=__PIN_ME__
13
-
14
- # (legacy/optional) Model IDs — you can still use these elsewhere if you want
15
  ARG SAM2_MODEL_ID=facebook/sam2
16
- ARG SAM2_VARIANT=sam2_hiera_large # sam2_hiera_small | sam2_hiera_base | sam2_hiera_large
17
  ARG MATANY_REPO_ID=PeiqingYang/MatAnyone
18
  ARG MATANY_FILENAME=matanyone_v1.0.pth
19
 
20
- # --- Create non-root user (uid 1000 required by HF) ---
21
  RUN useradd -m -u 1000 user
22
  ENV HOME=/home/user
23
  ENV PATH=/home/user/.local/bin:$PATH
24
- RUN mkdir -p /home/user/app && chown -R user:user /home/user
25
  WORKDIR /home/user/app
26
 
27
  # --- System packages ---
28
  USER root
29
  RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
30
  git ffmpeg python3 python3-pip python3-venv \
 
31
  libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 \
32
  && rm -rf /var/lib/apt/lists/*
33
 
34
- # Persistent cache dir for HF weights / torch / matplotlib
35
  RUN mkdir -p /data/.cache && chown -R user:user /data
36
  USER user
37
 
38
- # --- Python & CUDA wheels (Torch cu121) ---
39
  RUN python3 -m pip install --no-cache-dir --upgrade pip
40
  RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
41
  torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1
42
 
43
- # --- App Python deps ---
44
  COPY --chown=user requirements.txt ./requirements.txt
45
  RUN python3 -m pip install --no-cache-dir -r requirements.txt
46
- # Optional (recommended) nicer fallback segmentation:
47
  RUN python3 -m pip install --no-cache-dir mediapipe==0.10.14
48
 
49
- # --- Clone external repos (SAM2 & MatAnyone) ---
50
  RUN mkdir -p third_party
51
 
52
- # SAM2
53
- RUN git clone https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
54
  cd third_party/sam2 && \
55
- if [ "${SAM2_SHA}" != "__PIN_ME__" ]; then git checkout "${SAM2_SHA}"; fi
56
 
57
- # DEBUG: Check what was actually cloned
58
- RUN echo "=== DEBUG: SAM2 directory contents ===" && \
59
  ls -la third_party/sam2/ && \
60
- echo "=== DEBUG: Config directory ===" && \
61
- ls -la third_party/sam2/configs/ || echo "configs directory not found" && \
62
  echo "=== DEBUG: SAM2 configs ===" && \
63
- ls -la third_party/sam2/configs/sam2/ || echo "sam2 configs directory not found"
64
 
65
- # Install SAM2 requirements
66
  RUN cd third_party/sam2 && python3 -m pip install --no-cache-dir -e .
67
 
68
- # MatAnyone (pq-yang fork as per your previous setup)
69
- RUN git clone https://github.com/pq-yang/MatAnyone.git third_party/matanyone && \
70
  cd third_party/matanyone && \
71
- if [ "${MATANYONE_SHA}" != "__PIN_ME__" ]; then git checkout "${MATANYONE_SHA}"; fi
72
 
73
- # Install MatAnyone requirements if they exist
74
  RUN cd third_party/matanyone && \
75
  if [ -f requirements.txt ]; then python3 -m pip install --no-cache-dir -r requirements.txt; fi
76
 
77
- # --- App code ---
78
  COPY --chown=user . /home/user/app
79
 
80
- # DEBUG: Check if app code copy overwrote the cloned repos
81
- RUN echo "=== DEBUG: After app code copy - SAM2 status ===" && \
82
  ls -la third_party/sam2/ && \
83
- echo "=== DEBUG: Config files after copy ===" && \
84
- ls -la third_party/sam2/configs/sam2/ || echo "Config directory missing after copy"
85
 
86
- # --- Runtime environment (aligned with pipeline.py) ---
87
  ENV PYTHONUNBUFFERED=1 \
88
- OMP_NUM_THREADS=2 \
89
  TOKENIZERS_PARALLELISM=false \
90
  HF_HOME=/data/.cache/huggingface \
91
  TORCH_HOME=/data/.cache/torch \
92
  MPLCONFIGDIR=/data/.cache/matplotlib \
93
  PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
94
  PYTHONPATH="$PYTHONPATH:/home/user/app/third_party/sam2:/home/user/app/third_party/matanyone" \
95
- PORT=7860 \
96
  FFMPEG_BIN=ffmpeg \
97
- \
98
- # Let pipeline.py discover these dynamically (no hard-coded paths)
99
  THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
100
  THIRD_PARTY_MATANY_DIR=/home/user/app/third_party/matanyone \
101
- \
102
- # --- SAM2 dynamic config (FIXED: relative path within SAM2 repo) ---
103
  SAM2_MODEL_CFG="configs/sam2/sam2_hiera_l.yaml" \
104
  SAM2_CHECKPOINT="" \
105
- \
106
- # --- MatAnyone dynamic config (used by pipeline.py) ---
107
  MATANY_REPO_ID=PeiqingYang/MatAnyone \
108
  MATANY_CHECKPOINT="" \
109
  ENABLE_MATANY=1
110
 
111
- # DEBUG: Final check of SAM2 installation
112
- RUN echo "=== FINAL DEBUG: SAM2 status ===" && \
113
- pwd && \
114
- ls -la /home/user/app/third_party/sam2/ || echo "SAM2 directory missing" && \
115
- ls -la /home/user/app/third_party/sam2/configs/sam2/ || echo "Config dir missing"
116
 
117
- # --- Networking / Entrypoint ---
118
  EXPOSE 7860
119
- CMD ["python3", "app.py"]
 
 
 
 
 
 
1
  # ===============================
2
+ # BackgroundFX Pro — Dockerfile (Hardened for Spaces GPU)
 
3
  # ===============================
4
 
5
+ # Match PyTorch cu121 wheels (critical to avoid CUDA probe stalls)
6
+ FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
7
 
8
+ # --- Build args (optional pins) ---
 
9
  ARG SAM2_SHA=__PIN_ME__
10
  ARG MATANYONE_SHA=__PIN_ME__
 
 
11
  ARG SAM2_MODEL_ID=facebook/sam2
12
+ ARG SAM2_VARIANT=sam2_hiera_large
13
  ARG MATANY_REPO_ID=PeiqingYang/MatAnyone
14
  ARG MATANY_FILENAME=matanyone_v1.0.pth
15
 
16
+ # --- Non-root user (HF expects uid 1000) ---
17
  RUN useradd -m -u 1000 user
18
  ENV HOME=/home/user
19
  ENV PATH=/home/user/.local/bin:$PATH
 
20
  WORKDIR /home/user/app
21
 
22
  # --- System packages ---
23
  USER root
24
  RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
25
  git ffmpeg python3 python3-pip python3-venv \
26
+ wget curl ca-certificates \
27
  libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 \
28
  && rm -rf /var/lib/apt/lists/*
29
 
30
+ # Caches (writable)
31
  RUN mkdir -p /data/.cache && chown -R user:user /data
32
  USER user
33
 
34
+ # --- Python + Torch (cu121) ---
35
  RUN python3 -m pip install --no-cache-dir --upgrade pip
36
  RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
37
  torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1
38
 
39
+ # --- App deps ---
40
  COPY --chown=user requirements.txt ./requirements.txt
41
  RUN python3 -m pip install --no-cache-dir -r requirements.txt
42
+ # Optional nice fallback
43
  RUN python3 -m pip install --no-cache-dir mediapipe==0.10.14
44
 
45
+ # --- Third-party repos (build-time, never at runtime) ---
46
  RUN mkdir -p third_party
47
 
48
+ # SAM2 (shallow clone; optional SHA pin)
49
+ RUN git clone --depth=1 https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
50
  cd third_party/sam2 && \
51
+ if [ "${SAM2_SHA}" != "__PIN_ME__" ]; then git fetch --depth=1 origin ${SAM2_SHA} && git checkout ${SAM2_SHA}; fi
52
 
53
+ # Show what we got
54
+ RUN echo "=== DEBUG: SAM2 contents ===" && \
55
  ls -la third_party/sam2/ && \
 
 
56
  echo "=== DEBUG: SAM2 configs ===" && \
57
+ (ls -la third_party/sam2/configs/sam2/ || echo "configs missing")
58
 
59
+ # Install SAM2 (editable ok)
60
  RUN cd third_party/sam2 && python3 -m pip install --no-cache-dir -e .
61
 
62
+ # MatAnyone (pq-yang fork per your setup)
63
+ RUN git clone --depth=1 https://github.com/pq-yang/MatAnyone.git third_party/matanyone && \
64
  cd third_party/matanyone && \
65
+ if [ "${MATANYONE_SHA}" != "__PIN_ME__" ]; then git fetch --depth=1 origin ${MATANYONE_SHA} && git checkout ${MATANYONE_SHA}; fi
66
 
67
+ # Install MatAnyone requirements if present
68
  RUN cd third_party/matanyone && \
69
  if [ -f requirements.txt ]; then python3 -m pip install --no-cache-dir -r requirements.txt; fi
70
 
71
+ # --- App code last (so code changes don't invalidate heavy layers) ---
72
  COPY --chown=user . /home/user/app
73
 
74
+ # Verify clone not overwritten by COPY
75
+ RUN echo "=== DEBUG: After COPY ===" && \
76
  ls -la third_party/sam2/ && \
77
+ (ls -la third_party/sam2/configs/sam2/ || echo "SAM2 configs missing")
 
78
 
79
+ # --- Runtime environment ---
80
  ENV PYTHONUNBUFFERED=1 \
81
+ OMP_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 NUMEXPR_NUM_THREADS=1 \
82
  TOKENIZERS_PARALLELISM=false \
83
  HF_HOME=/data/.cache/huggingface \
84
  TORCH_HOME=/data/.cache/torch \
85
  MPLCONFIGDIR=/data/.cache/matplotlib \
86
  PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
87
  PYTHONPATH="$PYTHONPATH:/home/user/app/third_party/sam2:/home/user/app/third_party/matanyone" \
 
88
  FFMPEG_BIN=ffmpeg \
 
 
89
  THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
90
  THIRD_PARTY_MATANY_DIR=/home/user/app/third_party/matanyone \
 
 
91
  SAM2_MODEL_CFG="configs/sam2/sam2_hiera_l.yaml" \
92
  SAM2_CHECKPOINT="" \
 
 
93
  MATANY_REPO_ID=PeiqingYang/MatAnyone \
94
  MATANY_CHECKPOINT="" \
95
  ENABLE_MATANY=1
96
 
97
+ # Do NOT set PORT here; Spaces injects it.
 
 
 
 
98
 
 
99
  EXPOSE 7860
100
+
101
+ # Optional: basic health check to see if the server bound
102
+ HEALTHCHECK --interval=30s --timeout=5s --retries=5 CMD wget -qO- "http://127.0.0.1:${PORT:-7860}/" || exit 1
103
+
104
+ # Use exec form + unbuffered
105
+ CMD ["python3","-u","app.py"]
app.py CHANGED
@@ -1,28 +1,21 @@
1
- # app.py
2
  #!/usr/bin/env python3
3
  """
4
- BackgroundFX Pro - Gradio App (dynamic, GPU-ready, no hard-coded checkpoints)
5
- =============================================================================
6
-
7
- - Uses pipeline.process() which orchestrates:
8
- SAM2 first-frame segmentation → MatAnyone temporal matting → compositing
9
- - Robust fallbacks (MediaPipe / GrabCut; static-mask compositing)
10
- - Diagnostics JSON shows which engines ran and on which device
11
- - All paths/devices set by environment variables (see pipeline.py header)
12
  """
13
 
14
  import os
15
- import json
16
  import logging
 
 
17
  import subprocess
18
- from pathlib import Path
19
- from typing import Optional, Tuple
20
 
21
  import gradio as gr
22
 
23
- # --------------------------------------------------------------------------------------
24
- # Early GPU/perf diagnostics (IMPORT FIRST so logs show even if pipeline import fails)
25
- # --------------------------------------------------------------------------------------
26
  logger = logging.getLogger("backgroundfx_pro")
27
  if not logger.handlers:
28
  h = logging.StreamHandler()
@@ -30,116 +23,90 @@
30
  logger.addHandler(h)
31
  logger.setLevel(logging.INFO)
32
 
33
- # Try to load perf_tuning (forces CUDA or warns; sets cuDNN/TF32; logs banner)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
  import perf_tuning # noqa: F401
36
  logger.info("perf_tuning imported successfully.")
37
  except Exception as e:
38
- logger.warning(f"perf_tuning not loaded: {e}")
 
 
39
 
40
- def _log_gpu_diag():
41
- # Torch info
 
 
42
  try:
43
  import torch
44
- logger.info(f"torch.__version__={torch.__version__} | torch.version.cuda={getattr(torch.version, 'cuda', None)}")
45
- logger.info(f"torch.cuda.is_available()={torch.cuda.is_available()}")
46
- if torch.cuda.is_available():
 
 
 
 
47
  try:
48
  idx = torch.cuda.current_device()
49
  name = torch.cuda.get_device_name(idx)
50
  cap = torch.cuda.get_device_capability(idx)
51
- logger.info(f"Current CUDA device: {idx} | {name} | cc {cap[0]}.{cap[1]}")
52
  except Exception as e:
53
- logger.info(f"CUDA device query failed: {e}")
54
- except Exception as e:
55
- logger.warning(f"Could not import torch for GPU diag: {e}")
56
-
57
- # nvidia-smi
58
- try:
59
- out = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True)
60
- if out.returncode == 0:
61
- logger.info("nvidia-smi -L:\n" + out.stdout.strip())
62
- else:
63
- logger.warning("nvidia-smi -L failed or unavailable.")
64
  except Exception as e:
65
- logger.warning(f"nvidia-smi not runnable: {e}")
66
 
67
- _log_gpu_diag()
68
-
69
- # --------------------------------------------------------------------------------------
70
- # Local pipeline
71
- # --------------------------------------------------------------------------------------
72
- import pipeline
73
-
74
-
75
- def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Progress(track_tqdm=True)):
76
- """
77
- Gradio wrapper → returns (video_path, diagnostics_json_str)
78
- """
79
- if video is None or bg_image is None:
80
- return None, json.dumps({"error": "Please provide both a video and a background image."}, indent=2)
81
-
82
- # Gradio can pass dict-like objects for file with 'name' key, normalize to path
83
- vpath = video if isinstance(video, (str, Path)) else getattr(video, "name", None) or video.get("name")
84
- bpath = bg_image if isinstance(bg_image, (str, Path)) else getattr(bg_image, "name", None) or bg_image.get("name")
85
-
86
- progress(0.05, desc="Starting…")
87
- out_path, diag = pipeline.process(
88
- video_path=vpath,
89
- bg_image_path=bpath,
90
- point_x=point_x if point_x not in (None, "") else None,
91
- point_y=point_y if point_y not in (None, "") else None,
92
- auto_box=bool(auto_box),
93
- work_dir=None # pipeline will create a temp dir
94
- )
95
- progress(0.95, desc="Finalizing…")
96
-
97
- return (out_path if out_path else None), json.dumps(diag, indent=2)
98
-
99
-
100
- with gr.Blocks(title="BackgroundFX Pro (SAM2 + MatAnyone)", theme=gr.themes.Soft()) as demo:
101
- gr.Markdown(
102
- """
103
- # 🎬 BackgroundFX Pro
104
- **SAM2 + MatAnyone** with robust fallbacks. All configs/devices are dynamic via environment variables.
105
-
106
- - Upload a video and a background image.
107
- - Optionally provide a foreground point (x, y) in pixels for the first frame **or** tick *Auto subject box*.
108
- - Click **Process**. The app will try SAM2 → MatAnyone; if anything fails, it falls back automatically.
109
- """
110
- )
111
-
112
- with gr.Row():
113
- with gr.Column(scale=2):
114
- in_video = gr.Video(label="Input Video", sources=["upload"], interactive=True)
115
- in_bg = gr.Image(label="Background Image", type="filepath", interactive=True)
116
- with gr.Column(scale=1):
117
- point_x = gr.Number(label="Foreground point X (optional)", value=None, precision=0)
118
- point_y = gr.Number(label="Foreground point Y (optional)", value=None, precision=0)
119
- auto_box = gr.Checkbox(label="Auto subject box (ignore point)", value=True)
120
- process_btn = gr.Button("Process", variant="primary")
121
-
122
- with gr.Row():
123
- out_video = gr.Video(label="Output (H.264 MP4)")
124
- out_diag = gr.JSON(label="Diagnostics")
125
-
126
- def _on_click(video, bg, px, py, auto):
127
- v, d = _process_entry(video, bg, px, py, auto)
128
- try:
129
- d_dict = json.loads(d)
130
- except Exception:
131
- d_dict = {"raw": d}
132
- return v, d_dict
133
-
134
- process_btn.click(
135
- _on_click,
136
- inputs=[in_video, in_bg, point_x, point_y, auto_box],
137
- outputs=[out_video, out_diag]
138
- )
139
 
140
  if __name__ == "__main__":
141
- # Dynamic host/port via env; suitable defaults for Hugging Face Spaces
142
  host = os.environ.get("HOST", "0.0.0.0")
143
  port = int(os.environ.get("PORT", "7860"))
144
- # NOTE: gradio>=5 removed concurrency_count; use max_size only
145
- demo.queue(max_size=16).launch(server_name=host, server_port=port, show_error=True)
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro App Entrypoint (UI separated)
4
+ - UI is built in ui.py (create_interface)
5
+ - Hardened startup: heartbeat, safe diag, bind to $PORT
 
 
 
 
 
6
  """
7
 
8
  import os
 
9
  import logging
10
+ import threading
11
+ import time
12
  import subprocess
 
 
13
 
14
  import gradio as gr
15
 
16
+ # -----------------------------------------------------------------------------
17
+ # Logging early
18
+ # -----------------------------------------------------------------------------
19
  logger = logging.getLogger("backgroundfx_pro")
20
  if not logger.handlers:
21
  h = logging.StreamHandler()
 
23
  logger.addHandler(h)
24
  logger.setLevel(logging.INFO)
25
 
26
+ # Heartbeat so logs never go silent during startup/imports
27
+ def _heartbeat():
28
+ i = 0
29
+ while True:
30
+ i += 1
31
+ print(f"[startup-heartbeat] {i*5}s…", flush=True)
32
+ time.sleep(5)
33
+
34
+ threading.Thread(target=_heartbeat, daemon=True).start()
35
+
36
+ # -----------------------------------------------------------------------------
37
+ # Safe, minimal startup diagnostics (no long CUDA probes)
38
+ # -----------------------------------------------------------------------------
39
+ def _safe_startup_diag():
40
+ # Torch version only; defer CUDA availability checks to post-launch
41
+ try:
42
+ import torch # noqa: F401
43
+ import importlib
44
+ t = importlib.import_module("torch")
45
+ logger.info("torch imported: %s | torch.version.cuda=%s",
46
+ getattr(t, "__version__", "?"),
47
+ getattr(getattr(t, "version", None), "cuda", None))
48
+ except Exception as e:
49
+ logger.warning("Torch not available at startup: %s", e)
50
+
51
+ # nvidia-smi with short timeout (avoid indefinite block)
52
+ try:
53
+ out = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=2)
54
+ if out.returncode == 0:
55
+ logger.info("nvidia-smi -L:\n%s", out.stdout.strip())
56
+ else:
57
+ logger.warning("nvidia-smi -L failed or unavailable (rc=%s).", out.returncode)
58
+ except subprocess.TimeoutExpired:
59
+ logger.warning("nvidia-smi -L timed out (skipping).")
60
+ except Exception as e:
61
+ logger.warning("nvidia-smi not runnable: %s", e)
62
+
63
+ # Optional perf tuning; never block startup
64
  try:
65
  import perf_tuning # noqa: F401
66
  logger.info("perf_tuning imported successfully.")
67
  except Exception as e:
68
+ logger.warning("perf_tuning not loaded: %s", e)
69
+
70
+ _safe_startup_diag()
71
 
72
+ # -----------------------------------------------------------------------------
73
+ # Post-launch CUDA diag in background (so it never blocks binding the port)
74
+ # -----------------------------------------------------------------------------
75
+ def _post_launch_diag():
76
  try:
77
  import torch
78
+ try:
79
+ avail = torch.cuda.is_available()
80
+ except Exception as e:
81
+ logger.warning("torch.cuda.is_available() failed: %s", e)
82
+ avail = False
83
+ logger.info("CUDA available: %s", avail)
84
+ if avail:
85
  try:
86
  idx = torch.cuda.current_device()
87
  name = torch.cuda.get_device_name(idx)
88
  cap = torch.cuda.get_device_capability(idx)
89
+ logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
90
  except Exception as e:
91
+ logger.warning("CUDA device query failed: %s", e)
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
+ logger.warning("Post-launch torch diag failed: %s", e)
94
 
95
+ # -----------------------------------------------------------------------------
96
+ # Build UI (in separate module) and launch
97
+ # -----------------------------------------------------------------------------
98
+ def build_ui() -> gr.Blocks:
99
+ # Import here so any heavy imports inside ui.py (it shouldn’t) would show up after logs are configured
100
+ from ui import create_interface
101
+ return create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if __name__ == "__main__":
 
104
  host = os.environ.get("HOST", "0.0.0.0")
105
  port = int(os.environ.get("PORT", "7860"))
106
+ logger.info("Launching Gradio on %s:%s …", host, port)
107
+
108
+ demo = build_ui()
109
+ demo.queue(max_size=16)
110
+
111
+ threading.Thread(target=_post_launch_diag, daemon=True).start()
112
+ demo.launch(server_name=host, server_port=port, show_error=True)
models/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
  #!/usr/bin/env python3
2
  """
3
- BackgroundFX Pro - Model Loading & Utilities
4
- ===========================================
5
- Contains all model loading, inference functions, and utility functions
6
- moved from the main pipeline for better organization.
 
7
  """
8
 
9
  from __future__ import annotations
@@ -19,12 +20,24 @@
19
 
20
  import numpy as np
21
  import yaml
22
- import torch # For memory management and CUDA operations
23
 
24
  # --------------------------------------------------------------------------------------
25
- # Logging
26
  # --------------------------------------------------------------------------------------
27
  logger = logging.getLogger("backgroundfx_pro")
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # --------------------------------------------------------------------------------------
30
  # Optional dependencies
@@ -38,35 +51,40 @@
38
  # --------------------------------------------------------------------------------------
39
  # Path setup for third_party repos
40
  # --------------------------------------------------------------------------------------
41
- ROOT = Path(__file__).resolve().parent.parent # Go up from models/ to project root
42
  TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
43
  TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
44
 
45
  def _add_sys_path(p: Path) -> None:
46
- p_str = str(p)
47
- if p_str not in sys.path:
48
- sys.path.insert(0, p_str)
 
 
 
49
 
50
  _add_sys_path(TP_SAM2)
51
  _add_sys_path(TP_MATANY)
52
 
53
  # --------------------------------------------------------------------------------------
54
- # Basic Utilities
55
  # --------------------------------------------------------------------------------------
56
- def _ffmpeg_bin() -> str:
57
- return os.environ.get("FFMPEG_BIN", "ffmpeg")
58
-
59
- def _probe_ffmpeg() -> bool:
60
  try:
61
- subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
62
- return True
63
- except Exception:
64
- return False
 
65
 
66
  def _has_cuda() -> bool:
 
 
 
67
  try:
68
- return torch.cuda.is_available()
69
- except Exception:
 
70
  return False
71
 
72
  def _pick_device(env_key: str) -> str:
@@ -75,6 +93,19 @@ def _pick_device(env_key: str) -> str:
75
  return requested
76
  return "cuda" if _has_cuda() else "cpu"
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def _ensure_dir(p: Path) -> None:
79
  p.mkdir(parents=True, exist_ok=True)
80
 
@@ -141,7 +172,6 @@ def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_
141
  # Compositing & Image Processing
142
  # --------------------------------------------------------------------------------------
143
  def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
144
- """Erode→dilate + gentle blur → float alpha in [0,1]."""
145
  if alpha.dtype != np.float32:
146
  a = alpha.astype(np.float32)
147
  if a.max() > 1.0:
@@ -173,7 +203,6 @@ def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
173
  return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
174
 
175
  def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
176
- """Simple light wrap from background into subject edges."""
177
  r = max(1, int(radius))
178
  inv = 1.0 - alpha01
179
  inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
@@ -181,8 +210,7 @@ def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount
181
  return lw
182
 
183
  def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
184
- """Reduce saturation in boundary band (alpha≈0.5) to remove old-background tint."""
185
- w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) # bell-shaped weight
186
  w = np.clip(w, 0.0, 1.0)
187
  hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
188
  H, S, V = cv2.split(hsv)
@@ -191,11 +219,11 @@ def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35
191
  out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
192
  return out
193
 
194
- def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
195
- erode_px: int = None, dilate_px: int = None, blur_px: float = None,
196
- lw_radius: int = None, lw_amount: float = None,
197
- despill_amount: float = None) -> np.ndarray:
198
- """Gamma-aware composite + edge refinement + light wrap + boundary de-spill."""
199
  erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
200
  dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
201
  blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
@@ -203,17 +231,11 @@ def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarr
203
  lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
204
  despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
205
 
206
- # refine alpha [0,1]
207
  a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
208
-
209
- # edge de-spill: temper saturation where a≈0.5
210
  fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
211
 
212
- # linearize for better blending
213
  fg_lin = _to_linear(fg_rgb)
214
  bg_lin = _to_linear(bg_rgb)
215
-
216
- # light wrap
217
  lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
218
  lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
219
 
@@ -233,30 +255,27 @@ def _resolve_sam2_cfg(cfg_str: str) -> str:
233
  return str(candidate)
234
  if cfg_path.exists():
235
  return str(cfg_path)
236
- # Last resort: common defaults inside the repo
237
  for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
238
  p = TP_SAM2 / name
239
  if p.exists():
240
  return str(p)
241
- return str(cfg_str) # let build_sam2 raise a clear error
242
 
243
  def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
244
  """If config references 'hieradet', try to find a 'hiera' config."""
245
  try:
246
  with open(cfg_path, "r") as f:
247
  data = yaml.safe_load(f)
248
- target = None
249
- model = data.get("model", {})
250
- enc = (model.get("image_encoder") or {})
251
- trunk = (enc.get("trunk") or {})
252
  target = trunk.get("_target_") or trunk.get("target")
253
  if isinstance(target, str) and "hieradet" in target:
254
  for y in TP_SAM2.rglob("*.yaml"):
255
  try:
256
  with open(y, "r") as f2:
257
- d2 = yaml.safe_load(f2)
258
- m2 = (d2 or {}).get("model", {})
259
- e2 = (m2.get("image_encoder") or {})
260
  t2 = (e2.get("trunk") or {})
261
  tgt2 = t2.get("_target_") or t2.get("target")
262
  if isinstance(tgt2, str) and ".hiera." in tgt2:
@@ -313,7 +332,7 @@ def _try_build(cfg_path: str):
313
  try:
314
  try:
315
  sam = _try_build(cfg)
316
- except Exception as e1:
317
  alt_cfg = _find_hiera_config_if_hieradet(cfg)
318
  if alt_cfg:
319
  logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
@@ -426,7 +445,6 @@ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
426
  repo_id = os.environ.get("MATANY_REPO_ID", "")
427
  ckpt = os.environ.get("MATANY_CHECKPOINT", "")
428
 
429
- # Check if this fork needs a prebuilt network
430
  try:
431
  sig = inspect.signature(InferenceCore)
432
  if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
@@ -656,7 +674,6 @@ def fallback_composite(video_path: Union[str, Path],
656
  # Stage-A (Transparent Export) Functions
657
  # --------------------------------------------------------------------------------------
658
  def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
659
- """RGB checkerboard for preview when no real alpha is possible."""
660
  y, x = np.mgrid[0:h, 0:w]
661
  c = ((x // tile) + (y // tile)) % 2
662
  a = np.where(c == 0, 200, 150).astype(np.uint8)
@@ -670,7 +687,6 @@ def _build_stage_a_rgba_vp9_from_fg_alpha(
670
  size: Tuple[int, int],
671
  src_audio: Optional[Union[str, Path]] = None,
672
  ) -> bool:
673
- """Merge FG+ALPHA → RGBA WebM (VP9 with alpha)."""
674
  if not _probe_ffmpeg():
675
  return False
676
  w, h = size
@@ -702,7 +718,6 @@ def _build_stage_a_rgba_vp9_from_mask(
702
  fps: int,
703
  size: Tuple[int, int],
704
  ) -> bool:
705
- """Merge original video + static mask → RGBA WebM (VP9 with alpha)."""
706
  if not _probe_ffmpeg():
707
  return False
708
  w, h = size
@@ -733,7 +748,6 @@ def _build_stage_a_checkerboard_from_fg_alpha(
733
  fps: int,
734
  size: Tuple[int, int],
735
  ) -> bool:
736
- """Preview: FG+ALPHA over checkerboard → MP4 (no real alpha)."""
737
  fg_cap = cv2.VideoCapture(str(fg_path))
738
  al_cap = cv2.VideoCapture(str(alpha_path))
739
  if not fg_cap.isOpened() or not al_cap.isOpened():
@@ -766,7 +780,6 @@ def _build_stage_a_checkerboard_from_mask(
766
  fps: int,
767
  size: Tuple[int, int],
768
  ) -> bool:
769
- """Preview: original video + static mask over checkerboard → MP4."""
770
  cap = cv2.VideoCapture(str(video_path))
771
  if not cap.isOpened():
772
  return False
@@ -790,4 +803,4 @@ def _build_stage_a_checkerboard_from_mask(
790
  finally:
791
  cap.release()
792
  writer.release()
793
- return ok_any
 
1
  #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro - Model Loading & Utilities (Hardened)
4
+ ======================================================
5
+ - Avoids heavy CUDA/Hydra work at import time
6
+ - Adds timeouts to subprocess probes
7
+ - Safer sys.path wiring for third_party repos
8
  """
9
 
10
  from __future__ import annotations
 
20
 
21
  import numpy as np
22
  import yaml
 
23
 
24
  # --------------------------------------------------------------------------------------
25
+ # Logging (ensure a handler exists very early)
26
  # --------------------------------------------------------------------------------------
27
  logger = logging.getLogger("backgroundfx_pro")
28
+ if not logger.handlers:
29
+ _h = logging.StreamHandler()
30
+ _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
31
+ logger.addHandler(_h)
32
+ logger.setLevel(logging.INFO)
33
+
34
+ # Pin OpenCV threads (helps libgomp stability in Spaces)
35
+ try:
36
+ cv_threads = int(os.environ.get("CV_THREADS", "1"))
37
+ if hasattr(cv2, "setNumThreads"):
38
+ cv2.setNumThreads(cv_threads)
39
+ except Exception:
40
+ pass
41
 
42
  # --------------------------------------------------------------------------------------
43
  # Optional dependencies
 
51
  # --------------------------------------------------------------------------------------
52
  # Path setup for third_party repos
53
  # --------------------------------------------------------------------------------------
54
+ ROOT = Path(__file__).resolve().parent.parent # project root
55
  TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
56
  TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
57
 
58
  def _add_sys_path(p: Path) -> None:
59
+ if p.exists():
60
+ p_str = str(p)
61
+ if p_str not in sys.path:
62
+ sys.path.insert(0, p_str)
63
+ else:
64
+ logger.warning(f"third_party path not found: {p}")
65
 
66
  _add_sys_path(TP_SAM2)
67
  _add_sys_path(TP_MATANY)
68
 
69
  # --------------------------------------------------------------------------------------
70
+ # Safe Torch accessors (no top-level import)
71
  # --------------------------------------------------------------------------------------
72
+ def _torch():
 
 
 
73
  try:
74
+ import torch # local import avoids early CUDA init during module import
75
+ return torch
76
+ except Exception as e:
77
+ logger.warning(f"[models.safe-torch] import failed: {e}")
78
+ return None
79
 
80
  def _has_cuda() -> bool:
81
+ t = _torch()
82
+ if t is None:
83
+ return False
84
  try:
85
+ return bool(t.cuda.is_available())
86
+ except Exception as e:
87
+ logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}")
88
  return False
89
 
90
  def _pick_device(env_key: str) -> str:
 
93
  return requested
94
  return "cuda" if _has_cuda() else "cpu"
95
 
96
+ # --------------------------------------------------------------------------------------
97
+ # Basic Utilities
98
+ # --------------------------------------------------------------------------------------
99
+ def _ffmpeg_bin() -> str:
100
+ return os.environ.get("FFMPEG_BIN", "ffmpeg")
101
+
102
+ def _probe_ffmpeg(timeout: int = 2) -> bool:
103
+ try:
104
+ subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout)
105
+ return True
106
+ except Exception:
107
+ return False
108
+
109
  def _ensure_dir(p: Path) -> None:
110
  p.mkdir(parents=True, exist_ok=True)
111
 
 
172
  # Compositing & Image Processing
173
  # --------------------------------------------------------------------------------------
174
  def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
 
175
  if alpha.dtype != np.float32:
176
  a = alpha.astype(np.float32)
177
  if a.max() > 1.0:
 
203
  return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
204
 
205
  def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
 
206
  r = max(1, int(radius))
207
  inv = 1.0 - alpha01
208
  inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
 
210
  return lw
211
 
212
  def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
213
+ w = 1.0 - 2.0 * np.abs(alpha01 - 0.5)
 
214
  w = np.clip(w, 0.0, 1.0)
215
  hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
216
  H, S, V = cv2.split(hsv)
 
219
  out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
220
  return out
221
 
222
+ def _composite_frame_pro(
223
+ fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
224
+ erode_px: int = None, dilate_px: int = None, blur_px: float = None,
225
+ lw_radius: int = None, lw_amount: float = None, despill_amount: float = None
226
+ ) -> np.ndarray:
227
  erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
228
  dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
229
  blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
 
231
  lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
232
  despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
233
 
 
234
  a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
 
 
235
  fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
236
 
 
237
  fg_lin = _to_linear(fg_rgb)
238
  bg_lin = _to_linear(bg_rgb)
 
 
239
  lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
240
  lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
241
 
 
255
  return str(candidate)
256
  if cfg_path.exists():
257
  return str(cfg_path)
 
258
  for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
259
  p = TP_SAM2 / name
260
  if p.exists():
261
  return str(p)
262
+ return str(cfg_str)
263
 
264
  def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
265
  """If config references 'hieradet', try to find a 'hiera' config."""
266
  try:
267
  with open(cfg_path, "r") as f:
268
  data = yaml.safe_load(f)
269
+ model = data.get("model", {}) or {}
270
+ enc = model.get("image_encoder") or {}
271
+ trunk = enc.get("trunk") or {}
 
272
  target = trunk.get("_target_") or trunk.get("target")
273
  if isinstance(target, str) and "hieradet" in target:
274
  for y in TP_SAM2.rglob("*.yaml"):
275
  try:
276
  with open(y, "r") as f2:
277
+ d2 = yaml.safe_load(f2) or {}
278
+ e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
 
279
  t2 = (e2.get("trunk") or {})
280
  tgt2 = t2.get("_target_") or t2.get("target")
281
  if isinstance(tgt2, str) and ".hiera." in tgt2:
 
332
  try:
333
  try:
334
  sam = _try_build(cfg)
335
+ except Exception:
336
  alt_cfg = _find_hiera_config_if_hieradet(cfg)
337
  if alt_cfg:
338
  logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
 
445
  repo_id = os.environ.get("MATANY_REPO_ID", "")
446
  ckpt = os.environ.get("MATANY_CHECKPOINT", "")
447
 
 
448
  try:
449
  sig = inspect.signature(InferenceCore)
450
  if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
 
674
  # Stage-A (Transparent Export) Functions
675
  # --------------------------------------------------------------------------------------
676
  def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
 
677
  y, x = np.mgrid[0:h, 0:w]
678
  c = ((x // tile) + (y // tile)) % 2
679
  a = np.where(c == 0, 200, 150).astype(np.uint8)
 
687
  size: Tuple[int, int],
688
  src_audio: Optional[Union[str, Path]] = None,
689
  ) -> bool:
 
690
  if not _probe_ffmpeg():
691
  return False
692
  w, h = size
 
718
  fps: int,
719
  size: Tuple[int, int],
720
  ) -> bool:
 
721
  if not _probe_ffmpeg():
722
  return False
723
  w, h = size
 
748
  fps: int,
749
  size: Tuple[int, int],
750
  ) -> bool:
 
751
  fg_cap = cv2.VideoCapture(str(fg_path))
752
  al_cap = cv2.VideoCapture(str(alpha_path))
753
  if not fg_cap.isOpened() or not al_cap.isOpened():
 
780
  fps: int,
781
  size: Tuple[int, int],
782
  ) -> bool:
 
783
  cap = cv2.VideoCapture(str(video_path))
784
  if not cap.isOpened():
785
  return False
 
803
  finally:
804
  cap.release()
805
  writer.release()
806
+ return ok_any
perf_tuning.py CHANGED
@@ -1,8 +1,10 @@
1
- # perf_tuning.py
2
  #!/usr/bin/env python3
3
  """
4
- Forces CUDA use (or fails fast), configures cuDNN/TF32, and logs a clear GPU banner.
5
- Loaded automatically because pipeline.py does: `import perf_tuning` (best-effort).
 
 
 
6
  """
7
 
8
  import os
@@ -15,59 +17,107 @@
15
  log.addHandler(h)
16
  log.setLevel(logging.INFO)
17
 
18
- try:
19
- import torch
20
- except Exception as e:
21
- raise RuntimeError(f"PyTorch not importable: {e}")
 
 
 
22
 
23
- require_cuda = os.environ.get("REQUIRE_CUDA", "0").strip() == "1"
24
- force_idx_env = os.environ.get("FORCE_CUDA_DEVICE", "").strip()
25
- mem_frac = float(os.environ.get("CUDA_MEMORY_FRACTION", "0.98"))
26
-
27
- if not torch.cuda.is_available():
28
- if require_cuda:
29
- raise RuntimeError("CUDA is NOT available, but REQUIRE_CUDA=1. "
30
- "Make sure the Space is on GPU and the container runs with --gpus all.")
31
- else:
32
- log.warning("CUDA not available; running on CPU. Set REQUIRE_CUDA=1 to fail fast.")
33
  else:
34
- # Choose device
35
  try:
36
- idx = int(force_idx_env) if force_idx_env != "" else 0
37
- except Exception:
38
- idx = 0
39
- if idx >= torch.cuda.device_count() or idx < 0:
40
- idx = 0
 
 
 
 
41
 
42
- torch.cuda.set_device(idx)
 
 
 
 
 
 
 
43
 
44
- # Perf knobs
45
- try:
46
- torch.backends.cuda.matmul.allow_tf32 = True
47
- except Exception:
48
- pass
49
- try:
50
- torch.backends.cudnn.allow_tf32 = True
51
- torch.backends.cudnn.benchmark = True
52
- except Exception:
53
- pass
54
 
55
- # Reserve VRAM fraction (best effort)
56
- try:
57
- torch.cuda.set_per_process_memory_fraction(mem_frac, idx)
58
- except Exception:
59
- pass
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Log a clear banner
62
- try:
63
- name = torch.cuda.get_device_name(idx)
64
- cap = torch.cuda.get_device_capability(idx)
65
- total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024**3)
66
- free_gb = torch.cuda.mem_get_info()[0] / (1024**3)
67
- log.info(f"Using CUDA device {idx}: {name} | cc {cap[0]}.{cap[1]} | "
68
- f"VRAM {total_gb:.2f} GB (free ~{free_gb:.2f} GB) | TF32:ON | cuDNN benchmark:ON")
69
- except Exception:
70
- log.info("Using CUDA; device info unavailable (but cuda.is_available()==True).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Optional: limit OpenCV threads if provided
73
  threads = os.environ.get("OPENCV_NUM_THREADS")
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ perf_tuning.py (Hardened)
4
+ - No hard CUDA touching at import time (prevents startup hangs on Spaces).
5
+ - Optional "strict" modes via env if you *really* want fail-fast behavior.
6
+ - Applies safe flags (TF32/cudnn.benchmark) best-effort.
7
+ - Short, defensive GPU banner (only if explicitly enabled).
8
  """
9
 
10
  import os
 
17
  log.addHandler(h)
18
  log.setLevel(logging.INFO)
19
 
20
+ # ---- Feature flags (env) -----------------------------------------------------
21
+ DISABLED = os.getenv("PERF_TUNING_DISABLED", "0").strip() == "1"
22
+ STRICT_IMPORT_FAIL = os.getenv("PERF_TUNING_IMPORT_STRICT", "0").strip() == "1" # if 1, may raise on import
23
+ EAGER_CUDA = os.getenv("PERF_TUNING_EAGER_CUDA", "0").strip() == "1" # if 1, do CUDA probing now
24
+ REQUIRE_CUDA = os.getenv("REQUIRE_CUDA", "0").strip() == "1" # prefer not to fail at import
25
+ FORCE_IDX_ENV = os.getenv("FORCE_CUDA_DEVICE", "").strip()
26
+ MEM_FRAC_STR = os.getenv("CUDA_MEMORY_FRACTION", "0.98").strip()
27
 
28
+ if DISABLED:
29
+ log.info("perf_tuning: disabled by PERF_TUNING_DISABLED=1")
 
 
 
 
 
 
 
 
30
  else:
31
+ # Import torch defensively (do NOT crash the app if it's not there)
32
  try:
33
+ import importlib
34
+ torch = importlib.import_module("torch")
35
+ except Exception as e:
36
+ msg = f"perf_tuning: PyTorch not importable at import-time: {e}"
37
+ if STRICT_IMPORT_FAIL:
38
+ raise RuntimeError(msg)
39
+ else:
40
+ log.warning(msg)
41
+ torch = None
42
 
43
+ def _bool_cuda_available():
44
+ if torch is None:
45
+ return False
46
+ try:
47
+ return bool(torch.cuda.is_available())
48
+ except Exception as e:
49
+ log.warning(f"perf_tuning: cuda.is_available() failed: {e}")
50
+ return False
51
 
52
+ # Soft gating: if user *requires* CUDA, set a marker we can read later
53
+ if REQUIRE_CUDA and not _bool_cuda_available():
54
+ os.environ["BFX_REQUIRE_CUDA_FAILED"] = "1"
55
+ msg = "CUDA NOT available but REQUIRE_CUDA=1 (will run on CPU unless app checks this later)."
56
+ if STRICT_IMPORT_FAIL:
57
+ raise RuntimeError(msg)
58
+ else:
59
+ log.warning(msg)
 
 
60
 
61
+ # Always try “cheap” flags that won’t touch devices
62
+ if torch is not None:
63
+ try:
64
+ # These do not require an active CUDA context
65
+ if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
66
+ try:
67
+ torch.backends.cuda.matmul.allow_tf32 = True
68
+ except Exception:
69
+ pass
70
+ if hasattr(torch.backends, "cudnn"):
71
+ try:
72
+ torch.backends.cudnn.allow_tf32 = True
73
+ torch.backends.cudnn.benchmark = True
74
+ except Exception:
75
+ pass
76
+ except Exception as e:
77
+ log.debug(f"perf_tuning: backend flags suppressed: {e}")
78
 
79
+ # Only do potentially blocking CUDA work if explicitly requested
80
+ if EAGER_CUDA and torch is not None:
81
+ try:
82
+ # Choose device (optional)
83
+ try:
84
+ idx = int(FORCE_IDX_ENV) if FORCE_IDX_ENV != "" else 0
85
+ except Exception:
86
+ idx = 0
87
+ try:
88
+ torch.cuda.set_device(idx)
89
+ except Exception as e:
90
+ log.warning(f"perf_tuning: set_device({idx}) failed: {e}")
91
+
92
+ # Memory fraction is optional and sometimes flaky—guard it
93
+ try:
94
+ mem_frac = float(MEM_FRAC_STR)
95
+ torch.cuda.set_per_process_memory_fraction(mem_frac, idx)
96
+ except Exception as e:
97
+ log.debug(f"perf_tuning: set_per_process_memory_fraction skipped: {e}")
98
+
99
+ # Best-effort banner; every call is wrapped so nothing blocks startup
100
+ try:
101
+ name = torch.cuda.get_device_name(idx)
102
+ except Exception as e:
103
+ name = f"? ({e})"
104
+ try:
105
+ cap = torch.cuda.get_device_capability(idx)
106
+ cap_s = f"{cap[0]}.{cap[1]}"
107
+ except Exception as e:
108
+ cap_s = f"? ({e})"
109
+ try:
110
+ total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024**3)
111
+ except Exception as e:
112
+ total_gb = f"? ({e})"
113
+ try:
114
+ free_gb = torch.cuda.mem_get_info()[0] / (1024**3)
115
+ except Exception as e:
116
+ free_gb = f"? ({e})"
117
+
118
+ log.info(f"CUDA device {idx}: {name} | cc {cap_s} | VRAM {total_gb} GB (free ~{free_gb} GB) | TF32:ON | cuDNN benchmark:ON")
119
+ except Exception as e:
120
+ log.warning(f"perf_tuning: eager CUDA probe failed (non-fatal): {e}")
121
 
122
  # Optional: limit OpenCV threads if provided
123
  threads = os.environ.get("OPENCV_NUM_THREADS")
pipeline.py CHANGED
@@ -1,9 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- BackgroundFX Pro - Memory-Optimized Pipeline
4
- ===========================================
5
- Orchestrates SAM2 MatAnyone Compositing with aggressive memory management.
6
- Models are loaded sequentially and freed immediately after use.
 
 
7
  """
8
 
9
  from __future__ import annotations
@@ -13,85 +15,128 @@
13
  import time
14
  import tempfile
15
  import logging
 
16
  from pathlib import Path
17
  from typing import Optional, Tuple, Dict, Any, Union
18
 
19
- import torch
20
- from models import (
21
- load_sam2, run_sam2_mask, load_matany, run_matany,
22
- fallback_mask, fallback_composite, composite_video,
23
- _cv_read_first_frame, _save_mask_png, _ensure_dir, _mux_audio, _probe_ffmpeg,
24
- _refine_mask_grabcut, _build_stage_a_rgba_vp9_from_fg_alpha,
25
- _build_stage_a_rgba_vp9_from_mask, _build_stage_a_checkerboard_from_fg_alpha,
26
- _build_stage_a_checkerboard_from_mask
27
- )
28
-
29
- # Try to apply GPU/perf tuning early
30
- try:
31
- import perf_tuning # noqa: F401
32
- except Exception:
33
- pass
34
-
35
  # --------------------------------------------------------------------------------------
36
  # Logging
37
  # --------------------------------------------------------------------------------------
38
  logger = logging.getLogger("backgroundfx_pro")
39
- logger.setLevel(logging.INFO)
40
  if not logger.handlers:
41
  _h = logging.StreamHandler()
42
  _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
43
  logger.addHandler(_h)
 
44
 
45
  # --------------------------------------------------------------------------------------
46
- # Memory Management Utilities
47
  # --------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def _cleanup_temp_files(tmp_root: Path) -> None:
49
- """Clean up temporary files aggressively"""
50
  try:
51
- for pattern in ["*.tmp", "*.temp", "*.bak"]:
52
  for f in tmp_root.glob(pattern):
53
  f.unlink(missing_ok=True)
54
  except Exception:
55
  pass
56
 
57
  def _log_memory() -> float:
58
- """Log current GPU memory usage and return allocated GB"""
59
- if torch.cuda.is_available():
60
- try:
61
- allocated = torch.cuda.memory_allocated() / 1e9
62
- reserved = torch.cuda.memory_reserved() / 1e9
 
 
 
 
63
  logger.info(f"GPU memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
64
- return allocated
65
- except Exception:
66
- pass
67
  return 0.0
68
 
69
  def _force_cleanup() -> None:
70
- """Aggressive memory cleanup"""
71
  try:
72
  gc.collect()
73
- if torch.cuda.is_available():
74
- torch.cuda.empty_cache()
75
- torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
- logger.warning(f"Cleanup warning: {e}")
 
78
 
79
  # --------------------------------------------------------------------------------------
80
- # Main Processing Function (Memory-Optimized)
81
  # --------------------------------------------------------------------------------------
82
- def process(video_path: Union[str, Path],
83
- bg_image_path: Union[str, Path],
84
- point_x: Optional[float] = None,
85
- point_y: Optional[float] = None,
86
- auto_box: bool = False,
87
- work_dir: Optional[Union[str, Path]] = None) -> Tuple[Optional[str], Dict[str, Any]]:
 
 
88
  """
89
  Memory-optimized orchestration: lazy loading, sequential model usage, aggressive cleanup.
90
-
91
  Flow:
92
- 1. Load SAM2 → get mask → FREE SAM2 immediately
93
- 2. Load MatAnyone process FREE MatAnyone immediately
94
- 3. Composite & finalize (CPU-based operations)
 
 
 
95
  """
96
  t0 = time.time()
97
  diagnostics: Dict[str, Any] = {
@@ -110,105 +155,130 @@ def process(video_path: Union[str, Path],
110
  tmp_root = Path(work_dir) if work_dir else Path(tempfile.mkdtemp(prefix="bfx_"))
111
  _ensure_dir(tmp_root)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  try:
114
  # 0) Basic video info
115
- logger.info("Reading video metadata...")
116
  first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
117
  diagnostics["fps"] = int(fps or 25)
118
  diagnostics["resolution"] = [int(vw), int(vh)]
119
-
120
  if first_frame is None or vw == 0 or vh == 0:
121
  diagnostics["fallback_used"] = "invalid_video"
122
  return None, diagnostics
123
 
124
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
125
 
126
- # 1) PHASE 1: SAM2 Loading & Processing → IMMEDIATE CLEANUP
127
- logger.info("=== PHASE 1: Loading SAM2 for segmentation ===")
128
  predictor, sam2_ok, sam_meta = load_sam2()
129
- diagnostics["sam2_meta"] = sam_meta
130
- diagnostics["device_sam2"] = sam_meta.get("sam2_device") if sam_meta else None
131
-
132
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
133
 
134
  seed_mask = None
135
  mask_png = tmp_root / "seed_mask.png"
136
-
 
137
  if sam2_ok and predictor is not None:
138
- logger.info("Running SAM2 segmentation...")
139
  px = int(point_x) if point_x is not None else None
140
  py = int(point_y) if point_y is not None else None
141
-
142
  seed_mask, ok_mask = run_sam2_mask(
143
  predictor, first_frame,
144
  point=(px, py) if (px is not None and py is not None) else None,
145
  auto=auto_box
146
  )
147
  diagnostics["sam2_ok"] = bool(ok_mask)
148
-
149
- # CRITICAL: Free SAM2 immediately after getting the mask
150
- logger.info("Freeing SAM2 memory...")
151
- del predictor
152
- predictor = None
153
- _force_cleanup()
154
- diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
155
-
156
  else:
157
- ok_mask = False
158
- logger.info("SAM2 not available or failed to load")
 
 
 
 
 
 
 
 
159
 
160
  # Fallback mask generation if SAM2 failed
161
  if not ok_mask or seed_mask is None:
162
- logger.info("Using fallback mask generation...")
163
  seed_mask = fallback_mask(first_frame)
164
  diagnostics["fallback_used"] = "mask_generation"
165
  _force_cleanup()
166
 
167
  # Optional GrabCut refinement
168
  if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
169
- logger.info("Refining mask with GrabCut...")
170
  seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
171
  _force_cleanup()
172
 
173
  _save_mask_png(seed_mask, mask_png)
174
-
175
- # Clean up the first frame from memory
176
- del first_frame
 
 
 
177
  _force_cleanup()
178
  _cleanup_temp_files(tmp_root)
179
 
180
- # 2) PHASE 2: MatAnyone Loading & Processing → IMMEDIATE CLEANUP
181
- logger.info("=== PHASE 2: Loading MatAnyone for temporal processing ===")
182
  matany, mat_ok, mat_meta = load_matany()
183
- diagnostics["matany_meta"] = mat_meta
184
- diagnostics["device_matany"] = mat_meta.get("matany_device") if mat_meta else None
185
-
186
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
187
 
188
  fg_path, al_path = None, None
189
  out_dir = tmp_root / "matany_out"
190
  _ensure_dir(out_dir)
191
-
 
192
  if mat_ok and matany is not None:
193
- logger.info("Running MatAnyone processing...")
194
  fg_path, al_path, ran = run_matany(matany, video_path, mask_png, out_dir)
195
  diagnostics["matany_ok"] = bool(ran)
196
-
197
- # CRITICAL: Free MatAnyone immediately after processing
198
- logger.info("Freeing MatAnyone memory...")
199
- del matany
200
- matany = None
201
- _force_cleanup()
202
- diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
203
  else:
204
- ran = False
205
- logger.info("MatAnyone not available, disabled, or failed to load")
 
 
 
 
 
 
 
 
206
 
207
- # 3) PHASE 3: Stage-A Creation (lightweight, CPU-based)
208
- logger.info("=== PHASE 3: Building Stage-A (transparent export) ===")
209
  stageA_path = None
210
  stageA_ok = False
211
-
212
  if diagnostics["matany_ok"] and fg_path and al_path:
213
  stageA_path = tmp_root / "stageA_transparent.webm"
214
  if _probe_ffmpeg():
@@ -238,57 +308,56 @@ def process(video_path: Union[str, Path],
238
  else ("MP4 checkerboard preview (no real alpha)" if stageA_ok else "Stage-A build failed")
239
  )
240
 
241
- # Optional: return Stage-A instead of final composite
242
  if os.environ.get("RETURN_STAGE_A", "0").strip() == "1" and stageA_ok:
243
  _force_cleanup()
244
  _cleanup_temp_files(tmp_root)
 
 
245
  return str(stageA_path), diagnostics
246
 
247
- # 4) PHASE 4: Final Compositing (CPU-based, memory-efficient)
248
- logger.info("=== PHASE 4: Creating final composite ===")
249
  output_path = tmp_root / "output.mp4"
250
-
251
  if diagnostics["matany_ok"] and fg_path and al_path:
252
- logger.info("Compositing with MatAnyone outputs...")
253
  ok_comp = composite_video(fg_path, al_path, bg_image_path, output_path, diagnostics["fps"], (vw, vh))
254
  if not ok_comp:
255
- logger.info("MatAnyone composite failed; falling back to static mask composite.")
256
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
257
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
258
  else:
259
- logger.info("Using static mask composite...")
260
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
261
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
262
 
263
- # Clean up intermediate files
264
  _cleanup_temp_files(tmp_root)
265
  _force_cleanup()
266
 
267
- # 5) PHASE 5: Audio Muxing (final step)
268
- logger.info("=== PHASE 5: Adding audio track ===")
269
  final_path = tmp_root / "output_with_audio.mp4"
270
  if _probe_ffmpeg():
271
  mux_ok = _mux_audio(video_path, output_path, final_path)
272
  if mux_ok:
273
- # Clean up the silent version
274
  output_path.unlink(missing_ok=True)
275
  _force_cleanup()
276
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
277
- logger.info(f"Processing completed successfully in {diagnostics['elapsed_sec']}s")
278
- logger.info(f"Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
279
  return str(final_path), diagnostics
280
 
281
- # Final cleanup
282
  _force_cleanup()
283
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
284
- logger.info(f"Processing completed in {diagnostics['elapsed_sec']}s (no audio)")
285
- logger.info(f"Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
286
  return str(output_path), diagnostics
287
 
288
  except Exception as e:
289
- logger.error(f"Processing failed: {e}")
290
  import traceback
291
- logger.error(f"Traceback: {traceback.format_exc()}")
292
  _force_cleanup()
293
  diagnostics["error"] = str(e)
294
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
@@ -297,4 +366,4 @@ def process(video_path: Union[str, Path],
297
  finally:
298
  # Ensure cleanup even if something goes wrong
299
  _force_cleanup()
300
- _cleanup_temp_files(tmp_root)
 
1
  #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro - Memory-Optimized Pipeline (Hardened)
4
+ ======================================================
5
+ - Lazy-imports heavy 'models' module to avoid Space boot stalls
6
+ - Sequential load run free (SAM2 then MatAnyone)
7
+ - Aggressive but non-blocking GPU cleanup (no synchronize())
8
+ - Verbose breadcrumbs for pinpointing stalls
9
  """
10
 
11
  from __future__ import annotations
 
15
  import time
16
  import tempfile
17
  import logging
18
+ import importlib
19
  from pathlib import Path
20
  from typing import Optional, Tuple, Dict, Any, Union
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # --------------------------------------------------------------------------------------
23
  # Logging
24
  # --------------------------------------------------------------------------------------
25
  logger = logging.getLogger("backgroundfx_pro")
 
26
  if not logger.handlers:
27
  _h = logging.StreamHandler()
28
  _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
29
  logger.addHandler(_h)
30
+ logger.setLevel(logging.INFO)
31
 
32
  # --------------------------------------------------------------------------------------
33
+ # Safe Torch accessors (avoid import-time CUDA touches)
34
  # --------------------------------------------------------------------------------------
35
+ def _torch():
36
+ try:
37
+ import torch # local import to avoid early CUDA init in module scope
38
+ return torch
39
+ except Exception as e:
40
+ logger.warning(f"[safe-torch] import failed: {e}")
41
+ return None
42
+
43
+ def _cuda_available() -> Optional[bool]:
44
+ t = _torch()
45
+ if t is None:
46
+ return None
47
+ try:
48
+ return t.cuda.is_available()
49
+ except Exception as e:
50
+ logger.warning(f"[safe-torch] torch.cuda.is_available() failed: {e}")
51
+ return None
52
+
53
+ # --------------------------------------------------------------------------------------
54
+ # Lightweight utilities
55
+ # --------------------------------------------------------------------------------------
56
+ def _ensure_dir(p: Union[str, Path]) -> None:
57
+ Path(p).mkdir(parents=True, exist_ok=True)
58
+
59
  def _cleanup_temp_files(tmp_root: Path) -> None:
60
+ """Clean up temporary files aggressively."""
61
  try:
62
+ for pattern in ("*.tmp", "*.temp", "*.bak"):
63
  for f in tmp_root.glob(pattern):
64
  f.unlink(missing_ok=True)
65
  except Exception:
66
  pass
67
 
68
  def _log_memory() -> float:
69
+ """Best-effort GPU mem log (never block)."""
70
+ t = _torch()
71
+ if t is None:
72
+ return 0.0
73
+ try:
74
+ avail = _cuda_available()
75
+ if avail:
76
+ allocated = t.cuda.memory_allocated() / 1e9
77
+ reserved = t.cuda.memory_reserved() / 1e9
78
  logger.info(f"GPU memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
79
+ return float(allocated)
80
+ except Exception as e:
81
+ logger.debug(f"[mem-log] suppressed: {e}")
82
  return 0.0
83
 
84
  def _force_cleanup() -> None:
85
+ """Aggressive memory cleanup (non-blocking)."""
86
  try:
87
  gc.collect()
88
+ except Exception:
89
+ pass
90
+ t = _torch()
91
+ if t is None:
92
+ return
93
+ try:
94
+ if _cuda_available():
95
+ # Avoid torch.cuda.synchronize() — can hang on driver issues
96
+ t.cuda.empty_cache()
97
+ except Exception as e:
98
+ logger.debug(f"[cleanup] suppressed: {e}")
99
+
100
+ # --------------------------------------------------------------------------------------
101
+ # Lazy import of heavy models module
102
+ # --------------------------------------------------------------------------------------
103
+ _models_ref = None
104
+
105
+ def _models():
106
+ """Import 'models' only when needed to avoid startup stalls."""
107
+ global _models_ref
108
+ if _models_ref is not None:
109
+ return _models_ref
110
+ logger.info("[init] Importing models module lazily…")
111
+ try:
112
+ _models_ref = importlib.import_module("models")
113
+ logger.info("[init] models imported OK.")
114
+ return _models_ref
115
  except Exception as e:
116
+ logger.exception(f"[init] Failed to import models: {e}")
117
+ raise
118
 
119
  # --------------------------------------------------------------------------------------
120
+ # Main Processing Function
121
  # --------------------------------------------------------------------------------------
122
+ def process(
123
+ video_path: Union[str, Path],
124
+ bg_image_path: Union[str, Path],
125
+ point_x: Optional[float] = None,
126
+ point_y: Optional[float] = None,
127
+ auto_box: bool = False,
128
+ work_dir: Optional[Union[str, Path]] = None
129
+ ) -> Tuple[Optional[str], Dict[str, Any]]:
130
  """
131
  Memory-optimized orchestration: lazy loading, sequential model usage, aggressive cleanup.
132
+
133
  Flow:
134
+ 0. Read video metadata
135
+ 1. SAM2mask (free immediately)
136
+ 2. MatAnyone FG/alpha (free immediately)
137
+ 3. Stage-A build (transparent or checkerboard)
138
+ 4. Final composite
139
+ 5. Audio mux
140
  """
141
  t0 = time.time()
142
  diagnostics: Dict[str, Any] = {
 
155
  tmp_root = Path(work_dir) if work_dir else Path(tempfile.mkdtemp(prefix="bfx_"))
156
  _ensure_dir(tmp_root)
157
 
158
+ # Defer heavy function imports until inside the call
159
+ M = _models()
160
+ # pull only the needed callables
161
+ _cv_read_first_frame = M._cv_read_first_frame
162
+ _save_mask_png = M._save_mask_png
163
+ _probe_ffmpeg = M._probe_ffmpeg
164
+ _mux_audio = M._mux_audio
165
+ _refine_mask_grabcut = M._refine_mask_grabcut
166
+ fallback_mask = M.fallback_mask
167
+ fallback_composite = M.fallback_composite
168
+ composite_video = M.composite_video
169
+ load_sam2 = M.load_sam2
170
+ run_sam2_mask = M.run_sam2_mask
171
+ load_matany = M.load_matany
172
+ run_matany = M.run_matany
173
+ _build_stage_a_rgba_vp9_from_fg_alpha = M._build_stage_a_rgba_vp9_from_fg_alpha
174
+ _build_stage_a_rgba_vp9_from_mask = M._build_stage_a_rgba_vp9_from_mask
175
+ _build_stage_a_checkerboard_from_fg_alpha = M._build_stage_a_checkerboard_from_fg_alpha
176
+ _build_stage_a_checkerboard_from_mask = M._build_stage_a_checkerboard_from_mask
177
+
178
  try:
179
  # 0) Basic video info
180
+ logger.info("[0] Reading video metadata")
181
  first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
182
  diagnostics["fps"] = int(fps or 25)
183
  diagnostics["resolution"] = [int(vw), int(vh)]
184
+
185
  if first_frame is None or vw == 0 or vh == 0:
186
  diagnostics["fallback_used"] = "invalid_video"
187
  return None, diagnostics
188
 
189
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
190
 
191
+ # 1) PHASE 1: SAM2
192
+ logger.info("[1] Loading SAM2")
193
  predictor, sam2_ok, sam_meta = load_sam2()
194
+ diagnostics["sam2_meta"] = sam_meta or {}
195
+ diagnostics["device_sam2"] = (sam_meta or {}).get("sam2_device")
196
+
197
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
198
 
199
  seed_mask = None
200
  mask_png = tmp_root / "seed_mask.png"
201
+
202
+ ok_mask = False
203
  if sam2_ok and predictor is not None:
204
+ logger.info("[1] Running SAM2 segmentation")
205
  px = int(point_x) if point_x is not None else None
206
  py = int(point_y) if point_y is not None else None
 
207
  seed_mask, ok_mask = run_sam2_mask(
208
  predictor, first_frame,
209
  point=(px, py) if (px is not None and py is not None) else None,
210
  auto=auto_box
211
  )
212
  diagnostics["sam2_ok"] = bool(ok_mask)
 
 
 
 
 
 
 
 
213
  else:
214
+ logger.info("[1] SAM2 unavailable or failed to load.")
215
+
216
+ # Free SAM2 ASAP
217
+ try:
218
+ del predictor
219
+ except Exception:
220
+ pass
221
+ predictor = None
222
+ _force_cleanup()
223
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
224
 
225
  # Fallback mask generation if SAM2 failed
226
  if not ok_mask or seed_mask is None:
227
+ logger.info("[1] Using fallback mask generation")
228
  seed_mask = fallback_mask(first_frame)
229
  diagnostics["fallback_used"] = "mask_generation"
230
  _force_cleanup()
231
 
232
  # Optional GrabCut refinement
233
  if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
234
+ logger.info("[1] Refining mask with GrabCut")
235
  seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
236
  _force_cleanup()
237
 
238
  _save_mask_png(seed_mask, mask_png)
239
+
240
+ # Free first frame
241
+ try:
242
+ del first_frame
243
+ except Exception:
244
+ pass
245
  _force_cleanup()
246
  _cleanup_temp_files(tmp_root)
247
 
248
+ # 2) PHASE 2: MatAnyone
249
+ logger.info("[2] Loading MatAnyone")
250
  matany, mat_ok, mat_meta = load_matany()
251
+ diagnostics["matany_meta"] = mat_meta or {}
252
+ diagnostics["device_matany"] = (mat_meta or {}).get("matany_device")
253
+
254
  diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
255
 
256
  fg_path, al_path = None, None
257
  out_dir = tmp_root / "matany_out"
258
  _ensure_dir(out_dir)
259
+
260
+ ran = False
261
  if mat_ok and matany is not None:
262
+ logger.info("[2] Running MatAnyone processing")
263
  fg_path, al_path, ran = run_matany(matany, video_path, mask_png, out_dir)
264
  diagnostics["matany_ok"] = bool(ran)
 
 
 
 
 
 
 
265
  else:
266
+ logger.info("[2] MatAnyone unavailable/disabled/failed to load.")
267
+
268
+ # Free MatAnyone ASAP
269
+ try:
270
+ del matany
271
+ except Exception:
272
+ pass
273
+ matany = None
274
+ _force_cleanup()
275
+ diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
276
 
277
+ # 3) PHASE 3: Stage-A
278
+ logger.info("[3] Building Stage-A (transparent or checkerboard)…")
279
  stageA_path = None
280
  stageA_ok = False
281
+
282
  if diagnostics["matany_ok"] and fg_path and al_path:
283
  stageA_path = tmp_root / "stageA_transparent.webm"
284
  if _probe_ffmpeg():
 
308
  else ("MP4 checkerboard preview (no real alpha)" if stageA_ok else "Stage-A build failed")
309
  )
310
 
 
311
  if os.environ.get("RETURN_STAGE_A", "0").strip() == "1" and stageA_ok:
312
  _force_cleanup()
313
  _cleanup_temp_files(tmp_root)
314
+ diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
315
+ logger.info(f"[done] Returned Stage-A in {diagnostics['elapsed_sec']}s")
316
  return str(stageA_path), diagnostics
317
 
318
+ # 4) PHASE 4: Final Compositing
319
+ logger.info("[4] Creating final composite")
320
  output_path = tmp_root / "output.mp4"
321
+
322
  if diagnostics["matany_ok"] and fg_path and al_path:
323
+ logger.info("[4] Compositing with MatAnyone outputs")
324
  ok_comp = composite_video(fg_path, al_path, bg_image_path, output_path, diagnostics["fps"], (vw, vh))
325
  if not ok_comp:
326
+ logger.info("[4] Composite failed; falling back to static mask composite.")
327
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
328
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
329
  else:
330
+ logger.info("[4] Using static mask composite")
331
  fallback_composite(video_path, mask_png, bg_image_path, output_path)
332
  diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
333
 
 
334
  _cleanup_temp_files(tmp_root)
335
  _force_cleanup()
336
 
337
+ # 5) PHASE 5: Audio Mux
338
+ logger.info("[5] Adding audio track")
339
  final_path = tmp_root / "output_with_audio.mp4"
340
  if _probe_ffmpeg():
341
  mux_ok = _mux_audio(video_path, output_path, final_path)
342
  if mux_ok:
 
343
  output_path.unlink(missing_ok=True)
344
  _force_cleanup()
345
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
346
+ logger.info(f"[done] Success in {diagnostics['elapsed_sec']}s")
347
+ logger.info(f"[done] Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
348
  return str(final_path), diagnostics
349
 
350
+ # Fallback return without audio
351
  _force_cleanup()
352
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
353
+ logger.info(f"[done] Completed (no audio) in {diagnostics['elapsed_sec']}s")
354
+ logger.info(f"[done] Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
355
  return str(output_path), diagnostics
356
 
357
  except Exception as e:
358
+ logger.error(f"[error] Processing failed: {e}")
359
  import traceback
360
+ logger.error(f"[error] Traceback:\n{traceback.format_exc()}")
361
  _force_cleanup()
362
  diagnostics["error"] = str(e)
363
  diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
 
366
  finally:
367
  # Ensure cleanup even if something goes wrong
368
  _force_cleanup()
369
+ _cleanup_temp_files(tmp_root)
requirements.txt CHANGED
@@ -9,7 +9,8 @@ moviepy==1.0.3
9
  decord==0.6.0
10
  Pillow==10.4.0
11
  numpy==1.26.4
12
- mediapipe==0.10.14
 
13
 
14
  # ===== Gradio UI =====
15
  gradio==5.42.0
@@ -28,10 +29,10 @@ scikit-image==0.24.0
28
  tqdm==4.66.5
29
 
30
  # ===== Helpers / caching =====
31
- huggingface_hub>=0.33.5
32
  ffmpeg-python==0.2.0
33
  psutil==6.0.0
34
- requests==2.31.0
35
  scikit-learn==1.5.1
36
 
37
  # ===== (Optional) Extras =====
 
9
  decord==0.6.0
10
  Pillow==10.4.0
11
  numpy==1.26.4
12
+ mediapipe==0.10.14
13
+ protobuf==4.25.3
14
 
15
  # ===== Gradio UI =====
16
  gradio==5.42.0
 
29
  tqdm==4.66.5
30
 
31
  # ===== Helpers / caching =====
32
+ huggingface_hub==0.33.5
33
  ffmpeg-python==0.2.0
34
  psutil==6.0.0
35
+ requests==2.32.3
36
  scikit-learn==1.5.1
37
 
38
  # ===== (Optional) Extras =====
ui.py CHANGED
@@ -1,6 +1,8 @@
1
- # ui.py
2
  """
3
- BackgroundFX Pro — Gradio UI, background generators, and data sources.
 
 
4
  """
5
 
6
  import io
@@ -14,13 +16,12 @@
14
  from PIL import Image
15
  import gradio as gr
16
 
17
- from pipeline import (
18
- process_video_gpu_optimized, stop_processing, processing_active,
19
- SAM2_ENABLED, MATANY_ENABLED, GPU_NAME, GPU_MEMORY
20
- )
21
-
22
  logger = logging.getLogger("ui")
23
-
 
 
 
 
24
 
25
  # ---- Background generators ----
26
  def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
@@ -51,7 +52,6 @@ def create_gradient_background(gradient_type: str, width: int, height: int) -> I
51
  img[i, :] = [r, g, b]
52
  return Image.fromarray(img)
53
 
54
-
55
  def create_solid_color(color: str, width: int, height: int) -> Image.Image:
56
  color_map = {
57
  "white": (255, 255, 255),
@@ -66,22 +66,25 @@ def create_solid_color(color: str, width: int, height: int) -> Image.Image:
66
  rgb = color_map.get(color, (70, 130, 180))
67
  return Image.fromarray(np.full((height, width, 3), rgb, dtype=np.uint8))
68
 
69
-
70
  def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
71
  try:
72
- if not prompt.strip():
73
  return None, "Please enter a prompt"
74
  models = [
75
  "black-forest-labs/FLUX.1-schnell",
76
  "stabilityai/stable-diffusion-xl-base-1.0",
77
- "runwayml/stable-diffusion-v1-5"
78
  ]
79
  enhanced_prompt = f"professional video background, {prompt}, high quality, 16:9, cinematic lighting, detailed"
 
 
80
  for model in models:
81
  try:
82
  url = f"https://api-inference.huggingface.co/models/{model}"
83
- headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN', 'hf_placeholder')}"}
84
- payload = {"inputs": enhanced_prompt, "parameters": {"width": 1024, "height": 576, "num_inference_steps": 20, "guidance_scale": 7.5}}
 
 
85
  r = requests.post(url, headers=headers, json=payload, timeout=60, stream=True)
86
  if r.status_code == 200 and "image" in r.headers.get("content-type", "").lower():
87
  buf = io.BytesIO(r.content if r.raw is None else r.raw.read())
@@ -95,11 +98,10 @@ def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
95
  logger.error(f"AI background error: {e}")
96
  return create_gradient_background("default", 1920, 1080), "Default due to error"
97
 
98
-
99
  # ---- MyAvatar API ----
100
  class MyAvatarAPI:
101
  def __init__(self):
102
- self.api_base = "https://app.myavatar.dk/api"
103
  self.videos_cache: List[Dict[str, Any]] = []
104
  self.last_refresh = 0
105
 
@@ -140,11 +142,20 @@ def get_video_url(self, selection: str) -> Optional[str]:
140
  logger.error(f"Parse selection failed: {e}")
141
  return None
142
 
143
-
144
  myavatar_api = MyAvatarAPI()
145
 
 
 
 
 
 
 
146
 
147
- # ---- UI ↔ Pipeline bridge: streaming handler ----
 
 
 
 
148
  def process_video_with_background_stoppable(
149
  input_video: Optional[str],
150
  myavatar_selection: str,
@@ -154,15 +165,12 @@ def process_video_with_background_stoppable(
154
  custom_background: Optional[str],
155
  ai_prompt: str
156
  ):
157
- # start
158
- from pipeline import processing_active as _active_ref # ensure we use the module global
159
- import pipeline # to toggle the flag
160
-
161
- pipeline.processing_active = True
162
  try:
163
- yield gr.update(visible=False), gr.update(visible=True), None, "Starting processing..."
 
164
 
165
- # resolve video
166
  video_path = None
167
  if input_video:
168
  video_path = input_video
@@ -173,16 +181,23 @@ def process_video_with_background_stoppable(
173
  r.raise_for_status()
174
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
175
  for chunk in r.iter_content(chunk_size=1 << 20):
 
 
 
176
  if chunk:
177
  tmp.write(chunk)
178
  video_path = tmp.name
179
 
 
 
 
 
180
  if not video_path:
181
  yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
182
  return
183
 
184
- # background
185
- yield gr.update(visible=False), gr.update(visible=True), None, "Generating background..."
186
  bg_img = None
187
  if background_type == "gradient":
188
  bg_img = create_gradient_background(gradient_type, 1920, 1080)
@@ -190,50 +205,68 @@ def process_video_with_background_stoppable(
190
  bg_img = create_solid_color(solid_color, 1920, 1080)
191
  elif background_type == "custom" and custom_background:
192
  try:
193
- from PIL import Image
194
  bg_img = Image.open(custom_background).convert("RGB")
195
  except Exception:
196
  bg_img = None
197
  elif background_type == "ai" and ai_prompt:
198
  bg_img, _ = generate_ai_background(ai_prompt)
199
 
 
 
 
 
200
  if bg_img is None:
201
  yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
202
  return
203
 
204
- # process
205
- yield gr.update(visible=False), gr.update(visible=True), None, "Processing video with GPU optimization..."
206
- bg_array = np.array(bg_img.resize((1280, 720), Image.Resampling.LANCZOS))
207
- with tempfile.NamedTemporaryFile(suffix="_processed.mp4", delete=False) as tmp_final:
208
- final_path = tmp_final.name
209
-
210
- out = process_video_gpu_optimized(video_path, bg_array, final_path)
 
 
 
 
 
 
 
 
 
211
 
212
- try:
213
- if video_path != input_video and video_path and os.path.exists(video_path):
214
- os.unlink(video_path)
215
- except Exception:
216
- pass
217
 
218
- if out and pipeline.processing_active:
219
- yield gr.update(visible=True), gr.update(visible=False), out, "Video processing completed successfully!"
220
  else:
221
- yield gr.update(visible=True), gr.update(visible=False), None, "Processing was stopped or failed"
222
 
223
  except Exception as e:
224
  logger.error(f"UI pipeline error: {e}")
225
  yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
226
  finally:
227
- pipeline.processing_active = False
228
-
229
-
230
- def stop_processing_button():
231
- from pipeline import stop_processing
232
- stop_processing()
233
- return gr.update(visible=False), "Processing stopped by user"
234
-
235
 
236
  # ---- UI factory ----
 
 
 
 
 
 
 
 
 
 
237
  def create_interface():
238
  css = """
239
  .main-container { max-width: 1200px; margin: 0 auto; }
@@ -241,13 +274,12 @@ def create_interface():
241
  .gradient-preview { border: 2px solid #ddd; border-radius: 10px; }
242
  """
243
 
244
- with gr.Blocks(css=css, title="BackgroundFX Pro - GPU Optimized") as app:
245
- gr.Markdown("# BackgroundFX Pro - GPU Optimized\n### Professional Video Background Replacement with SAM2 + MatAnyone")
246
 
247
  with gr.Row():
248
- sam2_status = "Ready" if SAM2_ENABLED else "Disabled"
249
- matany_status = "Ready" if MATANY_ENABLED else "Disabled"
250
- gr.Markdown(f"**System Status:** Online | **GPU:** {GPU_NAME} | **SAM2:** {sam2_status} | **MatAnyone:** {matany_status}")
251
 
252
  with gr.Row():
253
  with gr.Column(scale=1):
@@ -277,19 +309,19 @@ def create_interface():
277
  ai_preview = gr.Image(label="AI Generated Background", height=150, visible=False)
278
 
279
  with gr.Row():
280
- process_btn = gr.Button("Process Video", variant="primary", size="lg")
281
- stop_btn = gr.Button("Stop Processing", variant="stop", size="lg", visible=False)
282
 
283
  with gr.Column(scale=1):
284
  gr.Markdown("## Results")
285
  result_video = gr.Video(label="Processed Video", height=400)
286
  status_output = gr.Textbox(label="Processing Status", lines=5, max_lines=10, elem_classes=["status-box"])
287
  gr.Markdown("""
288
- ### Processing Pipeline:
289
- 1. **SAM2 Segmentation** GPU-accelerated person detection
290
- 2. **MatAnyone Matting** temporal consistency
291
- 3. **GPU Compositing** real-time background replacement
292
- 4. **Memory Optimization** — chunked processing + OOM recovery
293
  """)
294
 
295
  # handlers
 
1
+ #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro — Gradio UI, background generators, and data sources (Hardened)
4
+ - No top-level import of pipeline (lazy import in handlers)
5
+ - Compatible with pipeline.process()
6
  """
7
 
8
  import io
 
16
  from PIL import Image
17
  import gradio as gr
18
 
 
 
 
 
 
19
  logger = logging.getLogger("ui")
20
+ if not logger.handlers:
21
+ h = logging.StreamHandler()
22
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
23
+ logger.addHandler(h)
24
+ logger.setLevel(logging.INFO)
25
 
26
  # ---- Background generators ----
27
  def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
 
52
  img[i, :] = [r, g, b]
53
  return Image.fromarray(img)
54
 
 
55
  def create_solid_color(color: str, width: int, height: int) -> Image.Image:
56
  color_map = {
57
  "white": (255, 255, 255),
 
66
  rgb = color_map.get(color, (70, 130, 180))
67
  return Image.fromarray(np.full((height, width, 3), rgb, dtype=np.uint8))
68
 
 
69
  def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
70
  try:
71
+ if not prompt or not prompt.strip():
72
  return None, "Please enter a prompt"
73
  models = [
74
  "black-forest-labs/FLUX.1-schnell",
75
  "stabilityai/stable-diffusion-xl-base-1.0",
76
+ "runwayml/stable-diffusion-v1-5",
77
  ]
78
  enhanced_prompt = f"professional video background, {prompt}, high quality, 16:9, cinematic lighting, detailed"
79
+ token = os.getenv("HUGGINGFACE_TOKEN", "")
80
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
81
  for model in models:
82
  try:
83
  url = f"https://api-inference.huggingface.co/models/{model}"
84
+ payload = {
85
+ "inputs": enhanced_prompt,
86
+ "parameters": {"width": 1024, "height": 576, "num_inference_steps": 20, "guidance_scale": 7.5},
87
+ }
88
  r = requests.post(url, headers=headers, json=payload, timeout=60, stream=True)
89
  if r.status_code == 200 and "image" in r.headers.get("content-type", "").lower():
90
  buf = io.BytesIO(r.content if r.raw is None else r.raw.read())
 
98
  logger.error(f"AI background error: {e}")
99
  return create_gradient_background("default", 1920, 1080), "Default due to error"
100
 
 
101
  # ---- MyAvatar API ----
102
  class MyAvatarAPI:
103
  def __init__(self):
104
+ self.api_base = os.getenv("MYAVATAR_API_BASE", "https://app.myavatar.dk/api")
105
  self.videos_cache: List[Dict[str, Any]] = []
106
  self.last_refresh = 0
107
 
 
142
  logger.error(f"Parse selection failed: {e}")
143
  return None
144
 
 
145
  myavatar_api = MyAvatarAPI()
146
 
147
+ # ---- Minimal stop flag (request-scoped) ----
148
+ # We avoid pipeline globals; this just short-circuits the generator.
149
+ class Stopper:
150
+ def __init__(self):
151
+ self.stop = False
152
+ STOP = Stopper()
153
 
154
+ def stop_processing_button():
155
+ STOP.stop = True
156
+ return gr.update(visible=False), "Processing stopped by user"
157
+
158
+ # ---- UI ↔ Pipeline bridge ----
159
  def process_video_with_background_stoppable(
160
  input_video: Optional[str],
161
  myavatar_selection: str,
 
165
  custom_background: Optional[str],
166
  ai_prompt: str
167
  ):
168
+ import importlib
 
 
 
 
169
  try:
170
+ STOP.stop = False
171
+ yield gr.update(visible=False), gr.update(visible=True), None, "Starting…"
172
 
173
+ # Resolve video
174
  video_path = None
175
  if input_video:
176
  video_path = input_video
 
181
  r.raise_for_status()
182
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
183
  for chunk in r.iter_content(chunk_size=1 << 20):
184
+ if STOP.stop:
185
+ yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
186
+ return
187
  if chunk:
188
  tmp.write(chunk)
189
  video_path = tmp.name
190
 
191
+ if STOP.stop:
192
+ yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
193
+ return
194
+
195
  if not video_path:
196
  yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
197
  return
198
 
199
+ # Background
200
+ yield gr.update(visible=False), gr.update(visible=True), None, "Preparing background"
201
  bg_img = None
202
  if background_type == "gradient":
203
  bg_img = create_gradient_background(gradient_type, 1920, 1080)
 
205
  bg_img = create_solid_color(solid_color, 1920, 1080)
206
  elif background_type == "custom" and custom_background:
207
  try:
 
208
  bg_img = Image.open(custom_background).convert("RGB")
209
  except Exception:
210
  bg_img = None
211
  elif background_type == "ai" and ai_prompt:
212
  bg_img, _ = generate_ai_background(ai_prompt)
213
 
214
+ if STOP.stop:
215
+ yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
216
+ return
217
+
218
  if bg_img is None:
219
  yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
220
  return
221
 
222
+ # Save background to a temp file for pipeline.process()
223
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_bg:
224
+ bg_img.save(tmp_bg.name, format="PNG")
225
+ bg_path = tmp_bg.name
226
+
227
+ # Run pipeline lazily
228
+ yield gr.update(visible=False), gr.update(visible=True), None, "Processing video…"
229
+ pipe = importlib.import_module("pipeline")
230
+ out_path, diag = pipe.process(
231
+ video_path=video_path,
232
+ bg_image_path=bg_path,
233
+ point_x=None,
234
+ point_y=None,
235
+ auto_box=True,
236
+ work_dir=None
237
+ )
238
 
239
+ if STOP.stop:
240
+ yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
241
+ return
 
 
242
 
243
+ if out_path:
244
+ yield gr.update(visible=True), gr.update(visible=False), out_path, "Video processing completed successfully!"
245
  else:
246
+ yield gr.update(visible=True), gr.update(visible=False), None, f"Processing failed: {diag.get('error','unknown error')}"
247
 
248
  except Exception as e:
249
  logger.error(f"UI pipeline error: {e}")
250
  yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
251
  finally:
252
+ # Best-effort cleanup of any temp download
253
+ try:
254
+ if input_video is None and 'video_path' in locals() and video_path and os.path.exists(video_path):
255
+ os.unlink(video_path)
256
+ except Exception:
257
+ pass
 
 
258
 
259
  # ---- UI factory ----
260
+ def _system_status():
261
+ # Avoid early CUDA probing: only show torch version if available
262
+ try:
263
+ import torch
264
+ tver = getattr(torch, "__version__", "?")
265
+ cver = getattr(getattr(torch, "version", None), "cuda", None)
266
+ return f"torch {tver} (CUDA {cver})"
267
+ except Exception:
268
+ return "torch not available"
269
+
270
  def create_interface():
271
  css = """
272
  .main-container { max-width: 1200px; margin: 0 auto; }
 
274
  .gradient-preview { border: 2px solid #ddd; border-radius: 10px; }
275
  """
276
 
277
+ with gr.Blocks(css=css, title="BackgroundFX Pro") as app:
278
+ gr.Markdown("# BackgroundFX Pro SAM2 + MatAnyone (Hardened)")
279
 
280
  with gr.Row():
281
+ status = _system_status()
282
+ gr.Markdown(f"**System Status:** Online | **Runtime:** {status}")
 
283
 
284
  with gr.Row():
285
  with gr.Column(scale=1):
 
309
  ai_preview = gr.Image(label="AI Generated Background", height=150, visible=False)
310
 
311
  with gr.Row():
312
+ process_btn = gr.Button("Process Video", variant="primary")
313
+ stop_btn = gr.Button("Stop Processing", variant="stop", visible=False)
314
 
315
  with gr.Column(scale=1):
316
  gr.Markdown("## Results")
317
  result_video = gr.Video(label="Processed Video", height=400)
318
  status_output = gr.Textbox(label="Processing Status", lines=5, max_lines=10, elem_classes=["status-box"])
319
  gr.Markdown("""
320
+ ### Pipeline
321
+ 1. SAM2 Segmentation mask
322
+ 2. MatAnyone Matting FG + ALPHA
323
+ 3. Stage-A export (transparent WebM or checkerboard)
324
+ 4. Final compositing (H.264)
325
  """)
326
 
327
  # handlers