final run?
Browse files- .dockerignore +11 -24
- Dockerfile +36 -50
- app.py +77 -110
- models/__init__.py +66 -53
- perf_tuning.py +99 -49
- pipeline.py +180 -111
- requirements.txt +4 -3
- ui.py +95 -63
.dockerignore
CHANGED
|
@@ -2,16 +2,12 @@
|
|
| 2 |
# .dockerignore for HF Spaces
|
| 3 |
# ===========================
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
# VCS (never needed in image)
|
| 7 |
-
# ---------------------------
|
| 8 |
.git
|
| 9 |
.gitignore
|
| 10 |
.gitattributes
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
# Python cache / build files
|
| 14 |
-
# ---------------------------
|
| 15 |
__pycache__/
|
| 16 |
*.py[cod]
|
| 17 |
*.pyo
|
|
@@ -20,49 +16,41 @@ __pycache__/
|
|
| 20 |
*.egg-info/
|
| 21 |
dist/
|
| 22 |
build/
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
# ---------------------------
|
| 25 |
# Virtual environments
|
| 26 |
-
# ---------------------------
|
| 27 |
.env
|
| 28 |
.venv/
|
| 29 |
env/
|
| 30 |
venv/
|
| 31 |
|
| 32 |
-
# ---------------------------
|
| 33 |
# External repos (cloned in Docker, not copied from local)
|
| 34 |
-
# ---------------------------
|
| 35 |
third_party/
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
# Hugging Face / Torch caches (but allow model files that might be needed)
|
| 39 |
-
# ---------------------------
|
| 40 |
.cache/
|
| 41 |
huggingface/
|
| 42 |
torch/
|
| 43 |
data/
|
| 44 |
|
| 45 |
-
# ---------------------------
|
| 46 |
# HF Space metadata/state
|
| 47 |
-
# ---------------------------
|
| 48 |
.hf_space/
|
| 49 |
space.log
|
| 50 |
gradio_cached_examples/
|
| 51 |
gradio_static/
|
| 52 |
__outputs__/
|
| 53 |
|
| 54 |
-
# ---------------------------
|
| 55 |
# Logs & temp files
|
| 56 |
-
# ---------------------------
|
| 57 |
*.log
|
| 58 |
logs/
|
| 59 |
tmp/
|
| 60 |
temp/
|
| 61 |
*.swp
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
# ---------------------------
|
| 64 |
# Media test assets
|
| 65 |
-
# ---------------------------
|
| 66 |
*.mp4
|
| 67 |
*.avi
|
| 68 |
*.mov
|
|
@@ -72,9 +60,7 @@ temp/
|
|
| 72 |
*.jpeg
|
| 73 |
*.gif
|
| 74 |
|
| 75 |
-
# ---------------------------
|
| 76 |
# OS / IDE cruft
|
| 77 |
-
# ---------------------------
|
| 78 |
.DS_Store
|
| 79 |
Thumbs.db
|
| 80 |
.vscode/
|
|
@@ -82,10 +68,11 @@ Thumbs.db
|
|
| 82 |
*.sublime-project
|
| 83 |
*.sublime-workspace
|
| 84 |
|
| 85 |
-
# ---------------------------
|
| 86 |
# Node / frontend (if present)
|
| 87 |
-
# ---------------------------
|
| 88 |
node_modules/
|
| 89 |
npm-debug.log
|
| 90 |
yarn-debug.log
|
| 91 |
-
yarn-error.log
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
# .dockerignore for HF Spaces
|
| 3 |
# ===========================
|
| 4 |
|
| 5 |
+
# VCS
|
|
|
|
|
|
|
| 6 |
.git
|
| 7 |
.gitignore
|
| 8 |
.gitattributes
|
| 9 |
|
| 10 |
+
# Python cache / build
|
|
|
|
|
|
|
| 11 |
__pycache__/
|
| 12 |
*.py[cod]
|
| 13 |
*.pyo
|
|
|
|
| 16 |
*.egg-info/
|
| 17 |
dist/
|
| 18 |
build/
|
| 19 |
+
.pytest_cache/
|
| 20 |
+
.python-version
|
| 21 |
|
|
|
|
| 22 |
# Virtual environments
|
|
|
|
| 23 |
.env
|
| 24 |
.venv/
|
| 25 |
env/
|
| 26 |
venv/
|
| 27 |
|
|
|
|
| 28 |
# External repos (cloned in Docker, not copied from local)
|
|
|
|
| 29 |
third_party/
|
| 30 |
|
| 31 |
+
# Hugging Face / Torch caches
|
|
|
|
|
|
|
| 32 |
.cache/
|
| 33 |
huggingface/
|
| 34 |
torch/
|
| 35 |
data/
|
| 36 |
|
|
|
|
| 37 |
# HF Space metadata/state
|
|
|
|
| 38 |
.hf_space/
|
| 39 |
space.log
|
| 40 |
gradio_cached_examples/
|
| 41 |
gradio_static/
|
| 42 |
__outputs__/
|
| 43 |
|
|
|
|
| 44 |
# Logs & temp files
|
|
|
|
| 45 |
*.log
|
| 46 |
logs/
|
| 47 |
tmp/
|
| 48 |
temp/
|
| 49 |
*.swp
|
| 50 |
+
.coverage
|
| 51 |
+
coverage.xml
|
| 52 |
|
|
|
|
| 53 |
# Media test assets
|
|
|
|
| 54 |
*.mp4
|
| 55 |
*.avi
|
| 56 |
*.mov
|
|
|
|
| 60 |
*.jpeg
|
| 61 |
*.gif
|
| 62 |
|
|
|
|
| 63 |
# OS / IDE cruft
|
|
|
|
| 64 |
.DS_Store
|
| 65 |
Thumbs.db
|
| 66 |
.vscode/
|
|
|
|
| 68 |
*.sublime-project
|
| 69 |
*.sublime-workspace
|
| 70 |
|
|
|
|
| 71 |
# Node / frontend (if present)
|
|
|
|
| 72 |
node_modules/
|
| 73 |
npm-debug.log
|
| 74 |
yarn-debug.log
|
| 75 |
+
yarn-error.log
|
| 76 |
+
|
| 77 |
+
# ---- Optional: allow specific checkpoints if needed ----
|
| 78 |
+
!checkpoints/
|
Dockerfile
CHANGED
|
@@ -1,119 +1,105 @@
|
|
| 1 |
# ===============================
|
| 2 |
-
# BackgroundFX Pro — Dockerfile (
|
| 3 |
-
# Hugging Face Spaces Pro (GPU)
|
| 4 |
# ===============================
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
FROM nvidia/cuda:12.
|
| 8 |
|
| 9 |
-
# --- Build args (
|
| 10 |
-
# Pin external repos for reproducible builds
|
| 11 |
ARG SAM2_SHA=__PIN_ME__
|
| 12 |
ARG MATANYONE_SHA=__PIN_ME__
|
| 13 |
-
|
| 14 |
-
# (legacy/optional) Model IDs — you can still use these elsewhere if you want
|
| 15 |
ARG SAM2_MODEL_ID=facebook/sam2
|
| 16 |
-
ARG SAM2_VARIANT=sam2_hiera_large
|
| 17 |
ARG MATANY_REPO_ID=PeiqingYang/MatAnyone
|
| 18 |
ARG MATANY_FILENAME=matanyone_v1.0.pth
|
| 19 |
|
| 20 |
-
# ---
|
| 21 |
RUN useradd -m -u 1000 user
|
| 22 |
ENV HOME=/home/user
|
| 23 |
ENV PATH=/home/user/.local/bin:$PATH
|
| 24 |
-
RUN mkdir -p /home/user/app && chown -R user:user /home/user
|
| 25 |
WORKDIR /home/user/app
|
| 26 |
|
| 27 |
# --- System packages ---
|
| 28 |
USER root
|
| 29 |
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
| 30 |
git ffmpeg python3 python3-pip python3-venv \
|
|
|
|
| 31 |
libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 \
|
| 32 |
&& rm -rf /var/lib/apt/lists/*
|
| 33 |
|
| 34 |
-
#
|
| 35 |
RUN mkdir -p /data/.cache && chown -R user:user /data
|
| 36 |
USER user
|
| 37 |
|
| 38 |
-
# --- Python
|
| 39 |
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
| 40 |
RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
|
| 41 |
torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1
|
| 42 |
|
| 43 |
-
# --- App
|
| 44 |
COPY --chown=user requirements.txt ./requirements.txt
|
| 45 |
RUN python3 -m pip install --no-cache-dir -r requirements.txt
|
| 46 |
-
# Optional
|
| 47 |
RUN python3 -m pip install --no-cache-dir mediapipe==0.10.14
|
| 48 |
|
| 49 |
-
# ---
|
| 50 |
RUN mkdir -p third_party
|
| 51 |
|
| 52 |
-
# SAM2
|
| 53 |
-
RUN git clone https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
|
| 54 |
cd third_party/sam2 && \
|
| 55 |
-
if [ "${SAM2_SHA}" != "__PIN_ME__" ]; then git checkout
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
RUN echo "=== DEBUG: SAM2
|
| 59 |
ls -la third_party/sam2/ && \
|
| 60 |
-
echo "=== DEBUG: Config directory ===" && \
|
| 61 |
-
ls -la third_party/sam2/configs/ || echo "configs directory not found" && \
|
| 62 |
echo "=== DEBUG: SAM2 configs ===" && \
|
| 63 |
-
ls -la third_party/sam2/configs/sam2/ || echo "
|
| 64 |
|
| 65 |
-
# Install SAM2
|
| 66 |
RUN cd third_party/sam2 && python3 -m pip install --no-cache-dir -e .
|
| 67 |
|
| 68 |
-
# MatAnyone (pq-yang fork
|
| 69 |
-
RUN git clone https://github.com/pq-yang/MatAnyone.git third_party/matanyone && \
|
| 70 |
cd third_party/matanyone && \
|
| 71 |
-
if [ "${MATANYONE_SHA}" != "__PIN_ME__" ]; then git checkout
|
| 72 |
|
| 73 |
-
# Install MatAnyone requirements if
|
| 74 |
RUN cd third_party/matanyone && \
|
| 75 |
if [ -f requirements.txt ]; then python3 -m pip install --no-cache-dir -r requirements.txt; fi
|
| 76 |
|
| 77 |
-
# --- App code ---
|
| 78 |
COPY --chown=user . /home/user/app
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
RUN echo "=== DEBUG: After
|
| 82 |
ls -la third_party/sam2/ && \
|
| 83 |
-
|
| 84 |
-
ls -la third_party/sam2/configs/sam2/ || echo "Config directory missing after copy"
|
| 85 |
|
| 86 |
-
# --- Runtime environment
|
| 87 |
ENV PYTHONUNBUFFERED=1 \
|
| 88 |
-
OMP_NUM_THREADS=
|
| 89 |
TOKENIZERS_PARALLELISM=false \
|
| 90 |
HF_HOME=/data/.cache/huggingface \
|
| 91 |
TORCH_HOME=/data/.cache/torch \
|
| 92 |
MPLCONFIGDIR=/data/.cache/matplotlib \
|
| 93 |
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
|
| 94 |
PYTHONPATH="$PYTHONPATH:/home/user/app/third_party/sam2:/home/user/app/third_party/matanyone" \
|
| 95 |
-
PORT=7860 \
|
| 96 |
FFMPEG_BIN=ffmpeg \
|
| 97 |
-
\
|
| 98 |
-
# Let pipeline.py discover these dynamically (no hard-coded paths)
|
| 99 |
THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
|
| 100 |
THIRD_PARTY_MATANY_DIR=/home/user/app/third_party/matanyone \
|
| 101 |
-
\
|
| 102 |
-
# --- SAM2 dynamic config (FIXED: relative path within SAM2 repo) ---
|
| 103 |
SAM2_MODEL_CFG="configs/sam2/sam2_hiera_l.yaml" \
|
| 104 |
SAM2_CHECKPOINT="" \
|
| 105 |
-
\
|
| 106 |
-
# --- MatAnyone dynamic config (used by pipeline.py) ---
|
| 107 |
MATANY_REPO_ID=PeiqingYang/MatAnyone \
|
| 108 |
MATANY_CHECKPOINT="" \
|
| 109 |
ENABLE_MATANY=1
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
RUN echo "=== FINAL DEBUG: SAM2 status ===" && \
|
| 113 |
-
pwd && \
|
| 114 |
-
ls -la /home/user/app/third_party/sam2/ || echo "SAM2 directory missing" && \
|
| 115 |
-
ls -la /home/user/app/third_party/sam2/configs/sam2/ || echo "Config dir missing"
|
| 116 |
|
| 117 |
-
# --- Networking / Entrypoint ---
|
| 118 |
EXPOSE 7860
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# ===============================
|
| 2 |
+
# BackgroundFX Pro — Dockerfile (Hardened for Spaces GPU)
|
|
|
|
| 3 |
# ===============================
|
| 4 |
|
| 5 |
+
# Match PyTorch cu121 wheels (critical to avoid CUDA probe stalls)
|
| 6 |
+
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
| 7 |
|
| 8 |
+
# --- Build args (optional pins) ---
|
|
|
|
| 9 |
ARG SAM2_SHA=__PIN_ME__
|
| 10 |
ARG MATANYONE_SHA=__PIN_ME__
|
|
|
|
|
|
|
| 11 |
ARG SAM2_MODEL_ID=facebook/sam2
|
| 12 |
+
ARG SAM2_VARIANT=sam2_hiera_large
|
| 13 |
ARG MATANY_REPO_ID=PeiqingYang/MatAnyone
|
| 14 |
ARG MATANY_FILENAME=matanyone_v1.0.pth
|
| 15 |
|
| 16 |
+
# --- Non-root user (HF expects uid 1000) ---
|
| 17 |
RUN useradd -m -u 1000 user
|
| 18 |
ENV HOME=/home/user
|
| 19 |
ENV PATH=/home/user/.local/bin:$PATH
|
|
|
|
| 20 |
WORKDIR /home/user/app
|
| 21 |
|
| 22 |
# --- System packages ---
|
| 23 |
USER root
|
| 24 |
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
| 25 |
git ffmpeg python3 python3-pip python3-venv \
|
| 26 |
+
wget curl ca-certificates \
|
| 27 |
libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 \
|
| 28 |
&& rm -rf /var/lib/apt/lists/*
|
| 29 |
|
| 30 |
+
# Caches (writable)
|
| 31 |
RUN mkdir -p /data/.cache && chown -R user:user /data
|
| 32 |
USER user
|
| 33 |
|
| 34 |
+
# --- Python + Torch (cu121) ---
|
| 35 |
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
| 36 |
RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
|
| 37 |
torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1
|
| 38 |
|
| 39 |
+
# --- App deps ---
|
| 40 |
COPY --chown=user requirements.txt ./requirements.txt
|
| 41 |
RUN python3 -m pip install --no-cache-dir -r requirements.txt
|
| 42 |
+
# Optional nice fallback
|
| 43 |
RUN python3 -m pip install --no-cache-dir mediapipe==0.10.14
|
| 44 |
|
| 45 |
+
# --- Third-party repos (build-time, never at runtime) ---
|
| 46 |
RUN mkdir -p third_party
|
| 47 |
|
| 48 |
+
# SAM2 (shallow clone; optional SHA pin)
|
| 49 |
+
RUN git clone --depth=1 https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
|
| 50 |
cd third_party/sam2 && \
|
| 51 |
+
if [ "${SAM2_SHA}" != "__PIN_ME__" ]; then git fetch --depth=1 origin ${SAM2_SHA} && git checkout ${SAM2_SHA}; fi
|
| 52 |
|
| 53 |
+
# Show what we got
|
| 54 |
+
RUN echo "=== DEBUG: SAM2 contents ===" && \
|
| 55 |
ls -la third_party/sam2/ && \
|
|
|
|
|
|
|
| 56 |
echo "=== DEBUG: SAM2 configs ===" && \
|
| 57 |
+
(ls -la third_party/sam2/configs/sam2/ || echo "configs missing")
|
| 58 |
|
| 59 |
+
# Install SAM2 (editable ok)
|
| 60 |
RUN cd third_party/sam2 && python3 -m pip install --no-cache-dir -e .
|
| 61 |
|
| 62 |
+
# MatAnyone (pq-yang fork per your setup)
|
| 63 |
+
RUN git clone --depth=1 https://github.com/pq-yang/MatAnyone.git third_party/matanyone && \
|
| 64 |
cd third_party/matanyone && \
|
| 65 |
+
if [ "${MATANYONE_SHA}" != "__PIN_ME__" ]; then git fetch --depth=1 origin ${MATANYONE_SHA} && git checkout ${MATANYONE_SHA}; fi
|
| 66 |
|
| 67 |
+
# Install MatAnyone requirements if present
|
| 68 |
RUN cd third_party/matanyone && \
|
| 69 |
if [ -f requirements.txt ]; then python3 -m pip install --no-cache-dir -r requirements.txt; fi
|
| 70 |
|
| 71 |
+
# --- App code last (so code changes don't invalidate heavy layers) ---
|
| 72 |
COPY --chown=user . /home/user/app
|
| 73 |
|
| 74 |
+
# Verify clone not overwritten by COPY
|
| 75 |
+
RUN echo "=== DEBUG: After COPY ===" && \
|
| 76 |
ls -la third_party/sam2/ && \
|
| 77 |
+
(ls -la third_party/sam2/configs/sam2/ || echo "SAM2 configs missing")
|
|
|
|
| 78 |
|
| 79 |
+
# --- Runtime environment ---
|
| 80 |
ENV PYTHONUNBUFFERED=1 \
|
| 81 |
+
OMP_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 MKL_NUM_THREADS=1 NUMEXPR_NUM_THREADS=1 \
|
| 82 |
TOKENIZERS_PARALLELISM=false \
|
| 83 |
HF_HOME=/data/.cache/huggingface \
|
| 84 |
TORCH_HOME=/data/.cache/torch \
|
| 85 |
MPLCONFIGDIR=/data/.cache/matplotlib \
|
| 86 |
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
|
| 87 |
PYTHONPATH="$PYTHONPATH:/home/user/app/third_party/sam2:/home/user/app/third_party/matanyone" \
|
|
|
|
| 88 |
FFMPEG_BIN=ffmpeg \
|
|
|
|
|
|
|
| 89 |
THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
|
| 90 |
THIRD_PARTY_MATANY_DIR=/home/user/app/third_party/matanyone \
|
|
|
|
|
|
|
| 91 |
SAM2_MODEL_CFG="configs/sam2/sam2_hiera_l.yaml" \
|
| 92 |
SAM2_CHECKPOINT="" \
|
|
|
|
|
|
|
| 93 |
MATANY_REPO_ID=PeiqingYang/MatAnyone \
|
| 94 |
MATANY_CHECKPOINT="" \
|
| 95 |
ENABLE_MATANY=1
|
| 96 |
|
| 97 |
+
# Do NOT set PORT here; Spaces injects it.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
|
|
|
| 99 |
EXPOSE 7860
|
| 100 |
+
|
| 101 |
+
# Optional: basic health check to see if the server bound
|
| 102 |
+
HEALTHCHECK --interval=30s --timeout=5s --retries=5 CMD wget -qO- "http://127.0.0.1:${PORT:-7860}/" || exit 1
|
| 103 |
+
|
| 104 |
+
# Use exec form + unbuffered
|
| 105 |
+
CMD ["python3","-u","app.py"]
|
app.py
CHANGED
|
@@ -1,28 +1,21 @@
|
|
| 1 |
-
# app.py
|
| 2 |
#!/usr/bin/env python3
|
| 3 |
"""
|
| 4 |
-
BackgroundFX Pro
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
- Uses pipeline.process() which orchestrates:
|
| 8 |
-
SAM2 first-frame segmentation → MatAnyone temporal matting → compositing
|
| 9 |
-
- Robust fallbacks (MediaPipe / GrabCut; static-mask compositing)
|
| 10 |
-
- Diagnostics JSON shows which engines ran and on which device
|
| 11 |
-
- All paths/devices set by environment variables (see pipeline.py header)
|
| 12 |
"""
|
| 13 |
|
| 14 |
import os
|
| 15 |
-
import json
|
| 16 |
import logging
|
|
|
|
|
|
|
| 17 |
import subprocess
|
| 18 |
-
from pathlib import Path
|
| 19 |
-
from typing import Optional, Tuple
|
| 20 |
|
| 21 |
import gradio as gr
|
| 22 |
|
| 23 |
-
# -----------------------------------------------------------------------------
|
| 24 |
-
#
|
| 25 |
-
# -----------------------------------------------------------------------------
|
| 26 |
logger = logging.getLogger("backgroundfx_pro")
|
| 27 |
if not logger.handlers:
|
| 28 |
h = logging.StreamHandler()
|
|
@@ -30,116 +23,90 @@
|
|
| 30 |
logger.addHandler(h)
|
| 31 |
logger.setLevel(logging.INFO)
|
| 32 |
|
| 33 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
try:
|
| 35 |
import perf_tuning # noqa: F401
|
| 36 |
logger.info("perf_tuning imported successfully.")
|
| 37 |
except Exception as e:
|
| 38 |
-
logger.warning(
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
try:
|
| 43 |
import torch
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
try:
|
| 48 |
idx = torch.cuda.current_device()
|
| 49 |
name = torch.cuda.get_device_name(idx)
|
| 50 |
cap = torch.cuda.get_device_capability(idx)
|
| 51 |
-
logger.info(
|
| 52 |
except Exception as e:
|
| 53 |
-
logger.
|
| 54 |
-
except Exception as e:
|
| 55 |
-
logger.warning(f"Could not import torch for GPU diag: {e}")
|
| 56 |
-
|
| 57 |
-
# nvidia-smi
|
| 58 |
-
try:
|
| 59 |
-
out = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True)
|
| 60 |
-
if out.returncode == 0:
|
| 61 |
-
logger.info("nvidia-smi -L:\n" + out.stdout.strip())
|
| 62 |
-
else:
|
| 63 |
-
logger.warning("nvidia-smi -L failed or unavailable.")
|
| 64 |
except Exception as e:
|
| 65 |
-
logger.warning(
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# -----------------------------------------------------------------------------
|
| 70 |
-
|
| 71 |
-
#
|
| 72 |
-
import
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _process_entry(video, bg_image, point_x, point_y, auto_box, progress=gr.Progress(track_tqdm=True)):
|
| 76 |
-
"""
|
| 77 |
-
Gradio wrapper → returns (video_path, diagnostics_json_str)
|
| 78 |
-
"""
|
| 79 |
-
if video is None or bg_image is None:
|
| 80 |
-
return None, json.dumps({"error": "Please provide both a video and a background image."}, indent=2)
|
| 81 |
-
|
| 82 |
-
# Gradio can pass dict-like objects for file with 'name' key, normalize to path
|
| 83 |
-
vpath = video if isinstance(video, (str, Path)) else getattr(video, "name", None) or video.get("name")
|
| 84 |
-
bpath = bg_image if isinstance(bg_image, (str, Path)) else getattr(bg_image, "name", None) or bg_image.get("name")
|
| 85 |
-
|
| 86 |
-
progress(0.05, desc="Starting…")
|
| 87 |
-
out_path, diag = pipeline.process(
|
| 88 |
-
video_path=vpath,
|
| 89 |
-
bg_image_path=bpath,
|
| 90 |
-
point_x=point_x if point_x not in (None, "") else None,
|
| 91 |
-
point_y=point_y if point_y not in (None, "") else None,
|
| 92 |
-
auto_box=bool(auto_box),
|
| 93 |
-
work_dir=None # pipeline will create a temp dir
|
| 94 |
-
)
|
| 95 |
-
progress(0.95, desc="Finalizing…")
|
| 96 |
-
|
| 97 |
-
return (out_path if out_path else None), json.dumps(diag, indent=2)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
with gr.Blocks(title="BackgroundFX Pro (SAM2 + MatAnyone)", theme=gr.themes.Soft()) as demo:
|
| 101 |
-
gr.Markdown(
|
| 102 |
-
"""
|
| 103 |
-
# 🎬 BackgroundFX Pro
|
| 104 |
-
**SAM2 + MatAnyone** with robust fallbacks. All configs/devices are dynamic via environment variables.
|
| 105 |
-
|
| 106 |
-
- Upload a video and a background image.
|
| 107 |
-
- Optionally provide a foreground point (x, y) in pixels for the first frame **or** tick *Auto subject box*.
|
| 108 |
-
- Click **Process**. The app will try SAM2 → MatAnyone; if anything fails, it falls back automatically.
|
| 109 |
-
"""
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
with gr.Row():
|
| 113 |
-
with gr.Column(scale=2):
|
| 114 |
-
in_video = gr.Video(label="Input Video", sources=["upload"], interactive=True)
|
| 115 |
-
in_bg = gr.Image(label="Background Image", type="filepath", interactive=True)
|
| 116 |
-
with gr.Column(scale=1):
|
| 117 |
-
point_x = gr.Number(label="Foreground point X (optional)", value=None, precision=0)
|
| 118 |
-
point_y = gr.Number(label="Foreground point Y (optional)", value=None, precision=0)
|
| 119 |
-
auto_box = gr.Checkbox(label="Auto subject box (ignore point)", value=True)
|
| 120 |
-
process_btn = gr.Button("Process", variant="primary")
|
| 121 |
-
|
| 122 |
-
with gr.Row():
|
| 123 |
-
out_video = gr.Video(label="Output (H.264 MP4)")
|
| 124 |
-
out_diag = gr.JSON(label="Diagnostics")
|
| 125 |
-
|
| 126 |
-
def _on_click(video, bg, px, py, auto):
|
| 127 |
-
v, d = _process_entry(video, bg, px, py, auto)
|
| 128 |
-
try:
|
| 129 |
-
d_dict = json.loads(d)
|
| 130 |
-
except Exception:
|
| 131 |
-
d_dict = {"raw": d}
|
| 132 |
-
return v, d_dict
|
| 133 |
-
|
| 134 |
-
process_btn.click(
|
| 135 |
-
_on_click,
|
| 136 |
-
inputs=[in_video, in_bg, point_x, point_y, auto_box],
|
| 137 |
-
outputs=[out_video, out_diag]
|
| 138 |
-
)
|
| 139 |
|
| 140 |
if __name__ == "__main__":
|
| 141 |
-
# Dynamic host/port via env; suitable defaults for Hugging Face Spaces
|
| 142 |
host = os.environ.get("HOST", "0.0.0.0")
|
| 143 |
port = int(os.environ.get("PORT", "7860"))
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
BackgroundFX Pro — App Entrypoint (UI separated)
|
| 4 |
+
- UI is built in ui.py (create_interface)
|
| 5 |
+
- Hardened startup: heartbeat, safe diag, bind to $PORT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 9 |
import logging
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
import subprocess
|
|
|
|
|
|
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
# Logging early
|
| 18 |
+
# -----------------------------------------------------------------------------
|
| 19 |
logger = logging.getLogger("backgroundfx_pro")
|
| 20 |
if not logger.handlers:
|
| 21 |
h = logging.StreamHandler()
|
|
|
|
| 23 |
logger.addHandler(h)
|
| 24 |
logger.setLevel(logging.INFO)
|
| 25 |
|
| 26 |
+
# Heartbeat so logs never go silent during startup/imports
|
| 27 |
+
def _heartbeat():
|
| 28 |
+
i = 0
|
| 29 |
+
while True:
|
| 30 |
+
i += 1
|
| 31 |
+
print(f"[startup-heartbeat] {i*5}s…", flush=True)
|
| 32 |
+
time.sleep(5)
|
| 33 |
+
|
| 34 |
+
threading.Thread(target=_heartbeat, daemon=True).start()
|
| 35 |
+
|
| 36 |
+
# -----------------------------------------------------------------------------
|
| 37 |
+
# Safe, minimal startup diagnostics (no long CUDA probes)
|
| 38 |
+
# -----------------------------------------------------------------------------
|
| 39 |
+
def _safe_startup_diag():
|
| 40 |
+
# Torch version only; defer CUDA availability checks to post-launch
|
| 41 |
+
try:
|
| 42 |
+
import torch # noqa: F401
|
| 43 |
+
import importlib
|
| 44 |
+
t = importlib.import_module("torch")
|
| 45 |
+
logger.info("torch imported: %s | torch.version.cuda=%s",
|
| 46 |
+
getattr(t, "__version__", "?"),
|
| 47 |
+
getattr(getattr(t, "version", None), "cuda", None))
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.warning("Torch not available at startup: %s", e)
|
| 50 |
+
|
| 51 |
+
# nvidia-smi with short timeout (avoid indefinite block)
|
| 52 |
+
try:
|
| 53 |
+
out = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=2)
|
| 54 |
+
if out.returncode == 0:
|
| 55 |
+
logger.info("nvidia-smi -L:\n%s", out.stdout.strip())
|
| 56 |
+
else:
|
| 57 |
+
logger.warning("nvidia-smi -L failed or unavailable (rc=%s).", out.returncode)
|
| 58 |
+
except subprocess.TimeoutExpired:
|
| 59 |
+
logger.warning("nvidia-smi -L timed out (skipping).")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning("nvidia-smi not runnable: %s", e)
|
| 62 |
+
|
| 63 |
+
# Optional perf tuning; never block startup
|
| 64 |
try:
|
| 65 |
import perf_tuning # noqa: F401
|
| 66 |
logger.info("perf_tuning imported successfully.")
|
| 67 |
except Exception as e:
|
| 68 |
+
logger.warning("perf_tuning not loaded: %s", e)
|
| 69 |
+
|
| 70 |
+
_safe_startup_diag()
|
| 71 |
|
| 72 |
+
# -----------------------------------------------------------------------------
|
| 73 |
+
# Post-launch CUDA diag in background (so it never blocks binding the port)
|
| 74 |
+
# -----------------------------------------------------------------------------
|
| 75 |
+
def _post_launch_diag():
|
| 76 |
try:
|
| 77 |
import torch
|
| 78 |
+
try:
|
| 79 |
+
avail = torch.cuda.is_available()
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.warning("torch.cuda.is_available() failed: %s", e)
|
| 82 |
+
avail = False
|
| 83 |
+
logger.info("CUDA available: %s", avail)
|
| 84 |
+
if avail:
|
| 85 |
try:
|
| 86 |
idx = torch.cuda.current_device()
|
| 87 |
name = torch.cuda.get_device_name(idx)
|
| 88 |
cap = torch.cuda.get_device_capability(idx)
|
| 89 |
+
logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
|
| 90 |
except Exception as e:
|
| 91 |
+
logger.warning("CUDA device query failed: %s", e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
+
logger.warning("Post-launch torch diag failed: %s", e)
|
| 94 |
|
| 95 |
+
# -----------------------------------------------------------------------------
|
| 96 |
+
# Build UI (in separate module) and launch
|
| 97 |
+
# -----------------------------------------------------------------------------
|
| 98 |
+
def build_ui() -> gr.Blocks:
|
| 99 |
+
# Import here so any heavy imports inside ui.py (it shouldn’t) would show up after logs are configured
|
| 100 |
+
from ui import create_interface
|
| 101 |
+
return create_interface()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
if __name__ == "__main__":
|
|
|
|
| 104 |
host = os.environ.get("HOST", "0.0.0.0")
|
| 105 |
port = int(os.environ.get("PORT", "7860"))
|
| 106 |
+
logger.info("Launching Gradio on %s:%s …", host, port)
|
| 107 |
+
|
| 108 |
+
demo = build_ui()
|
| 109 |
+
demo.queue(max_size=16)
|
| 110 |
+
|
| 111 |
+
threading.Thread(target=_post_launch_diag, daemon=True).start()
|
| 112 |
+
demo.launch(server_name=host, server_port=port, show_error=True)
|
models/__init__.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
BackgroundFX Pro - Model Loading & Utilities
|
| 4 |
-
===========================================
|
| 5 |
-
|
| 6 |
-
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
|
@@ -19,12 +20,24 @@
|
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import yaml
|
| 22 |
-
import torch # For memory management and CUDA operations
|
| 23 |
|
| 24 |
# --------------------------------------------------------------------------------------
|
| 25 |
-
# Logging
|
| 26 |
# --------------------------------------------------------------------------------------
|
| 27 |
logger = logging.getLogger("backgroundfx_pro")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# --------------------------------------------------------------------------------------
|
| 30 |
# Optional dependencies
|
|
@@ -38,35 +51,40 @@
|
|
| 38 |
# --------------------------------------------------------------------------------------
|
| 39 |
# Path setup for third_party repos
|
| 40 |
# --------------------------------------------------------------------------------------
|
| 41 |
-
ROOT = Path(__file__).resolve().parent.parent #
|
| 42 |
TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
|
| 43 |
TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
|
| 44 |
|
| 45 |
def _add_sys_path(p: Path) -> None:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
sys.path
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
_add_sys_path(TP_SAM2)
|
| 51 |
_add_sys_path(TP_MATANY)
|
| 52 |
|
| 53 |
# --------------------------------------------------------------------------------------
|
| 54 |
-
#
|
| 55 |
# --------------------------------------------------------------------------------------
|
| 56 |
-
def
|
| 57 |
-
return os.environ.get("FFMPEG_BIN", "ffmpeg")
|
| 58 |
-
|
| 59 |
-
def _probe_ffmpeg() -> bool:
|
| 60 |
try:
|
| 61 |
-
|
| 62 |
-
return
|
| 63 |
-
except Exception:
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
def _has_cuda() -> bool:
|
|
|
|
|
|
|
|
|
|
| 67 |
try:
|
| 68 |
-
return
|
| 69 |
-
except Exception:
|
|
|
|
| 70 |
return False
|
| 71 |
|
| 72 |
def _pick_device(env_key: str) -> str:
|
|
@@ -75,6 +93,19 @@ def _pick_device(env_key: str) -> str:
|
|
| 75 |
return requested
|
| 76 |
return "cuda" if _has_cuda() else "cpu"
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def _ensure_dir(p: Path) -> None:
|
| 79 |
p.mkdir(parents=True, exist_ok=True)
|
| 80 |
|
|
@@ -141,7 +172,6 @@ def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_
|
|
| 141 |
# Compositing & Image Processing
|
| 142 |
# --------------------------------------------------------------------------------------
|
| 143 |
def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
|
| 144 |
-
"""Erode→dilate + gentle blur → float alpha in [0,1]."""
|
| 145 |
if alpha.dtype != np.float32:
|
| 146 |
a = alpha.astype(np.float32)
|
| 147 |
if a.max() > 1.0:
|
|
@@ -173,7 +203,6 @@ def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
|
|
| 173 |
return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
|
| 174 |
|
| 175 |
def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
|
| 176 |
-
"""Simple light wrap from background into subject edges."""
|
| 177 |
r = max(1, int(radius))
|
| 178 |
inv = 1.0 - alpha01
|
| 179 |
inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
|
|
@@ -181,8 +210,7 @@ def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount
|
|
| 181 |
return lw
|
| 182 |
|
| 183 |
def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
|
| 184 |
-
|
| 185 |
-
w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) # bell-shaped weight
|
| 186 |
w = np.clip(w, 0.0, 1.0)
|
| 187 |
hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
|
| 188 |
H, S, V = cv2.split(hsv)
|
|
@@ -191,11 +219,11 @@ def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35
|
|
| 191 |
out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
|
| 192 |
return out
|
| 193 |
|
| 194 |
-
def _composite_frame_pro(
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
|
| 200 |
dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
|
| 201 |
blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
|
|
@@ -203,17 +231,11 @@ def _composite_frame_pro(fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarr
|
|
| 203 |
lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
|
| 204 |
despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
|
| 205 |
|
| 206 |
-
# refine alpha [0,1]
|
| 207 |
a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
|
| 208 |
-
|
| 209 |
-
# edge de-spill: temper saturation where a≈0.5
|
| 210 |
fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
|
| 211 |
|
| 212 |
-
# linearize for better blending
|
| 213 |
fg_lin = _to_linear(fg_rgb)
|
| 214 |
bg_lin = _to_linear(bg_rgb)
|
| 215 |
-
|
| 216 |
-
# light wrap
|
| 217 |
lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
|
| 218 |
lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
|
| 219 |
|
|
@@ -233,30 +255,27 @@ def _resolve_sam2_cfg(cfg_str: str) -> str:
|
|
| 233 |
return str(candidate)
|
| 234 |
if cfg_path.exists():
|
| 235 |
return str(cfg_path)
|
| 236 |
-
# Last resort: common defaults inside the repo
|
| 237 |
for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
|
| 238 |
p = TP_SAM2 / name
|
| 239 |
if p.exists():
|
| 240 |
return str(p)
|
| 241 |
-
return str(cfg_str)
|
| 242 |
|
| 243 |
def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
|
| 244 |
"""If config references 'hieradet', try to find a 'hiera' config."""
|
| 245 |
try:
|
| 246 |
with open(cfg_path, "r") as f:
|
| 247 |
data = yaml.safe_load(f)
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
trunk = (enc.get("trunk") or {})
|
| 252 |
target = trunk.get("_target_") or trunk.get("target")
|
| 253 |
if isinstance(target, str) and "hieradet" in target:
|
| 254 |
for y in TP_SAM2.rglob("*.yaml"):
|
| 255 |
try:
|
| 256 |
with open(y, "r") as f2:
|
| 257 |
-
d2 = yaml.safe_load(f2)
|
| 258 |
-
|
| 259 |
-
e2 = (m2.get("image_encoder") or {})
|
| 260 |
t2 = (e2.get("trunk") or {})
|
| 261 |
tgt2 = t2.get("_target_") or t2.get("target")
|
| 262 |
if isinstance(tgt2, str) and ".hiera." in tgt2:
|
|
@@ -313,7 +332,7 @@ def _try_build(cfg_path: str):
|
|
| 313 |
try:
|
| 314 |
try:
|
| 315 |
sam = _try_build(cfg)
|
| 316 |
-
except Exception
|
| 317 |
alt_cfg = _find_hiera_config_if_hieradet(cfg)
|
| 318 |
if alt_cfg:
|
| 319 |
logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
|
|
@@ -426,7 +445,6 @@ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
|
| 426 |
repo_id = os.environ.get("MATANY_REPO_ID", "")
|
| 427 |
ckpt = os.environ.get("MATANY_CHECKPOINT", "")
|
| 428 |
|
| 429 |
-
# Check if this fork needs a prebuilt network
|
| 430 |
try:
|
| 431 |
sig = inspect.signature(InferenceCore)
|
| 432 |
if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
|
|
@@ -656,7 +674,6 @@ def fallback_composite(video_path: Union[str, Path],
|
|
| 656 |
# Stage-A (Transparent Export) Functions
|
| 657 |
# --------------------------------------------------------------------------------------
|
| 658 |
def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
|
| 659 |
-
"""RGB checkerboard for preview when no real alpha is possible."""
|
| 660 |
y, x = np.mgrid[0:h, 0:w]
|
| 661 |
c = ((x // tile) + (y // tile)) % 2
|
| 662 |
a = np.where(c == 0, 200, 150).astype(np.uint8)
|
|
@@ -670,7 +687,6 @@ def _build_stage_a_rgba_vp9_from_fg_alpha(
|
|
| 670 |
size: Tuple[int, int],
|
| 671 |
src_audio: Optional[Union[str, Path]] = None,
|
| 672 |
) -> bool:
|
| 673 |
-
"""Merge FG+ALPHA → RGBA WebM (VP9 with alpha)."""
|
| 674 |
if not _probe_ffmpeg():
|
| 675 |
return False
|
| 676 |
w, h = size
|
|
@@ -702,7 +718,6 @@ def _build_stage_a_rgba_vp9_from_mask(
|
|
| 702 |
fps: int,
|
| 703 |
size: Tuple[int, int],
|
| 704 |
) -> bool:
|
| 705 |
-
"""Merge original video + static mask → RGBA WebM (VP9 with alpha)."""
|
| 706 |
if not _probe_ffmpeg():
|
| 707 |
return False
|
| 708 |
w, h = size
|
|
@@ -733,7 +748,6 @@ def _build_stage_a_checkerboard_from_fg_alpha(
|
|
| 733 |
fps: int,
|
| 734 |
size: Tuple[int, int],
|
| 735 |
) -> bool:
|
| 736 |
-
"""Preview: FG+ALPHA over checkerboard → MP4 (no real alpha)."""
|
| 737 |
fg_cap = cv2.VideoCapture(str(fg_path))
|
| 738 |
al_cap = cv2.VideoCapture(str(alpha_path))
|
| 739 |
if not fg_cap.isOpened() or not al_cap.isOpened():
|
|
@@ -766,7 +780,6 @@ def _build_stage_a_checkerboard_from_mask(
|
|
| 766 |
fps: int,
|
| 767 |
size: Tuple[int, int],
|
| 768 |
) -> bool:
|
| 769 |
-
"""Preview: original video + static mask over checkerboard → MP4."""
|
| 770 |
cap = cv2.VideoCapture(str(video_path))
|
| 771 |
if not cap.isOpened():
|
| 772 |
return False
|
|
@@ -790,4 +803,4 @@ def _build_stage_a_checkerboard_from_mask(
|
|
| 790 |
finally:
|
| 791 |
cap.release()
|
| 792 |
writer.release()
|
| 793 |
-
return ok_any
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
BackgroundFX Pro - Model Loading & Utilities (Hardened)
|
| 4 |
+
======================================================
|
| 5 |
+
- Avoids heavy CUDA/Hydra work at import time
|
| 6 |
+
- Adds timeouts to subprocess probes
|
| 7 |
+
- Safer sys.path wiring for third_party repos
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
|
|
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
import yaml
|
|
|
|
| 23 |
|
| 24 |
# --------------------------------------------------------------------------------------
|
| 25 |
+
# Logging (ensure a handler exists very early)
|
| 26 |
# --------------------------------------------------------------------------------------
|
| 27 |
logger = logging.getLogger("backgroundfx_pro")
|
| 28 |
+
if not logger.handlers:
|
| 29 |
+
_h = logging.StreamHandler()
|
| 30 |
+
_h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 31 |
+
logger.addHandler(_h)
|
| 32 |
+
logger.setLevel(logging.INFO)
|
| 33 |
+
|
| 34 |
+
# Pin OpenCV threads (helps libgomp stability in Spaces)
|
| 35 |
+
try:
|
| 36 |
+
cv_threads = int(os.environ.get("CV_THREADS", "1"))
|
| 37 |
+
if hasattr(cv2, "setNumThreads"):
|
| 38 |
+
cv2.setNumThreads(cv_threads)
|
| 39 |
+
except Exception:
|
| 40 |
+
pass
|
| 41 |
|
| 42 |
# --------------------------------------------------------------------------------------
|
| 43 |
# Optional dependencies
|
|
|
|
| 51 |
# --------------------------------------------------------------------------------------
|
| 52 |
# Path setup for third_party repos
|
| 53 |
# --------------------------------------------------------------------------------------
|
| 54 |
+
ROOT = Path(__file__).resolve().parent.parent # project root
|
| 55 |
TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
|
| 56 |
TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
|
| 57 |
|
| 58 |
def _add_sys_path(p: Path) -> None:
|
| 59 |
+
if p.exists():
|
| 60 |
+
p_str = str(p)
|
| 61 |
+
if p_str not in sys.path:
|
| 62 |
+
sys.path.insert(0, p_str)
|
| 63 |
+
else:
|
| 64 |
+
logger.warning(f"third_party path not found: {p}")
|
| 65 |
|
| 66 |
_add_sys_path(TP_SAM2)
|
| 67 |
_add_sys_path(TP_MATANY)
|
| 68 |
|
| 69 |
# --------------------------------------------------------------------------------------
|
| 70 |
+
# Safe Torch accessors (no top-level import)
|
| 71 |
# --------------------------------------------------------------------------------------
|
| 72 |
+
def _torch():
|
|
|
|
|
|
|
|
|
|
| 73 |
try:
|
| 74 |
+
import torch # local import avoids early CUDA init during module import
|
| 75 |
+
return torch
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"[models.safe-torch] import failed: {e}")
|
| 78 |
+
return None
|
| 79 |
|
| 80 |
def _has_cuda() -> bool:
|
| 81 |
+
t = _torch()
|
| 82 |
+
if t is None:
|
| 83 |
+
return False
|
| 84 |
try:
|
| 85 |
+
return bool(t.cuda.is_available())
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}")
|
| 88 |
return False
|
| 89 |
|
| 90 |
def _pick_device(env_key: str) -> str:
|
|
|
|
| 93 |
return requested
|
| 94 |
return "cuda" if _has_cuda() else "cpu"
|
| 95 |
|
| 96 |
+
# --------------------------------------------------------------------------------------
|
| 97 |
+
# Basic Utilities
|
| 98 |
+
# --------------------------------------------------------------------------------------
|
| 99 |
+
def _ffmpeg_bin() -> str:
|
| 100 |
+
return os.environ.get("FFMPEG_BIN", "ffmpeg")
|
| 101 |
+
|
| 102 |
+
def _probe_ffmpeg(timeout: int = 2) -> bool:
|
| 103 |
+
try:
|
| 104 |
+
subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout)
|
| 105 |
+
return True
|
| 106 |
+
except Exception:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
def _ensure_dir(p: Path) -> None:
|
| 110 |
p.mkdir(parents=True, exist_ok=True)
|
| 111 |
|
|
|
|
| 172 |
# Compositing & Image Processing
|
| 173 |
# --------------------------------------------------------------------------------------
|
| 174 |
def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
|
|
|
|
| 175 |
if alpha.dtype != np.float32:
|
| 176 |
a = alpha.astype(np.float32)
|
| 177 |
if a.max() > 1.0:
|
|
|
|
| 203 |
return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
|
| 204 |
|
| 205 |
def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
|
|
|
|
| 206 |
r = max(1, int(radius))
|
| 207 |
inv = 1.0 - alpha01
|
| 208 |
inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
|
|
|
|
| 210 |
return lw
|
| 211 |
|
| 212 |
def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
|
| 213 |
+
w = 1.0 - 2.0 * np.abs(alpha01 - 0.5)
|
|
|
|
| 214 |
w = np.clip(w, 0.0, 1.0)
|
| 215 |
hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
|
| 216 |
H, S, V = cv2.split(hsv)
|
|
|
|
| 219 |
out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
|
| 220 |
return out
|
| 221 |
|
| 222 |
+
def _composite_frame_pro(
|
| 223 |
+
fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
|
| 224 |
+
erode_px: int = None, dilate_px: int = None, blur_px: float = None,
|
| 225 |
+
lw_radius: int = None, lw_amount: float = None, despill_amount: float = None
|
| 226 |
+
) -> np.ndarray:
|
| 227 |
erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
|
| 228 |
dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
|
| 229 |
blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
|
|
|
|
| 231 |
lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
|
| 232 |
despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
|
| 233 |
|
|
|
|
| 234 |
a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
|
|
|
|
|
|
|
| 235 |
fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
|
| 236 |
|
|
|
|
| 237 |
fg_lin = _to_linear(fg_rgb)
|
| 238 |
bg_lin = _to_linear(bg_rgb)
|
|
|
|
|
|
|
| 239 |
lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
|
| 240 |
lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
|
| 241 |
|
|
|
|
| 255 |
return str(candidate)
|
| 256 |
if cfg_path.exists():
|
| 257 |
return str(cfg_path)
|
|
|
|
| 258 |
for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
|
| 259 |
p = TP_SAM2 / name
|
| 260 |
if p.exists():
|
| 261 |
return str(p)
|
| 262 |
+
return str(cfg_str)
|
| 263 |
|
| 264 |
def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
|
| 265 |
"""If config references 'hieradet', try to find a 'hiera' config."""
|
| 266 |
try:
|
| 267 |
with open(cfg_path, "r") as f:
|
| 268 |
data = yaml.safe_load(f)
|
| 269 |
+
model = data.get("model", {}) or {}
|
| 270 |
+
enc = model.get("image_encoder") or {}
|
| 271 |
+
trunk = enc.get("trunk") or {}
|
|
|
|
| 272 |
target = trunk.get("_target_") or trunk.get("target")
|
| 273 |
if isinstance(target, str) and "hieradet" in target:
|
| 274 |
for y in TP_SAM2.rglob("*.yaml"):
|
| 275 |
try:
|
| 276 |
with open(y, "r") as f2:
|
| 277 |
+
d2 = yaml.safe_load(f2) or {}
|
| 278 |
+
e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
|
|
|
|
| 279 |
t2 = (e2.get("trunk") or {})
|
| 280 |
tgt2 = t2.get("_target_") or t2.get("target")
|
| 281 |
if isinstance(tgt2, str) and ".hiera." in tgt2:
|
|
|
|
| 332 |
try:
|
| 333 |
try:
|
| 334 |
sam = _try_build(cfg)
|
| 335 |
+
except Exception:
|
| 336 |
alt_cfg = _find_hiera_config_if_hieradet(cfg)
|
| 337 |
if alt_cfg:
|
| 338 |
logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
|
|
|
|
| 445 |
repo_id = os.environ.get("MATANY_REPO_ID", "")
|
| 446 |
ckpt = os.environ.get("MATANY_CHECKPOINT", "")
|
| 447 |
|
|
|
|
| 448 |
try:
|
| 449 |
sig = inspect.signature(InferenceCore)
|
| 450 |
if "network" in sig.parameters and sig.parameters["network"].default is inspect._empty:
|
|
|
|
| 674 |
# Stage-A (Transparent Export) Functions
|
| 675 |
# --------------------------------------------------------------------------------------
|
| 676 |
def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
|
|
|
|
| 677 |
y, x = np.mgrid[0:h, 0:w]
|
| 678 |
c = ((x // tile) + (y // tile)) % 2
|
| 679 |
a = np.where(c == 0, 200, 150).astype(np.uint8)
|
|
|
|
| 687 |
size: Tuple[int, int],
|
| 688 |
src_audio: Optional[Union[str, Path]] = None,
|
| 689 |
) -> bool:
|
|
|
|
| 690 |
if not _probe_ffmpeg():
|
| 691 |
return False
|
| 692 |
w, h = size
|
|
|
|
| 718 |
fps: int,
|
| 719 |
size: Tuple[int, int],
|
| 720 |
) -> bool:
|
|
|
|
| 721 |
if not _probe_ffmpeg():
|
| 722 |
return False
|
| 723 |
w, h = size
|
|
|
|
| 748 |
fps: int,
|
| 749 |
size: Tuple[int, int],
|
| 750 |
) -> bool:
|
|
|
|
| 751 |
fg_cap = cv2.VideoCapture(str(fg_path))
|
| 752 |
al_cap = cv2.VideoCapture(str(alpha_path))
|
| 753 |
if not fg_cap.isOpened() or not al_cap.isOpened():
|
|
|
|
| 780 |
fps: int,
|
| 781 |
size: Tuple[int, int],
|
| 782 |
) -> bool:
|
|
|
|
| 783 |
cap = cv2.VideoCapture(str(video_path))
|
| 784 |
if not cap.isOpened():
|
| 785 |
return False
|
|
|
|
| 803 |
finally:
|
| 804 |
cap.release()
|
| 805 |
writer.release()
|
| 806 |
+
return ok_any
|
perf_tuning.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
-
# perf_tuning.py
|
| 2 |
#!/usr/bin/env python3
|
| 3 |
"""
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
@@ -15,59 +17,107 @@
|
|
| 15 |
log.addHandler(h)
|
| 16 |
log.setLevel(logging.INFO)
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
mem_frac = float(os.environ.get("CUDA_MEMORY_FRACTION", "0.98"))
|
| 26 |
-
|
| 27 |
-
if not torch.cuda.is_available():
|
| 28 |
-
if require_cuda:
|
| 29 |
-
raise RuntimeError("CUDA is NOT available, but REQUIRE_CUDA=1. "
|
| 30 |
-
"Make sure the Space is on GPU and the container runs with --gpus all.")
|
| 31 |
-
else:
|
| 32 |
-
log.warning("CUDA not available; running on CPU. Set REQUIRE_CUDA=1 to fail fast.")
|
| 33 |
else:
|
| 34 |
-
#
|
| 35 |
try:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
except Exception:
|
| 53 |
-
pass
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Optional: limit OpenCV threads if provided
|
| 73 |
threads = os.environ.get("OPENCV_NUM_THREADS")
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
perf_tuning.py (Hardened)
|
| 4 |
+
- No hard CUDA touching at import time (prevents startup hangs on Spaces).
|
| 5 |
+
- Optional "strict" modes via env if you *really* want fail-fast behavior.
|
| 6 |
+
- Applies safe flags (TF32/cudnn.benchmark) best-effort.
|
| 7 |
+
- Short, defensive GPU banner (only if explicitly enabled).
|
| 8 |
"""
|
| 9 |
|
| 10 |
import os
|
|
|
|
| 17 |
log.addHandler(h)
|
| 18 |
log.setLevel(logging.INFO)
|
| 19 |
|
| 20 |
+
# ---- Feature flags (env) -----------------------------------------------------
|
| 21 |
+
DISABLED = os.getenv("PERF_TUNING_DISABLED", "0").strip() == "1"
|
| 22 |
+
STRICT_IMPORT_FAIL = os.getenv("PERF_TUNING_IMPORT_STRICT", "0").strip() == "1" # if 1, may raise on import
|
| 23 |
+
EAGER_CUDA = os.getenv("PERF_TUNING_EAGER_CUDA", "0").strip() == "1" # if 1, do CUDA probing now
|
| 24 |
+
REQUIRE_CUDA = os.getenv("REQUIRE_CUDA", "0").strip() == "1" # prefer not to fail at import
|
| 25 |
+
FORCE_IDX_ENV = os.getenv("FORCE_CUDA_DEVICE", "").strip()
|
| 26 |
+
MEM_FRAC_STR = os.getenv("CUDA_MEMORY_FRACTION", "0.98").strip()
|
| 27 |
|
| 28 |
+
if DISABLED:
|
| 29 |
+
log.info("perf_tuning: disabled by PERF_TUNING_DISABLED=1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
else:
|
| 31 |
+
# Import torch defensively (do NOT crash the app if it's not there)
|
| 32 |
try:
|
| 33 |
+
import importlib
|
| 34 |
+
torch = importlib.import_module("torch")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
msg = f"perf_tuning: PyTorch not importable at import-time: {e}"
|
| 37 |
+
if STRICT_IMPORT_FAIL:
|
| 38 |
+
raise RuntimeError(msg)
|
| 39 |
+
else:
|
| 40 |
+
log.warning(msg)
|
| 41 |
+
torch = None
|
| 42 |
|
| 43 |
+
def _bool_cuda_available():
|
| 44 |
+
if torch is None:
|
| 45 |
+
return False
|
| 46 |
+
try:
|
| 47 |
+
return bool(torch.cuda.is_available())
|
| 48 |
+
except Exception as e:
|
| 49 |
+
log.warning(f"perf_tuning: cuda.is_available() failed: {e}")
|
| 50 |
+
return False
|
| 51 |
|
| 52 |
+
# Soft gating: if user *requires* CUDA, set a marker we can read later
|
| 53 |
+
if REQUIRE_CUDA and not _bool_cuda_available():
|
| 54 |
+
os.environ["BFX_REQUIRE_CUDA_FAILED"] = "1"
|
| 55 |
+
msg = "CUDA NOT available but REQUIRE_CUDA=1 (will run on CPU unless app checks this later)."
|
| 56 |
+
if STRICT_IMPORT_FAIL:
|
| 57 |
+
raise RuntimeError(msg)
|
| 58 |
+
else:
|
| 59 |
+
log.warning(msg)
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
# Always try “cheap” flags that won’t touch devices
|
| 62 |
+
if torch is not None:
|
| 63 |
+
try:
|
| 64 |
+
# These do not require an active CUDA context
|
| 65 |
+
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
|
| 66 |
+
try:
|
| 67 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 68 |
+
except Exception:
|
| 69 |
+
pass
|
| 70 |
+
if hasattr(torch.backends, "cudnn"):
|
| 71 |
+
try:
|
| 72 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 73 |
+
torch.backends.cudnn.benchmark = True
|
| 74 |
+
except Exception:
|
| 75 |
+
pass
|
| 76 |
+
except Exception as e:
|
| 77 |
+
log.debug(f"perf_tuning: backend flags suppressed: {e}")
|
| 78 |
|
| 79 |
+
# Only do potentially blocking CUDA work if explicitly requested
|
| 80 |
+
if EAGER_CUDA and torch is not None:
|
| 81 |
+
try:
|
| 82 |
+
# Choose device (optional)
|
| 83 |
+
try:
|
| 84 |
+
idx = int(FORCE_IDX_ENV) if FORCE_IDX_ENV != "" else 0
|
| 85 |
+
except Exception:
|
| 86 |
+
idx = 0
|
| 87 |
+
try:
|
| 88 |
+
torch.cuda.set_device(idx)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
log.warning(f"perf_tuning: set_device({idx}) failed: {e}")
|
| 91 |
+
|
| 92 |
+
# Memory fraction is optional and sometimes flaky—guard it
|
| 93 |
+
try:
|
| 94 |
+
mem_frac = float(MEM_FRAC_STR)
|
| 95 |
+
torch.cuda.set_per_process_memory_fraction(mem_frac, idx)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
log.debug(f"perf_tuning: set_per_process_memory_fraction skipped: {e}")
|
| 98 |
+
|
| 99 |
+
# Best-effort banner; every call is wrapped so nothing blocks startup
|
| 100 |
+
try:
|
| 101 |
+
name = torch.cuda.get_device_name(idx)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
name = f"? ({e})"
|
| 104 |
+
try:
|
| 105 |
+
cap = torch.cuda.get_device_capability(idx)
|
| 106 |
+
cap_s = f"{cap[0]}.{cap[1]}"
|
| 107 |
+
except Exception as e:
|
| 108 |
+
cap_s = f"? ({e})"
|
| 109 |
+
try:
|
| 110 |
+
total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024**3)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
total_gb = f"? ({e})"
|
| 113 |
+
try:
|
| 114 |
+
free_gb = torch.cuda.mem_get_info()[0] / (1024**3)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
free_gb = f"? ({e})"
|
| 117 |
+
|
| 118 |
+
log.info(f"CUDA device {idx}: {name} | cc {cap_s} | VRAM {total_gb} GB (free ~{free_gb} GB) | TF32:ON | cuDNN benchmark:ON")
|
| 119 |
+
except Exception as e:
|
| 120 |
+
log.warning(f"perf_tuning: eager CUDA probe failed (non-fatal): {e}")
|
| 121 |
|
| 122 |
# Optional: limit OpenCV threads if provided
|
| 123 |
threads = os.environ.get("OPENCV_NUM_THREADS")
|
pipeline.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
BackgroundFX Pro - Memory-Optimized Pipeline
|
| 4 |
-
===========================================
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
|
@@ -13,85 +15,128 @@
|
|
| 13 |
import time
|
| 14 |
import tempfile
|
| 15 |
import logging
|
|
|
|
| 16 |
from pathlib import Path
|
| 17 |
from typing import Optional, Tuple, Dict, Any, Union
|
| 18 |
|
| 19 |
-
import torch
|
| 20 |
-
from models import (
|
| 21 |
-
load_sam2, run_sam2_mask, load_matany, run_matany,
|
| 22 |
-
fallback_mask, fallback_composite, composite_video,
|
| 23 |
-
_cv_read_first_frame, _save_mask_png, _ensure_dir, _mux_audio, _probe_ffmpeg,
|
| 24 |
-
_refine_mask_grabcut, _build_stage_a_rgba_vp9_from_fg_alpha,
|
| 25 |
-
_build_stage_a_rgba_vp9_from_mask, _build_stage_a_checkerboard_from_fg_alpha,
|
| 26 |
-
_build_stage_a_checkerboard_from_mask
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Try to apply GPU/perf tuning early
|
| 30 |
-
try:
|
| 31 |
-
import perf_tuning # noqa: F401
|
| 32 |
-
except Exception:
|
| 33 |
-
pass
|
| 34 |
-
|
| 35 |
# --------------------------------------------------------------------------------------
|
| 36 |
# Logging
|
| 37 |
# --------------------------------------------------------------------------------------
|
| 38 |
logger = logging.getLogger("backgroundfx_pro")
|
| 39 |
-
logger.setLevel(logging.INFO)
|
| 40 |
if not logger.handlers:
|
| 41 |
_h = logging.StreamHandler()
|
| 42 |
_h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 43 |
logger.addHandler(_h)
|
|
|
|
| 44 |
|
| 45 |
# --------------------------------------------------------------------------------------
|
| 46 |
-
#
|
| 47 |
# --------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def _cleanup_temp_files(tmp_root: Path) -> None:
|
| 49 |
-
"""Clean up temporary files aggressively"""
|
| 50 |
try:
|
| 51 |
-
for pattern in
|
| 52 |
for f in tmp_root.glob(pattern):
|
| 53 |
f.unlink(missing_ok=True)
|
| 54 |
except Exception:
|
| 55 |
pass
|
| 56 |
|
| 57 |
def _log_memory() -> float:
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
logger.info(f"GPU memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
|
| 64 |
-
return allocated
|
| 65 |
-
|
| 66 |
-
|
| 67 |
return 0.0
|
| 68 |
|
| 69 |
def _force_cleanup() -> None:
|
| 70 |
-
"""Aggressive memory cleanup"""
|
| 71 |
try:
|
| 72 |
gc.collect()
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
except Exception as e:
|
| 77 |
-
logger.
|
|
|
|
| 78 |
|
| 79 |
# --------------------------------------------------------------------------------------
|
| 80 |
-
# Main Processing Function
|
| 81 |
# --------------------------------------------------------------------------------------
|
| 82 |
-
def process(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
Memory-optimized orchestration: lazy loading, sequential model usage, aggressive cleanup.
|
| 90 |
-
|
| 91 |
Flow:
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
"""
|
| 96 |
t0 = time.time()
|
| 97 |
diagnostics: Dict[str, Any] = {
|
|
@@ -110,105 +155,130 @@ def process(video_path: Union[str, Path],
|
|
| 110 |
tmp_root = Path(work_dir) if work_dir else Path(tempfile.mkdtemp(prefix="bfx_"))
|
| 111 |
_ensure_dir(tmp_root)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
try:
|
| 114 |
# 0) Basic video info
|
| 115 |
-
logger.info("Reading video metadata
|
| 116 |
first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
|
| 117 |
diagnostics["fps"] = int(fps or 25)
|
| 118 |
diagnostics["resolution"] = [int(vw), int(vh)]
|
| 119 |
-
|
| 120 |
if first_frame is None or vw == 0 or vh == 0:
|
| 121 |
diagnostics["fallback_used"] = "invalid_video"
|
| 122 |
return None, diagnostics
|
| 123 |
|
| 124 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 125 |
|
| 126 |
-
# 1) PHASE 1: SAM2
|
| 127 |
-
logger.info("
|
| 128 |
predictor, sam2_ok, sam_meta = load_sam2()
|
| 129 |
-
diagnostics["sam2_meta"] = sam_meta
|
| 130 |
-
diagnostics["device_sam2"] = sam_meta.get("sam2_device")
|
| 131 |
-
|
| 132 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 133 |
|
| 134 |
seed_mask = None
|
| 135 |
mask_png = tmp_root / "seed_mask.png"
|
| 136 |
-
|
|
|
|
| 137 |
if sam2_ok and predictor is not None:
|
| 138 |
-
logger.info("Running SAM2 segmentation
|
| 139 |
px = int(point_x) if point_x is not None else None
|
| 140 |
py = int(point_y) if point_y is not None else None
|
| 141 |
-
|
| 142 |
seed_mask, ok_mask = run_sam2_mask(
|
| 143 |
predictor, first_frame,
|
| 144 |
point=(px, py) if (px is not None and py is not None) else None,
|
| 145 |
auto=auto_box
|
| 146 |
)
|
| 147 |
diagnostics["sam2_ok"] = bool(ok_mask)
|
| 148 |
-
|
| 149 |
-
# CRITICAL: Free SAM2 immediately after getting the mask
|
| 150 |
-
logger.info("Freeing SAM2 memory...")
|
| 151 |
-
del predictor
|
| 152 |
-
predictor = None
|
| 153 |
-
_force_cleanup()
|
| 154 |
-
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 155 |
-
|
| 156 |
else:
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# Fallback mask generation if SAM2 failed
|
| 161 |
if not ok_mask or seed_mask is None:
|
| 162 |
-
logger.info("Using fallback mask generation
|
| 163 |
seed_mask = fallback_mask(first_frame)
|
| 164 |
diagnostics["fallback_used"] = "mask_generation"
|
| 165 |
_force_cleanup()
|
| 166 |
|
| 167 |
# Optional GrabCut refinement
|
| 168 |
if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
|
| 169 |
-
logger.info("Refining mask with GrabCut
|
| 170 |
seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
|
| 171 |
_force_cleanup()
|
| 172 |
|
| 173 |
_save_mask_png(seed_mask, mask_png)
|
| 174 |
-
|
| 175 |
-
#
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
_force_cleanup()
|
| 178 |
_cleanup_temp_files(tmp_root)
|
| 179 |
|
| 180 |
-
# 2) PHASE 2: MatAnyone
|
| 181 |
-
logger.info("
|
| 182 |
matany, mat_ok, mat_meta = load_matany()
|
| 183 |
-
diagnostics["matany_meta"] = mat_meta
|
| 184 |
-
diagnostics["device_matany"] = mat_meta.get("matany_device")
|
| 185 |
-
|
| 186 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 187 |
|
| 188 |
fg_path, al_path = None, None
|
| 189 |
out_dir = tmp_root / "matany_out"
|
| 190 |
_ensure_dir(out_dir)
|
| 191 |
-
|
|
|
|
| 192 |
if mat_ok and matany is not None:
|
| 193 |
-
logger.info("Running MatAnyone processing
|
| 194 |
fg_path, al_path, ran = run_matany(matany, video_path, mask_png, out_dir)
|
| 195 |
diagnostics["matany_ok"] = bool(ran)
|
| 196 |
-
|
| 197 |
-
# CRITICAL: Free MatAnyone immediately after processing
|
| 198 |
-
logger.info("Freeing MatAnyone memory...")
|
| 199 |
-
del matany
|
| 200 |
-
matany = None
|
| 201 |
-
_force_cleanup()
|
| 202 |
-
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 203 |
else:
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
# 3) PHASE 3: Stage-A
|
| 208 |
-
logger.info("
|
| 209 |
stageA_path = None
|
| 210 |
stageA_ok = False
|
| 211 |
-
|
| 212 |
if diagnostics["matany_ok"] and fg_path and al_path:
|
| 213 |
stageA_path = tmp_root / "stageA_transparent.webm"
|
| 214 |
if _probe_ffmpeg():
|
|
@@ -238,57 +308,56 @@ def process(video_path: Union[str, Path],
|
|
| 238 |
else ("MP4 checkerboard preview (no real alpha)" if stageA_ok else "Stage-A build failed")
|
| 239 |
)
|
| 240 |
|
| 241 |
-
# Optional: return Stage-A instead of final composite
|
| 242 |
if os.environ.get("RETURN_STAGE_A", "0").strip() == "1" and stageA_ok:
|
| 243 |
_force_cleanup()
|
| 244 |
_cleanup_temp_files(tmp_root)
|
|
|
|
|
|
|
| 245 |
return str(stageA_path), diagnostics
|
| 246 |
|
| 247 |
-
# 4) PHASE 4: Final Compositing
|
| 248 |
-
logger.info("
|
| 249 |
output_path = tmp_root / "output.mp4"
|
| 250 |
-
|
| 251 |
if diagnostics["matany_ok"] and fg_path and al_path:
|
| 252 |
-
logger.info("Compositing with MatAnyone outputs
|
| 253 |
ok_comp = composite_video(fg_path, al_path, bg_image_path, output_path, diagnostics["fps"], (vw, vh))
|
| 254 |
if not ok_comp:
|
| 255 |
-
logger.info("
|
| 256 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 257 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
|
| 258 |
else:
|
| 259 |
-
logger.info("Using static mask composite
|
| 260 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 261 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
|
| 262 |
|
| 263 |
-
# Clean up intermediate files
|
| 264 |
_cleanup_temp_files(tmp_root)
|
| 265 |
_force_cleanup()
|
| 266 |
|
| 267 |
-
# 5) PHASE 5: Audio
|
| 268 |
-
logger.info("
|
| 269 |
final_path = tmp_root / "output_with_audio.mp4"
|
| 270 |
if _probe_ffmpeg():
|
| 271 |
mux_ok = _mux_audio(video_path, output_path, final_path)
|
| 272 |
if mux_ok:
|
| 273 |
-
# Clean up the silent version
|
| 274 |
output_path.unlink(missing_ok=True)
|
| 275 |
_force_cleanup()
|
| 276 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 277 |
-
logger.info(f"
|
| 278 |
-
logger.info(f"Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
|
| 279 |
return str(final_path), diagnostics
|
| 280 |
|
| 281 |
-
#
|
| 282 |
_force_cleanup()
|
| 283 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 284 |
-
logger.info(f"
|
| 285 |
-
logger.info(f"Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
|
| 286 |
return str(output_path), diagnostics
|
| 287 |
|
| 288 |
except Exception as e:
|
| 289 |
-
logger.error(f"Processing failed: {e}")
|
| 290 |
import traceback
|
| 291 |
-
logger.error(f"Traceback:
|
| 292 |
_force_cleanup()
|
| 293 |
diagnostics["error"] = str(e)
|
| 294 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
|
@@ -297,4 +366,4 @@ def process(video_path: Union[str, Path],
|
|
| 297 |
finally:
|
| 298 |
# Ensure cleanup even if something goes wrong
|
| 299 |
_force_cleanup()
|
| 300 |
-
_cleanup_temp_files(tmp_root)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
BackgroundFX Pro - Memory-Optimized Pipeline (Hardened)
|
| 4 |
+
======================================================
|
| 5 |
+
- Lazy-imports heavy 'models' module to avoid Space boot stalls
|
| 6 |
+
- Sequential load → run → free (SAM2 then MatAnyone)
|
| 7 |
+
- Aggressive but non-blocking GPU cleanup (no synchronize())
|
| 8 |
+
- Verbose breadcrumbs for pinpointing stalls
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
|
|
| 15 |
import time
|
| 16 |
import tempfile
|
| 17 |
import logging
|
| 18 |
+
import importlib
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Optional, Tuple, Dict, Any, Union
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# --------------------------------------------------------------------------------------
|
| 23 |
# Logging
|
| 24 |
# --------------------------------------------------------------------------------------
|
| 25 |
logger = logging.getLogger("backgroundfx_pro")
|
|
|
|
| 26 |
if not logger.handlers:
|
| 27 |
_h = logging.StreamHandler()
|
| 28 |
_h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 29 |
logger.addHandler(_h)
|
| 30 |
+
logger.setLevel(logging.INFO)
|
| 31 |
|
| 32 |
# --------------------------------------------------------------------------------------
|
| 33 |
+
# Safe Torch accessors (avoid import-time CUDA touches)
|
| 34 |
# --------------------------------------------------------------------------------------
|
| 35 |
+
def _torch():
|
| 36 |
+
try:
|
| 37 |
+
import torch # local import to avoid early CUDA init in module scope
|
| 38 |
+
return torch
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.warning(f"[safe-torch] import failed: {e}")
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
def _cuda_available() -> Optional[bool]:
|
| 44 |
+
t = _torch()
|
| 45 |
+
if t is None:
|
| 46 |
+
return None
|
| 47 |
+
try:
|
| 48 |
+
return t.cuda.is_available()
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.warning(f"[safe-torch] torch.cuda.is_available() failed: {e}")
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
# --------------------------------------------------------------------------------------
|
| 54 |
+
# Lightweight utilities
|
| 55 |
+
# --------------------------------------------------------------------------------------
|
| 56 |
+
def _ensure_dir(p: Union[str, Path]) -> None:
|
| 57 |
+
Path(p).mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
def _cleanup_temp_files(tmp_root: Path) -> None:
|
| 60 |
+
"""Clean up temporary files aggressively."""
|
| 61 |
try:
|
| 62 |
+
for pattern in ("*.tmp", "*.temp", "*.bak"):
|
| 63 |
for f in tmp_root.glob(pattern):
|
| 64 |
f.unlink(missing_ok=True)
|
| 65 |
except Exception:
|
| 66 |
pass
|
| 67 |
|
| 68 |
def _log_memory() -> float:
|
| 69 |
+
"""Best-effort GPU mem log (never block)."""
|
| 70 |
+
t = _torch()
|
| 71 |
+
if t is None:
|
| 72 |
+
return 0.0
|
| 73 |
+
try:
|
| 74 |
+
avail = _cuda_available()
|
| 75 |
+
if avail:
|
| 76 |
+
allocated = t.cuda.memory_allocated() / 1e9
|
| 77 |
+
reserved = t.cuda.memory_reserved() / 1e9
|
| 78 |
logger.info(f"GPU memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
|
| 79 |
+
return float(allocated)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.debug(f"[mem-log] suppressed: {e}")
|
| 82 |
return 0.0
|
| 83 |
|
| 84 |
def _force_cleanup() -> None:
|
| 85 |
+
"""Aggressive memory cleanup (non-blocking)."""
|
| 86 |
try:
|
| 87 |
gc.collect()
|
| 88 |
+
except Exception:
|
| 89 |
+
pass
|
| 90 |
+
t = _torch()
|
| 91 |
+
if t is None:
|
| 92 |
+
return
|
| 93 |
+
try:
|
| 94 |
+
if _cuda_available():
|
| 95 |
+
# Avoid torch.cuda.synchronize() — can hang on driver issues
|
| 96 |
+
t.cuda.empty_cache()
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.debug(f"[cleanup] suppressed: {e}")
|
| 99 |
+
|
| 100 |
+
# --------------------------------------------------------------------------------------
|
| 101 |
+
# Lazy import of heavy models module
|
| 102 |
+
# --------------------------------------------------------------------------------------
|
| 103 |
+
_models_ref = None
|
| 104 |
+
|
| 105 |
+
def _models():
|
| 106 |
+
"""Import 'models' only when needed to avoid startup stalls."""
|
| 107 |
+
global _models_ref
|
| 108 |
+
if _models_ref is not None:
|
| 109 |
+
return _models_ref
|
| 110 |
+
logger.info("[init] Importing models module lazily…")
|
| 111 |
+
try:
|
| 112 |
+
_models_ref = importlib.import_module("models")
|
| 113 |
+
logger.info("[init] models imported OK.")
|
| 114 |
+
return _models_ref
|
| 115 |
except Exception as e:
|
| 116 |
+
logger.exception(f"[init] Failed to import models: {e}")
|
| 117 |
+
raise
|
| 118 |
|
| 119 |
# --------------------------------------------------------------------------------------
|
| 120 |
+
# Main Processing Function
|
| 121 |
# --------------------------------------------------------------------------------------
|
| 122 |
+
def process(
|
| 123 |
+
video_path: Union[str, Path],
|
| 124 |
+
bg_image_path: Union[str, Path],
|
| 125 |
+
point_x: Optional[float] = None,
|
| 126 |
+
point_y: Optional[float] = None,
|
| 127 |
+
auto_box: bool = False,
|
| 128 |
+
work_dir: Optional[Union[str, Path]] = None
|
| 129 |
+
) -> Tuple[Optional[str], Dict[str, Any]]:
|
| 130 |
"""
|
| 131 |
Memory-optimized orchestration: lazy loading, sequential model usage, aggressive cleanup.
|
| 132 |
+
|
| 133 |
Flow:
|
| 134 |
+
0. Read video metadata
|
| 135 |
+
1. SAM2 → mask (free immediately)
|
| 136 |
+
2. MatAnyone → FG/alpha (free immediately)
|
| 137 |
+
3. Stage-A build (transparent or checkerboard)
|
| 138 |
+
4. Final composite
|
| 139 |
+
5. Audio mux
|
| 140 |
"""
|
| 141 |
t0 = time.time()
|
| 142 |
diagnostics: Dict[str, Any] = {
|
|
|
|
| 155 |
tmp_root = Path(work_dir) if work_dir else Path(tempfile.mkdtemp(prefix="bfx_"))
|
| 156 |
_ensure_dir(tmp_root)
|
| 157 |
|
| 158 |
+
# Defer heavy function imports until inside the call
|
| 159 |
+
M = _models()
|
| 160 |
+
# pull only the needed callables
|
| 161 |
+
_cv_read_first_frame = M._cv_read_first_frame
|
| 162 |
+
_save_mask_png = M._save_mask_png
|
| 163 |
+
_probe_ffmpeg = M._probe_ffmpeg
|
| 164 |
+
_mux_audio = M._mux_audio
|
| 165 |
+
_refine_mask_grabcut = M._refine_mask_grabcut
|
| 166 |
+
fallback_mask = M.fallback_mask
|
| 167 |
+
fallback_composite = M.fallback_composite
|
| 168 |
+
composite_video = M.composite_video
|
| 169 |
+
load_sam2 = M.load_sam2
|
| 170 |
+
run_sam2_mask = M.run_sam2_mask
|
| 171 |
+
load_matany = M.load_matany
|
| 172 |
+
run_matany = M.run_matany
|
| 173 |
+
_build_stage_a_rgba_vp9_from_fg_alpha = M._build_stage_a_rgba_vp9_from_fg_alpha
|
| 174 |
+
_build_stage_a_rgba_vp9_from_mask = M._build_stage_a_rgba_vp9_from_mask
|
| 175 |
+
_build_stage_a_checkerboard_from_fg_alpha = M._build_stage_a_checkerboard_from_fg_alpha
|
| 176 |
+
_build_stage_a_checkerboard_from_mask = M._build_stage_a_checkerboard_from_mask
|
| 177 |
+
|
| 178 |
try:
|
| 179 |
# 0) Basic video info
|
| 180 |
+
logger.info("[0] Reading video metadata…")
|
| 181 |
first_frame, fps, (vw, vh) = _cv_read_first_frame(video_path)
|
| 182 |
diagnostics["fps"] = int(fps or 25)
|
| 183 |
diagnostics["resolution"] = [int(vw), int(vh)]
|
| 184 |
+
|
| 185 |
if first_frame is None or vw == 0 or vh == 0:
|
| 186 |
diagnostics["fallback_used"] = "invalid_video"
|
| 187 |
return None, diagnostics
|
| 188 |
|
| 189 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 190 |
|
| 191 |
+
# 1) PHASE 1: SAM2
|
| 192 |
+
logger.info("[1] Loading SAM2…")
|
| 193 |
predictor, sam2_ok, sam_meta = load_sam2()
|
| 194 |
+
diagnostics["sam2_meta"] = sam_meta or {}
|
| 195 |
+
diagnostics["device_sam2"] = (sam_meta or {}).get("sam2_device")
|
| 196 |
+
|
| 197 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 198 |
|
| 199 |
seed_mask = None
|
| 200 |
mask_png = tmp_root / "seed_mask.png"
|
| 201 |
+
|
| 202 |
+
ok_mask = False
|
| 203 |
if sam2_ok and predictor is not None:
|
| 204 |
+
logger.info("[1] Running SAM2 segmentation…")
|
| 205 |
px = int(point_x) if point_x is not None else None
|
| 206 |
py = int(point_y) if point_y is not None else None
|
|
|
|
| 207 |
seed_mask, ok_mask = run_sam2_mask(
|
| 208 |
predictor, first_frame,
|
| 209 |
point=(px, py) if (px is not None and py is not None) else None,
|
| 210 |
auto=auto_box
|
| 211 |
)
|
| 212 |
diagnostics["sam2_ok"] = bool(ok_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
else:
|
| 214 |
+
logger.info("[1] SAM2 unavailable or failed to load.")
|
| 215 |
+
|
| 216 |
+
# Free SAM2 ASAP
|
| 217 |
+
try:
|
| 218 |
+
del predictor
|
| 219 |
+
except Exception:
|
| 220 |
+
pass
|
| 221 |
+
predictor = None
|
| 222 |
+
_force_cleanup()
|
| 223 |
+
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 224 |
|
| 225 |
# Fallback mask generation if SAM2 failed
|
| 226 |
if not ok_mask or seed_mask is None:
|
| 227 |
+
logger.info("[1] Using fallback mask generation…")
|
| 228 |
seed_mask = fallback_mask(first_frame)
|
| 229 |
diagnostics["fallback_used"] = "mask_generation"
|
| 230 |
_force_cleanup()
|
| 231 |
|
| 232 |
# Optional GrabCut refinement
|
| 233 |
if int(os.environ.get("REFINE_GRABCUT", "1")) == 1:
|
| 234 |
+
logger.info("[1] Refining mask with GrabCut…")
|
| 235 |
seed_mask = _refine_mask_grabcut(first_frame, seed_mask)
|
| 236 |
_force_cleanup()
|
| 237 |
|
| 238 |
_save_mask_png(seed_mask, mask_png)
|
| 239 |
+
|
| 240 |
+
# Free first frame
|
| 241 |
+
try:
|
| 242 |
+
del first_frame
|
| 243 |
+
except Exception:
|
| 244 |
+
pass
|
| 245 |
_force_cleanup()
|
| 246 |
_cleanup_temp_files(tmp_root)
|
| 247 |
|
| 248 |
+
# 2) PHASE 2: MatAnyone
|
| 249 |
+
logger.info("[2] Loading MatAnyone…")
|
| 250 |
matany, mat_ok, mat_meta = load_matany()
|
| 251 |
+
diagnostics["matany_meta"] = mat_meta or {}
|
| 252 |
+
diagnostics["device_matany"] = (mat_meta or {}).get("matany_device")
|
| 253 |
+
|
| 254 |
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 255 |
|
| 256 |
fg_path, al_path = None, None
|
| 257 |
out_dir = tmp_root / "matany_out"
|
| 258 |
_ensure_dir(out_dir)
|
| 259 |
+
|
| 260 |
+
ran = False
|
| 261 |
if mat_ok and matany is not None:
|
| 262 |
+
logger.info("[2] Running MatAnyone processing…")
|
| 263 |
fg_path, al_path, ran = run_matany(matany, video_path, mask_png, out_dir)
|
| 264 |
diagnostics["matany_ok"] = bool(ran)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
else:
|
| 266 |
+
logger.info("[2] MatAnyone unavailable/disabled/failed to load.")
|
| 267 |
+
|
| 268 |
+
# Free MatAnyone ASAP
|
| 269 |
+
try:
|
| 270 |
+
del matany
|
| 271 |
+
except Exception:
|
| 272 |
+
pass
|
| 273 |
+
matany = None
|
| 274 |
+
_force_cleanup()
|
| 275 |
+
diagnostics["memory_peak_gb"] = max(diagnostics["memory_peak_gb"], _log_memory())
|
| 276 |
|
| 277 |
+
# 3) PHASE 3: Stage-A
|
| 278 |
+
logger.info("[3] Building Stage-A (transparent or checkerboard)…")
|
| 279 |
stageA_path = None
|
| 280 |
stageA_ok = False
|
| 281 |
+
|
| 282 |
if diagnostics["matany_ok"] and fg_path and al_path:
|
| 283 |
stageA_path = tmp_root / "stageA_transparent.webm"
|
| 284 |
if _probe_ffmpeg():
|
|
|
|
| 308 |
else ("MP4 checkerboard preview (no real alpha)" if stageA_ok else "Stage-A build failed")
|
| 309 |
)
|
| 310 |
|
|
|
|
| 311 |
if os.environ.get("RETURN_STAGE_A", "0").strip() == "1" and stageA_ok:
|
| 312 |
_force_cleanup()
|
| 313 |
_cleanup_temp_files(tmp_root)
|
| 314 |
+
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 315 |
+
logger.info(f"[done] Returned Stage-A in {diagnostics['elapsed_sec']}s")
|
| 316 |
return str(stageA_path), diagnostics
|
| 317 |
|
| 318 |
+
# 4) PHASE 4: Final Compositing
|
| 319 |
+
logger.info("[4] Creating final composite…")
|
| 320 |
output_path = tmp_root / "output.mp4"
|
| 321 |
+
|
| 322 |
if diagnostics["matany_ok"] and fg_path and al_path:
|
| 323 |
+
logger.info("[4] Compositing with MatAnyone outputs…")
|
| 324 |
ok_comp = composite_video(fg_path, al_path, bg_image_path, output_path, diagnostics["fps"], (vw, vh))
|
| 325 |
if not ok_comp:
|
| 326 |
+
logger.info("[4] Composite failed; falling back to static mask composite.")
|
| 327 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 328 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") + "+composite_static"
|
| 329 |
else:
|
| 330 |
+
logger.info("[4] Using static mask composite…")
|
| 331 |
fallback_composite(video_path, mask_png, bg_image_path, output_path)
|
| 332 |
diagnostics["fallback_used"] = (diagnostics["fallback_used"] or "") or "composite_static"
|
| 333 |
|
|
|
|
| 334 |
_cleanup_temp_files(tmp_root)
|
| 335 |
_force_cleanup()
|
| 336 |
|
| 337 |
+
# 5) PHASE 5: Audio Mux
|
| 338 |
+
logger.info("[5] Adding audio track…")
|
| 339 |
final_path = tmp_root / "output_with_audio.mp4"
|
| 340 |
if _probe_ffmpeg():
|
| 341 |
mux_ok = _mux_audio(video_path, output_path, final_path)
|
| 342 |
if mux_ok:
|
|
|
|
| 343 |
output_path.unlink(missing_ok=True)
|
| 344 |
_force_cleanup()
|
| 345 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 346 |
+
logger.info(f"[done] Success in {diagnostics['elapsed_sec']}s")
|
| 347 |
+
logger.info(f"[done] Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
|
| 348 |
return str(final_path), diagnostics
|
| 349 |
|
| 350 |
+
# Fallback return without audio
|
| 351 |
_force_cleanup()
|
| 352 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
| 353 |
+
logger.info(f"[done] Completed (no audio) in {diagnostics['elapsed_sec']}s")
|
| 354 |
+
logger.info(f"[done] Peak GPU memory usage: {diagnostics['memory_peak_gb']:.1f}GB")
|
| 355 |
return str(output_path), diagnostics
|
| 356 |
|
| 357 |
except Exception as e:
|
| 358 |
+
logger.error(f"[error] Processing failed: {e}")
|
| 359 |
import traceback
|
| 360 |
+
logger.error(f"[error] Traceback:\n{traceback.format_exc()}")
|
| 361 |
_force_cleanup()
|
| 362 |
diagnostics["error"] = str(e)
|
| 363 |
diagnostics["elapsed_sec"] = round(time.time() - t0, 3)
|
|
|
|
| 366 |
finally:
|
| 367 |
# Ensure cleanup even if something goes wrong
|
| 368 |
_force_cleanup()
|
| 369 |
+
_cleanup_temp_files(tmp_root)
|
requirements.txt
CHANGED
|
@@ -9,7 +9,8 @@ moviepy==1.0.3
|
|
| 9 |
decord==0.6.0
|
| 10 |
Pillow==10.4.0
|
| 11 |
numpy==1.26.4
|
| 12 |
-
mediapipe==0.10.14
|
|
|
|
| 13 |
|
| 14 |
# ===== Gradio UI =====
|
| 15 |
gradio==5.42.0
|
|
@@ -28,10 +29,10 @@ scikit-image==0.24.0
|
|
| 28 |
tqdm==4.66.5
|
| 29 |
|
| 30 |
# ===== Helpers / caching =====
|
| 31 |
-
huggingface_hub
|
| 32 |
ffmpeg-python==0.2.0
|
| 33 |
psutil==6.0.0
|
| 34 |
-
requests==2.
|
| 35 |
scikit-learn==1.5.1
|
| 36 |
|
| 37 |
# ===== (Optional) Extras =====
|
|
|
|
| 9 |
decord==0.6.0
|
| 10 |
Pillow==10.4.0
|
| 11 |
numpy==1.26.4
|
| 12 |
+
mediapipe==0.10.14
|
| 13 |
+
protobuf==4.25.3
|
| 14 |
|
| 15 |
# ===== Gradio UI =====
|
| 16 |
gradio==5.42.0
|
|
|
|
| 29 |
tqdm==4.66.5
|
| 30 |
|
| 31 |
# ===== Helpers / caching =====
|
| 32 |
+
huggingface_hub==0.33.5
|
| 33 |
ffmpeg-python==0.2.0
|
| 34 |
psutil==6.0.0
|
| 35 |
+
requests==2.32.3
|
| 36 |
scikit-learn==1.5.1
|
| 37 |
|
| 38 |
# ===== (Optional) Extras =====
|
ui.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
-
#
|
| 2 |
"""
|
| 3 |
-
BackgroundFX Pro — Gradio UI, background generators, and data sources
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import io
|
|
@@ -14,13 +16,12 @@
|
|
| 14 |
from PIL import Image
|
| 15 |
import gradio as gr
|
| 16 |
|
| 17 |
-
from pipeline import (
|
| 18 |
-
process_video_gpu_optimized, stop_processing, processing_active,
|
| 19 |
-
SAM2_ENABLED, MATANY_ENABLED, GPU_NAME, GPU_MEMORY
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
logger = logging.getLogger("ui")
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ---- Background generators ----
|
| 26 |
def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
|
|
@@ -51,7 +52,6 @@ def create_gradient_background(gradient_type: str, width: int, height: int) -> I
|
|
| 51 |
img[i, :] = [r, g, b]
|
| 52 |
return Image.fromarray(img)
|
| 53 |
|
| 54 |
-
|
| 55 |
def create_solid_color(color: str, width: int, height: int) -> Image.Image:
|
| 56 |
color_map = {
|
| 57 |
"white": (255, 255, 255),
|
|
@@ -66,22 +66,25 @@ def create_solid_color(color: str, width: int, height: int) -> Image.Image:
|
|
| 66 |
rgb = color_map.get(color, (70, 130, 180))
|
| 67 |
return Image.fromarray(np.full((height, width, 3), rgb, dtype=np.uint8))
|
| 68 |
|
| 69 |
-
|
| 70 |
def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
|
| 71 |
try:
|
| 72 |
-
if not prompt.strip():
|
| 73 |
return None, "Please enter a prompt"
|
| 74 |
models = [
|
| 75 |
"black-forest-labs/FLUX.1-schnell",
|
| 76 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 77 |
-
"runwayml/stable-diffusion-v1-5"
|
| 78 |
]
|
| 79 |
enhanced_prompt = f"professional video background, {prompt}, high quality, 16:9, cinematic lighting, detailed"
|
|
|
|
|
|
|
| 80 |
for model in models:
|
| 81 |
try:
|
| 82 |
url = f"https://api-inference.huggingface.co/models/{model}"
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
r = requests.post(url, headers=headers, json=payload, timeout=60, stream=True)
|
| 86 |
if r.status_code == 200 and "image" in r.headers.get("content-type", "").lower():
|
| 87 |
buf = io.BytesIO(r.content if r.raw is None else r.raw.read())
|
|
@@ -95,11 +98,10 @@ def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
|
|
| 95 |
logger.error(f"AI background error: {e}")
|
| 96 |
return create_gradient_background("default", 1920, 1080), "Default due to error"
|
| 97 |
|
| 98 |
-
|
| 99 |
# ---- MyAvatar API ----
|
| 100 |
class MyAvatarAPI:
|
| 101 |
def __init__(self):
|
| 102 |
-
self.api_base = "https://app.myavatar.dk/api"
|
| 103 |
self.videos_cache: List[Dict[str, Any]] = []
|
| 104 |
self.last_refresh = 0
|
| 105 |
|
|
@@ -140,11 +142,20 @@ def get_video_url(self, selection: str) -> Optional[str]:
|
|
| 140 |
logger.error(f"Parse selection failed: {e}")
|
| 141 |
return None
|
| 142 |
|
| 143 |
-
|
| 144 |
myavatar_api = MyAvatarAPI()
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
def process_video_with_background_stoppable(
|
| 149 |
input_video: Optional[str],
|
| 150 |
myavatar_selection: str,
|
|
@@ -154,15 +165,12 @@ def process_video_with_background_stoppable(
|
|
| 154 |
custom_background: Optional[str],
|
| 155 |
ai_prompt: str
|
| 156 |
):
|
| 157 |
-
|
| 158 |
-
from pipeline import processing_active as _active_ref # ensure we use the module global
|
| 159 |
-
import pipeline # to toggle the flag
|
| 160 |
-
|
| 161 |
-
pipeline.processing_active = True
|
| 162 |
try:
|
| 163 |
-
|
|
|
|
| 164 |
|
| 165 |
-
#
|
| 166 |
video_path = None
|
| 167 |
if input_video:
|
| 168 |
video_path = input_video
|
|
@@ -173,16 +181,23 @@ def process_video_with_background_stoppable(
|
|
| 173 |
r.raise_for_status()
|
| 174 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 175 |
for chunk in r.iter_content(chunk_size=1 << 20):
|
|
|
|
|
|
|
|
|
|
| 176 |
if chunk:
|
| 177 |
tmp.write(chunk)
|
| 178 |
video_path = tmp.name
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
if not video_path:
|
| 181 |
yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
|
| 182 |
return
|
| 183 |
|
| 184 |
-
#
|
| 185 |
-
yield gr.update(visible=False), gr.update(visible=True), None, "
|
| 186 |
bg_img = None
|
| 187 |
if background_type == "gradient":
|
| 188 |
bg_img = create_gradient_background(gradient_type, 1920, 1080)
|
|
@@ -190,50 +205,68 @@ def process_video_with_background_stoppable(
|
|
| 190 |
bg_img = create_solid_color(solid_color, 1920, 1080)
|
| 191 |
elif background_type == "custom" and custom_background:
|
| 192 |
try:
|
| 193 |
-
from PIL import Image
|
| 194 |
bg_img = Image.open(custom_background).convert("RGB")
|
| 195 |
except Exception:
|
| 196 |
bg_img = None
|
| 197 |
elif background_type == "ai" and ai_prompt:
|
| 198 |
bg_img, _ = generate_ai_background(ai_prompt)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
if bg_img is None:
|
| 201 |
yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
|
| 202 |
return
|
| 203 |
|
| 204 |
-
# process
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
except Exception:
|
| 216 |
-
pass
|
| 217 |
|
| 218 |
-
if
|
| 219 |
-
yield gr.update(visible=True), gr.update(visible=False),
|
| 220 |
else:
|
| 221 |
-
yield gr.update(visible=True), gr.update(visible=False), None, "Processing
|
| 222 |
|
| 223 |
except Exception as e:
|
| 224 |
logger.error(f"UI pipeline error: {e}")
|
| 225 |
yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
|
| 226 |
finally:
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
return gr.update(visible=False), "Processing stopped by user"
|
| 234 |
-
|
| 235 |
|
| 236 |
# ---- UI factory ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
def create_interface():
|
| 238 |
css = """
|
| 239 |
.main-container { max-width: 1200px; margin: 0 auto; }
|
|
@@ -241,13 +274,12 @@ def create_interface():
|
|
| 241 |
.gradient-preview { border: 2px solid #ddd; border-radius: 10px; }
|
| 242 |
"""
|
| 243 |
|
| 244 |
-
with gr.Blocks(css=css, title="BackgroundFX Pro
|
| 245 |
-
gr.Markdown("# BackgroundFX Pro
|
| 246 |
|
| 247 |
with gr.Row():
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
gr.Markdown(f"**System Status:** Online | **GPU:** {GPU_NAME} | **SAM2:** {sam2_status} | **MatAnyone:** {matany_status}")
|
| 251 |
|
| 252 |
with gr.Row():
|
| 253 |
with gr.Column(scale=1):
|
|
@@ -277,19 +309,19 @@ def create_interface():
|
|
| 277 |
ai_preview = gr.Image(label="AI Generated Background", height=150, visible=False)
|
| 278 |
|
| 279 |
with gr.Row():
|
| 280 |
-
process_btn = gr.Button("Process Video", variant="primary"
|
| 281 |
-
stop_btn = gr.Button("Stop Processing", variant="stop",
|
| 282 |
|
| 283 |
with gr.Column(scale=1):
|
| 284 |
gr.Markdown("## Results")
|
| 285 |
result_video = gr.Video(label="Processed Video", height=400)
|
| 286 |
status_output = gr.Textbox(label="Processing Status", lines=5, max_lines=10, elem_classes=["status-box"])
|
| 287 |
gr.Markdown("""
|
| 288 |
-
###
|
| 289 |
-
1.
|
| 290 |
-
2.
|
| 291 |
-
3.
|
| 292 |
-
4.
|
| 293 |
""")
|
| 294 |
|
| 295 |
# handlers
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
BackgroundFX Pro — Gradio UI, background generators, and data sources (Hardened)
|
| 4 |
+
- No top-level import of pipeline (lazy import in handlers)
|
| 5 |
+
- Compatible with pipeline.process()
|
| 6 |
"""
|
| 7 |
|
| 8 |
import io
|
|
|
|
| 16 |
from PIL import Image
|
| 17 |
import gradio as gr
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
logger = logging.getLogger("ui")
|
| 20 |
+
if not logger.handlers:
|
| 21 |
+
h = logging.StreamHandler()
|
| 22 |
+
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 23 |
+
logger.addHandler(h)
|
| 24 |
+
logger.setLevel(logging.INFO)
|
| 25 |
|
| 26 |
# ---- Background generators ----
|
| 27 |
def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
|
|
|
|
| 52 |
img[i, :] = [r, g, b]
|
| 53 |
return Image.fromarray(img)
|
| 54 |
|
|
|
|
| 55 |
def create_solid_color(color: str, width: int, height: int) -> Image.Image:
|
| 56 |
color_map = {
|
| 57 |
"white": (255, 255, 255),
|
|
|
|
| 66 |
rgb = color_map.get(color, (70, 130, 180))
|
| 67 |
return Image.fromarray(np.full((height, width, 3), rgb, dtype=np.uint8))
|
| 68 |
|
|
|
|
| 69 |
def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
|
| 70 |
try:
|
| 71 |
+
if not prompt or not prompt.strip():
|
| 72 |
return None, "Please enter a prompt"
|
| 73 |
models = [
|
| 74 |
"black-forest-labs/FLUX.1-schnell",
|
| 75 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 76 |
+
"runwayml/stable-diffusion-v1-5",
|
| 77 |
]
|
| 78 |
enhanced_prompt = f"professional video background, {prompt}, high quality, 16:9, cinematic lighting, detailed"
|
| 79 |
+
token = os.getenv("HUGGINGFACE_TOKEN", "")
|
| 80 |
+
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
| 81 |
for model in models:
|
| 82 |
try:
|
| 83 |
url = f"https://api-inference.huggingface.co/models/{model}"
|
| 84 |
+
payload = {
|
| 85 |
+
"inputs": enhanced_prompt,
|
| 86 |
+
"parameters": {"width": 1024, "height": 576, "num_inference_steps": 20, "guidance_scale": 7.5},
|
| 87 |
+
}
|
| 88 |
r = requests.post(url, headers=headers, json=payload, timeout=60, stream=True)
|
| 89 |
if r.status_code == 200 and "image" in r.headers.get("content-type", "").lower():
|
| 90 |
buf = io.BytesIO(r.content if r.raw is None else r.raw.read())
|
|
|
|
| 98 |
logger.error(f"AI background error: {e}")
|
| 99 |
return create_gradient_background("default", 1920, 1080), "Default due to error"
|
| 100 |
|
|
|
|
| 101 |
# ---- MyAvatar API ----
|
| 102 |
class MyAvatarAPI:
|
| 103 |
def __init__(self):
|
| 104 |
+
self.api_base = os.getenv("MYAVATAR_API_BASE", "https://app.myavatar.dk/api")
|
| 105 |
self.videos_cache: List[Dict[str, Any]] = []
|
| 106 |
self.last_refresh = 0
|
| 107 |
|
|
|
|
| 142 |
logger.error(f"Parse selection failed: {e}")
|
| 143 |
return None
|
| 144 |
|
|
|
|
| 145 |
myavatar_api = MyAvatarAPI()
|
| 146 |
|
| 147 |
+
# ---- Minimal stop flag (request-scoped) ----
|
| 148 |
+
# We avoid pipeline globals; this just short-circuits the generator.
|
| 149 |
+
class Stopper:
|
| 150 |
+
def __init__(self):
|
| 151 |
+
self.stop = False
|
| 152 |
+
STOP = Stopper()
|
| 153 |
|
| 154 |
+
def stop_processing_button():
|
| 155 |
+
STOP.stop = True
|
| 156 |
+
return gr.update(visible=False), "Processing stopped by user"
|
| 157 |
+
|
| 158 |
+
# ---- UI ↔ Pipeline bridge ----
|
| 159 |
def process_video_with_background_stoppable(
|
| 160 |
input_video: Optional[str],
|
| 161 |
myavatar_selection: str,
|
|
|
|
| 165 |
custom_background: Optional[str],
|
| 166 |
ai_prompt: str
|
| 167 |
):
|
| 168 |
+
import importlib
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
try:
|
| 170 |
+
STOP.stop = False
|
| 171 |
+
yield gr.update(visible=False), gr.update(visible=True), None, "Starting…"
|
| 172 |
|
| 173 |
+
# Resolve video
|
| 174 |
video_path = None
|
| 175 |
if input_video:
|
| 176 |
video_path = input_video
|
|
|
|
| 181 |
r.raise_for_status()
|
| 182 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 183 |
for chunk in r.iter_content(chunk_size=1 << 20):
|
| 184 |
+
if STOP.stop:
|
| 185 |
+
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 186 |
+
return
|
| 187 |
if chunk:
|
| 188 |
tmp.write(chunk)
|
| 189 |
video_path = tmp.name
|
| 190 |
|
| 191 |
+
if STOP.stop:
|
| 192 |
+
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
if not video_path:
|
| 196 |
yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
|
| 197 |
return
|
| 198 |
|
| 199 |
+
# Background
|
| 200 |
+
yield gr.update(visible=False), gr.update(visible=True), None, "Preparing background…"
|
| 201 |
bg_img = None
|
| 202 |
if background_type == "gradient":
|
| 203 |
bg_img = create_gradient_background(gradient_type, 1920, 1080)
|
|
|
|
| 205 |
bg_img = create_solid_color(solid_color, 1920, 1080)
|
| 206 |
elif background_type == "custom" and custom_background:
|
| 207 |
try:
|
|
|
|
| 208 |
bg_img = Image.open(custom_background).convert("RGB")
|
| 209 |
except Exception:
|
| 210 |
bg_img = None
|
| 211 |
elif background_type == "ai" and ai_prompt:
|
| 212 |
bg_img, _ = generate_ai_background(ai_prompt)
|
| 213 |
|
| 214 |
+
if STOP.stop:
|
| 215 |
+
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
if bg_img is None:
|
| 219 |
yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
|
| 220 |
return
|
| 221 |
|
| 222 |
+
# Save background to a temp file for pipeline.process()
|
| 223 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_bg:
|
| 224 |
+
bg_img.save(tmp_bg.name, format="PNG")
|
| 225 |
+
bg_path = tmp_bg.name
|
| 226 |
+
|
| 227 |
+
# Run pipeline lazily
|
| 228 |
+
yield gr.update(visible=False), gr.update(visible=True), None, "Processing video…"
|
| 229 |
+
pipe = importlib.import_module("pipeline")
|
| 230 |
+
out_path, diag = pipe.process(
|
| 231 |
+
video_path=video_path,
|
| 232 |
+
bg_image_path=bg_path,
|
| 233 |
+
point_x=None,
|
| 234 |
+
point_y=None,
|
| 235 |
+
auto_box=True,
|
| 236 |
+
work_dir=None
|
| 237 |
+
)
|
| 238 |
|
| 239 |
+
if STOP.stop:
|
| 240 |
+
yield gr.update(visible=True), gr.update(visible=False), None, "Stopped."
|
| 241 |
+
return
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
if out_path:
|
| 244 |
+
yield gr.update(visible=True), gr.update(visible=False), out_path, "Video processing completed successfully!"
|
| 245 |
else:
|
| 246 |
+
yield gr.update(visible=True), gr.update(visible=False), None, f"Processing failed: {diag.get('error','unknown error')}"
|
| 247 |
|
| 248 |
except Exception as e:
|
| 249 |
logger.error(f"UI pipeline error: {e}")
|
| 250 |
yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {e}"
|
| 251 |
finally:
|
| 252 |
+
# Best-effort cleanup of any temp download
|
| 253 |
+
try:
|
| 254 |
+
if input_video is None and 'video_path' in locals() and video_path and os.path.exists(video_path):
|
| 255 |
+
os.unlink(video_path)
|
| 256 |
+
except Exception:
|
| 257 |
+
pass
|
|
|
|
|
|
|
| 258 |
|
| 259 |
# ---- UI factory ----
|
| 260 |
+
def _system_status():
|
| 261 |
+
# Avoid early CUDA probing: only show torch version if available
|
| 262 |
+
try:
|
| 263 |
+
import torch
|
| 264 |
+
tver = getattr(torch, "__version__", "?")
|
| 265 |
+
cver = getattr(getattr(torch, "version", None), "cuda", None)
|
| 266 |
+
return f"torch {tver} (CUDA {cver})"
|
| 267 |
+
except Exception:
|
| 268 |
+
return "torch not available"
|
| 269 |
+
|
| 270 |
def create_interface():
|
| 271 |
css = """
|
| 272 |
.main-container { max-width: 1200px; margin: 0 auto; }
|
|
|
|
| 274 |
.gradient-preview { border: 2px solid #ddd; border-radius: 10px; }
|
| 275 |
"""
|
| 276 |
|
| 277 |
+
with gr.Blocks(css=css, title="BackgroundFX Pro") as app:
|
| 278 |
+
gr.Markdown("# BackgroundFX Pro — SAM2 + MatAnyone (Hardened)")
|
| 279 |
|
| 280 |
with gr.Row():
|
| 281 |
+
status = _system_status()
|
| 282 |
+
gr.Markdown(f"**System Status:** Online | **Runtime:** {status}")
|
|
|
|
| 283 |
|
| 284 |
with gr.Row():
|
| 285 |
with gr.Column(scale=1):
|
|
|
|
| 309 |
ai_preview = gr.Image(label="AI Generated Background", height=150, visible=False)
|
| 310 |
|
| 311 |
with gr.Row():
|
| 312 |
+
process_btn = gr.Button("Process Video", variant="primary")
|
| 313 |
+
stop_btn = gr.Button("Stop Processing", variant="stop", visible=False)
|
| 314 |
|
| 315 |
with gr.Column(scale=1):
|
| 316 |
gr.Markdown("## Results")
|
| 317 |
result_video = gr.Video(label="Processed Video", height=400)
|
| 318 |
status_output = gr.Textbox(label="Processing Status", lines=5, max_lines=10, elem_classes=["status-box"])
|
| 319 |
gr.Markdown("""
|
| 320 |
+
### Pipeline
|
| 321 |
+
1. SAM2 Segmentation → mask
|
| 322 |
+
2. MatAnyone Matting → FG + ALPHA
|
| 323 |
+
3. Stage-A export (transparent WebM or checkerboard)
|
| 324 |
+
4. Final compositing (H.264)
|
| 325 |
""")
|
| 326 |
|
| 327 |
# handlers
|