MogensR commited on
Commit
87680ff
Β·
1 Parent(s): 6fd9fe7

Update to Streamlit UI with new features and logo

Browse files
Dockerfile CHANGED
@@ -1,6 +1,6 @@
1
  # ===============================
2
  # Hugging Face Space β€” Stable Dockerfile
3
- # CUDA 12.1.1 + PyTorch 2.5.1 (cu121) + Gradio 4.41.3
4
  # SAM2 installed from source; MatAnyone via pip (repo)
5
  # ===============================
6
 
@@ -20,7 +20,9 @@ ENV DEBIAN_FRONTEND=noninteractive \
20
  NUMEXPR_NUM_THREADS=1 \
21
  HF_HOME=/home/user/app/.hf \
22
  TORCH_HOME=/home/user/app/.torch \
23
- GRADIO_SERVER_PORT=7860
 
 
24
 
25
  # ---- Non-root user ----
26
  RUN useradd -m -u 1000 user
@@ -34,7 +36,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
34
  build-essential gcc g++ pkg-config \
35
  libffi-dev libssl-dev libc6-dev \
36
  libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 libgomp1 \
37
- && rm -rf /var/lib/apt/lists/*
38
 
39
  # ---- Python bootstrap ----
40
  RUN python3 -m pip install --upgrade pip setuptools wheel
@@ -42,17 +44,11 @@ RUN python3 -m pip install --upgrade pip setuptools wheel
42
  # ---- Install PyTorch (CUDA 12.1 wheels) ----
43
  RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
44
  torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
45
- && python3 - <<'PY'
46
  import torch
47
  print("PyTorch:", torch.__version__)
48
  print("CUDA available:", torch.cuda.is_available())
49
  print("torch.version.cuda:", getattr(torch.version, "cuda", None))
50
- try:
51
- import torchaudio, torchvision
52
- print("torchaudio:", torchaudio.__version__)
53
- import torchvision as tv; print("torchvision:", tv.__version__)
54
- except Exception as e:
55
- print("aux libs check:", e)
56
  PY
57
 
58
  # ---- Copy deps first (better caching) ----
@@ -92,19 +88,20 @@ RUN mkdir -p /home/user/app/checkpoints /home/user/app/.hf /home/user/app/.torch
92
  chmod -R 755 /home/user/app && \
93
  find /home/user/app -type d -exec chmod 755 {} \; && \
94
  find /home/user/app -type f -exec chmod 644 {} \; && \
95
- chmod +x /home/user/app/ui.py || true
96
 
97
- # ---- Healthcheck (use exec-form, no heredoc) ----
98
  HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD \
99
  ["python3","-c","import torch; print('torch', torch.__version__, '| cuda', getattr(torch.version,'cuda',None), '| ok=', torch.cuda.is_available())"]
100
 
101
  # ---- Runtime ----
102
  USER user
103
- EXPOSE 7860
104
 
 
105
  CMD ["sh", "-c", "\
106
  echo '===========================================' && \
107
- echo '=== BACKGROUNDFX PRO CONTAINER STARTUP ===' && \
108
  echo '===========================================' && \
109
  echo 'Timestamp:' $(date) && \
110
  echo 'Current directory:' $(pwd) && \
@@ -115,23 +112,22 @@ CMD ["sh", "-c", "\
115
  echo 'Files in app directory:' && \
116
  ls -la && \
117
  echo '' && \
118
- echo '=== UI.PY VERIFICATION ===' && \
119
- if [ -f ui.py ]; then \
120
- echo 'βœ… ui.py found' && \
121
- echo 'File size:' $(wc -c < ui.py) 'bytes' && \
122
- echo 'File permissions:' $(ls -l ui.py) && \
123
  echo 'Testing Python imports...' && \
124
- python3 -B -c 'import gradio; print(\"βœ… Gradio:\", gradio.__version__)' && \
125
  python3 -B -c 'import torch; print(\"βœ… Torch:\", torch.__version__)' && \
126
- echo 'Testing ui.py import...' && \
127
- python3 -B -c 'import sys; sys.path.insert(0, \".\"); import ui; print(\"βœ… ui.py imports successfully\")' && \
128
  echo 'βœ… All checks passed!'; \
129
  else \
130
- echo '❌ ERROR: ui.py not found!' && \
131
  exit 1; \
132
  fi && \
133
  echo '' && \
134
- echo '=== STARTING APPLICATION ===' && \
135
- echo 'Launching ui.py with bytecode disabled...' && \
136
- python3 -B -u ui.py \
137
- "]
 
1
  # ===============================
2
  # Hugging Face Space β€” Stable Dockerfile
3
+ # CUDA 12.1.1 + PyTorch 2.5.1 (cu121) + Streamlit 1.32.0
4
  # SAM2 installed from source; MatAnyone via pip (repo)
5
  # ===============================
6
 
 
20
  NUMEXPR_NUM_THREADS=1 \
21
  HF_HOME=/home/user/app/.hf \
22
  TORCH_HOME=/home/user/app/.torch \
23
+ STREAMLIT_SERVER_PORT=8501 \
24
+ STREAMLIT_SERVER_HEADLESS=true \
25
+ STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
26
 
27
  # ---- Non-root user ----
28
  RUN useradd -m -u 1000 user
 
36
  build-essential gcc g++ pkg-config \
37
  libffi-dev libssl-dev libc6-dev \
38
  libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 libgomp1 \
39
+ && rm -rf /var/lib/apt/lists/*
40
 
41
  # ---- Python bootstrap ----
42
  RUN python3 -m pip install --upgrade pip setuptools wheel
 
44
  # ---- Install PyTorch (CUDA 12.1 wheels) ----
45
  RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
46
  torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
47
+ && python3 - <<'PY'
48
  import torch
49
  print("PyTorch:", torch.__version__)
50
  print("CUDA available:", torch.cuda.is_available())
51
  print("torch.version.cuda:", getattr(torch.version, "cuda", None))
 
 
 
 
 
 
52
  PY
53
 
54
  # ---- Copy deps first (better caching) ----
 
88
  chmod -R 755 /home/user/app && \
89
  find /home/user/app -type d -exec chmod 755 {} \; && \
90
  find /home/user/app -type f -exec chmod 644 {} \; && \
91
+ chmod +x /home/user/app/app.py || true
92
 
93
+ # ---- Healthcheck ----
94
  HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD \
95
  ["python3","-c","import torch; print('torch', torch.__version__, '| cuda', getattr(torch.version,'cuda',None), '| ok=', torch.cuda.is_available())"]
96
 
97
  # ---- Runtime ----
98
  USER user
99
+ EXPOSE 8501
100
 
101
+ # Streamlit server command
102
  CMD ["sh", "-c", "\
103
  echo '===========================================' && \
104
+ echo '=== MYAVATAR STREAMLIT CONTAINER STARTUP ===' && \
105
  echo '===========================================' && \
106
  echo 'Timestamp:' $(date) && \
107
  echo 'Current directory:' $(pwd) && \
 
112
  echo 'Files in app directory:' && \
113
  ls -la && \
114
  echo '' && \
115
+ echo '=== APP.PY VERIFICATION ===' && \
116
+ if [ -f app.py ]; then \
117
+ echo 'βœ… app.py found' && \
118
+ echo 'File size:' $(wc -c < app.py) 'bytes' && \
119
+ echo 'File permissions:' $(ls -l app.py) && \
120
  echo 'Testing Python imports...' && \
121
+ python3 -B -c 'import streamlit; print(\"βœ… Streamlit:\", streamlit.__version__)' && \
122
  python3 -B -c 'import torch; print(\"βœ… Torch:\", torch.__version__)' && \
123
+ echo 'Testing app.py import...' && \
124
+ python3 -B -c 'import sys; sys.path.insert(0, \".\"); import app; print(\"βœ… app.py imports successfully\")' && \
125
  echo 'βœ… All checks passed!'; \
126
  else \
127
+ echo '❌ ERROR: app.py not found!' && \
128
  exit 1; \
129
  fi && \
130
  echo '' && \
131
+ echo '=== STARTING STREAMLIT SERVER ===' && \
132
+ streamlit run --server.port=8501 --server.address=0.0.0.0 app.py \
133
+ "]
 
VideoBackgroundReplacer2/.dockerignore ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===========================
2
+ # .dockerignore for HF Spaces
3
+ # ===========================
4
+
5
+ # VCS
6
+ .git
7
+ .gitignore
8
+ .gitattributes
9
+
10
+ # Python cache / build
11
+ __pycache__/
12
+ *.py[cod]
13
+ *.pyo
14
+ *.pyd
15
+ *.pdb
16
+ *.egg-info/
17
+ dist/
18
+ build/
19
+ .pytest_cache/
20
+ .python-version
21
+
22
+ # Virtual environments
23
+ .env
24
+ .venv/
25
+ env/
26
+ venv/
27
+
28
+ # External repos (cloned in Docker, not copied from local)
29
+ third_party/
30
+
31
+ # Hugging Face / Torch caches
32
+ .cache/
33
+ huggingface/
34
+ torch/
35
+ data/
36
+
37
+ # HF Space metadata/state
38
+ .hf_space/
39
+ space.log
40
+ gradio_cached_examples/
41
+ gradio_static/
42
+ __outputs__/
43
+
44
+ # Logs & temp files
45
+ *.log
46
+ logs/
47
+ tmp/
48
+ temp/
49
+ *.swp
50
+ .coverage
51
+ coverage.xml
52
+
53
+ # Media test assets
54
+ *.mp4
55
+ *.avi
56
+ *.mov
57
+ *.mkv
58
+ *.png
59
+ *.jpg
60
+ *.jpeg
61
+ *.gif
62
+
63
+ # OS / IDE cruft
64
+ .DS_Store
65
+ Thumbs.db
66
+ .vscode/
67
+ .idea/
68
+ *.sublime-project
69
+ *.sublime-workspace
70
+
71
+ # Node / frontend (if present)
72
+ node_modules/
73
+ npm-debug.log
74
+ yarn-debug.log
75
+ yarn-error.log
76
+
77
+ # ---- Optional: allow specific checkpoints if needed ----
78
+ !checkpoints/
VideoBackgroundReplacer2/5.0.0 ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Defaulting to user installation because normal site-packages is not writeable
2
+ Requirement already satisfied: gradio in c:\users\mogen\appdata\roaming\python\python313\site-packages (4.44.0)
3
+ Requirement already satisfied: aiofiles<24.0,>=22.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (23.2.1)
4
+ Requirement already satisfied: anyio<5.0,>=3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (4.9.0)
5
+ Requirement already satisfied: fastapi<1.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.115.12)
6
+ Requirement already satisfied: ffmpy in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.6.1)
7
+ Requirement already satisfied: gradio-client==1.3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (1.3.0)
8
+ Requirement already satisfied: httpx>=0.24.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.27.2)
9
+ Requirement already satisfied: huggingface-hub>=0.19.3 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.34.4)
10
+ Requirement already satisfied: importlib-resources<7.0,>=1.3 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (6.5.2)
11
+ Requirement already satisfied: jinja2<4.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (3.1.6)
12
+ Requirement already satisfied: markupsafe~=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.1.5)
13
+ Requirement already satisfied: matplotlib~=3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (3.10.5)
14
+ Requirement already satisfied: numpy<3.0,>=1.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (1.26.4)
15
+ Requirement already satisfied: orjson~=3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (3.11.2)
16
+ Requirement already satisfied: packaging in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (24.2)
17
+ Requirement already satisfied: pandas<3.0,>=1.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.2.3)
18
+ Requirement already satisfied: pillow<11.0,>=8.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (10.4.0)
19
+ Requirement already satisfied: pydantic>=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.11.5)
20
+ Requirement already satisfied: pydub in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.25.1)
21
+ Requirement already satisfied: python-multipart>=0.0.9 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.0.20)
22
+ Requirement already satisfied: pyyaml<7.0,>=5.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (6.0.2)
23
+ Requirement already satisfied: ruff>=0.2.2 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.12.9)
24
+ Requirement already satisfied: semantic-version~=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.10.0)
25
+ Requirement already satisfied: tomlkit==0.12.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.12.0)
26
+ Requirement already satisfied: typer<1.0,>=0.12 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.16.0)
27
+ Requirement already satisfied: typing-extensions~=4.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (4.14.1)
28
+ Requirement already satisfied: urllib3~=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.5.0)
29
+ Requirement already satisfied: uvicorn>=0.14.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.34.3)
30
+ Requirement already satisfied: fsspec in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio-client==1.3.0->gradio) (2025.5.1)
31
+ Requirement already satisfied: websockets<13.0,>=10.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio-client==1.3.0->gradio) (10.4)
32
+ Requirement already satisfied: idna>=2.8 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from anyio<5.0,>=3.0->gradio) (3.10)
33
+ Requirement already satisfied: sniffio>=1.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)
34
+ Requirement already satisfied: starlette<0.47.0,>=0.40.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from fastapi<1.0->gradio) (0.46.2)
35
+ Requirement already satisfied: certifi in c:\users\mogen\appdata\roaming\python\python313\site-packages (from httpx>=0.24.1->gradio) (2025.7.9)
36
+ Requirement already satisfied: httpcore==1.* in c:\users\mogen\appdata\roaming\python\python313\site-packages (from httpx>=0.24.1->gradio) (1.0.9)
37
+ Requirement already satisfied: h11>=0.16 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.16.0)
38
+ Requirement already satisfied: filelock in c:\users\mogen\appdata\roaming\python\python313\site-packages (from huggingface-hub>=0.19.3->gradio) (3.18.0)
39
+ Requirement already satisfied: requests in c:\users\mogen\appdata\roaming\python\python313\site-packages (from huggingface-hub>=0.19.3->gradio) (2.32.3)
40
+ Requirement already satisfied: tqdm>=4.42.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from huggingface-hub>=0.19.3->gradio) (4.67.1)
41
+ Requirement already satisfied: contourpy>=1.0.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (1.3.3)
42
+ Requirement already satisfied: cycler>=0.10 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (0.12.1)
43
+ Requirement already satisfied: fonttools>=4.22.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (4.59.1)
44
+ Requirement already satisfied: kiwisolver>=1.3.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (1.4.9)
45
+ Requirement already satisfied: pyparsing>=2.3.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (3.2.3)
46
+ Requirement already satisfied: python-dateutil>=2.7 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (2.8.2)
47
+ Requirement already satisfied: pytz>=2020.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
48
+ Requirement already satisfied: tzdata>=2022.7 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
49
+ Requirement already satisfied: annotated-types>=0.6.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pydantic>=2.0->gradio) (0.7.0)
50
+ Requirement already satisfied: pydantic-core==2.33.2 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pydantic>=2.0->gradio) (2.33.2)
51
+ Requirement already satisfied: typing-inspection>=0.4.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pydantic>=2.0->gradio) (0.4.1)
52
+ Requirement already satisfied: click>=8.0.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from typer<1.0,>=0.12->gradio) (8.2.1)
53
+ Requirement already satisfied: shellingham>=1.3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from typer<1.0,>=0.12->gradio) (1.5.4)
54
+ Requirement already satisfied: rich>=10.11.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from typer<1.0,>=0.12->gradio) (14.0.0)
55
+ Requirement already satisfied: colorama in c:\users\mogen\appdata\roaming\python\python313\site-packages (from click>=8.0.0->typer<1.0,>=0.12->gradio) (0.4.6)
56
+ Requirement already satisfied: six>=1.5 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.17.0)
57
+ Requirement already satisfied: markdown-it-py>=2.2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)
58
+ Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.1)
59
+ Requirement already satisfied: charset-normalizer<4,>=2 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.4.2)
60
+ Requirement already satisfied: mdurl~=0.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)
VideoBackgroundReplacer2/DEPLOYMENT.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VideoBackgroundReplacer2 Deployment Guide
2
+
3
+ This guide provides instructions for deploying the VideoBackgroundReplacer2 application to Hugging Face Spaces with GPU acceleration.
4
+
5
+ ## Prerequisites
6
+
7
+ - Docker
8
+ - Git
9
+ - Python 3.8+
10
+ - NVIDIA Container Toolkit (for local GPU testing)
11
+ - Hugging Face account with access to GPU Spaces
12
+
13
+ ## Local Development
14
+
15
+ ### 1. Clone the repository
16
+ ```bash
17
+ git clone <repository-url>
18
+ cd VideoBackgroundReplacer2
19
+ ```
20
+
21
+ ### 2. Build the Docker image
22
+ ```bash
23
+ # Make the build script executable
24
+ chmod +x build_and_deploy.sh
25
+
26
+ # Build the image
27
+ ./build_and_deploy.sh
28
+ ```
29
+
30
+ ### 3. Run the container locally
31
+ ```bash
32
+ docker run --gpus all -p 7860:7860 -v $(pwd)/checkpoints:/home/user/app/checkpoints videobackgroundreplacer2:latest
33
+ ```
34
+
35
+ ## Hugging Face Spaces Deployment
36
+
37
+ ### 1. Create a new Space
38
+ - Go to [Hugging Face Spaces](https://huggingface.co/spaces)
39
+ - Click "Create new Space"
40
+ - Select "Docker" as the SDK
41
+ - Choose a name and set the space to private if needed
42
+ - Select GPU as the hardware
43
+
44
+ ### 2. Configure the Space
45
+ Add the following environment variables to your Space settings:
46
+ - `SAM2_DEVICE`: `cuda`
47
+ - `MATANY_DEVICE`: `cuda`
48
+ - `PYTORCH_CUDA_ALLOC_CONF`: `max_split_size_mb:256,garbage_collection_threshold:0.8`
49
+ - `TORCH_CUDA_ARCH_LIST`: `7.5 8.0 8.6+PTX`
50
+
51
+ ### 3. Deploy to Hugging Face
52
+ ```bash
53
+ # Set your Hugging Face token
54
+ export HF_TOKEN=your_hf_token
55
+ export HF_USERNAME=your_username
56
+
57
+ # Build and deploy
58
+ ./build_and_deploy.sh
59
+ ```
60
+
61
+ ## Health Check
62
+
63
+ You can verify the installation by running:
64
+ ```bash
65
+ docker run --rm videobackgroundreplacer2:latest python3 health_check.py
66
+ ```
67
+
68
+ ## Troubleshooting
69
+
70
+ ### Build Failures
71
+ - Ensure you have enough disk space (at least 10GB free)
72
+ - Check Docker logs for specific error messages
73
+ - Verify your internet connection is stable
74
+
75
+ ### Runtime Issues
76
+ - Check container logs: `docker logs <container_id>`
77
+ - Verify GPU is detected: `nvidia-smi` inside the container
78
+ - Check disk space: `df -h`
79
+
80
+ ## Performance Optimization
81
+
82
+ - For faster inference, use the `sam2_hiera_tiny` model
83
+ - Adjust batch size based on available GPU memory
84
+ - Enable gradient checkpointing for large models
85
+
86
+ ## Monitoring
87
+
88
+ - Use `nvidia-smi` to monitor GPU usage
89
+ - Check container logs for any warnings or errors
90
+ - Monitor memory usage with `htop` or similar tools
VideoBackgroundReplacer2/Dockerfile ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===============================
2
+ # Hugging Face Space β€” Stable Dockerfile
3
+ # CUDA 12.1.1 + PyTorch 2.5.1 (cu121) + Gradio 4.41.3
4
+ # SAM2 installed from source; MatAnyone via pip (repo)
5
+ # ===============================
6
+
7
+ FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
8
+
9
+ # ---- Environment (runtime hygiene) ----
10
+ ENV DEBIAN_FRONTEND=noninteractive \
11
+ PYTHONUNBUFFERED=1 \
12
+ PYTHONDONTWRITEBYTECODE=1 \
13
+ PIP_NO_CACHE_DIR=1 \
14
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
15
+ TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX" \
16
+ CUDA_VISIBLE_DEVICES="0" \
17
+ OMP_NUM_THREADS=4 \
18
+ OPENBLAS_NUM_THREADS=1 \
19
+ MKL_NUM_THREADS=1 \
20
+ NUMEXPR_NUM_THREADS=1 \
21
+ HF_HOME=/home/user/app/.hf \
22
+ TORCH_HOME=/home/user/app/.torch \
23
+ GRADIO_SERVER_PORT=7860
24
+
25
+ # ---- Non-root user ----
26
+ RUN useradd -m -u 1000 user
27
+ ENV HOME=/home/user
28
+ WORKDIR $HOME/app
29
+
30
+ # ---- System deps ----
31
+ RUN apt-get update && apt-get install -y --no-install-recommends \
32
+ git ffmpeg wget curl \
33
+ python3 python3-pip python3-venv python3-dev \
34
+ build-essential gcc g++ pkg-config \
35
+ libffi-dev libssl-dev libc6-dev \
36
+ libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 libgomp1 \
37
+ && rm -rf /var/lib/apt/lists/*
38
+
39
+ # ---- Python bootstrap ----
40
+ RUN python3 -m pip install --upgrade pip setuptools wheel
41
+
42
+ # ---- Install PyTorch (CUDA 12.1 wheels) ----
43
+ RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
44
+ torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
45
+ && python3 - <<'PY'
46
+ import torch
47
+ print("PyTorch:", torch.__version__)
48
+ print("CUDA available:", torch.cuda.is_available())
49
+ print("torch.version.cuda:", getattr(torch.version, "cuda", None))
50
+ try:
51
+ import torchaudio, torchvision
52
+ print("torchaudio:", torchaudio.__version__)
53
+ import torchvision as tv; print("torchvision:", tv.__version__)
54
+ except Exception as e:
55
+ print("aux libs check:", e)
56
+ PY
57
+
58
+ # ---- Copy deps first (better caching) ----
59
+ COPY --chown=user:user requirements.txt ./
60
+
61
+ # ---- Install remaining Python deps ----
62
+ RUN python3 -m pip install --no-cache-dir -r requirements.txt
63
+
64
+ # ---- MatAnyone (pip install from repo with retry) ----
65
+ RUN echo "Installing MatAnyone..." && \
66
+ (python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone || \
67
+ (echo "Retrying MatAnyone..." && \
68
+ python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone)) && \
69
+ python3 -c "import matanyone; print('MatAnyone import OK')"
70
+
71
+ # ---- App code ----
72
+ COPY --chown=user:user . .
73
+
74
+ # ---- SAM2 from source (editable) ----
75
+ RUN echo "Installing SAM2 (editable)..." && \
76
+ git clone --depth=1 https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
77
+ cd third_party/sam2 && python3 -m pip install --no-cache-dir -e .
78
+
79
+ # ---- App env ----
80
+ ENV PYTHONPATH=/home/user/app:/home/user/app/third_party:/home/user/app/third_party/sam2 \
81
+ FFMPEG_BIN=ffmpeg \
82
+ THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
83
+ ENABLE_MATANY=1 \
84
+ SAM2_DEVICE=cuda \
85
+ MATANY_DEVICE=cuda \
86
+ TF_CPP_MIN_LOG_LEVEL=2 \
87
+ SAM2_CHECKPOINT=/home/user/app/checkpoints/sam2_hiera_large.pt
88
+
89
+ # ---- Create writable dirs (caches + checkpoints) ----
90
+ RUN mkdir -p /home/user/app/checkpoints /home/user/app/.hf /home/user/app/.torch && \
91
+ chown -R user:user /home/user/app && \
92
+ chmod -R 755 /home/user/app && \
93
+ find /home/user/app -type d -exec chmod 755 {} \; && \
94
+ find /home/user/app -type f -exec chmod 644 {} \; && \
95
+ chmod +x /home/user/app/ui.py || true
96
+
97
+ # ---- Healthcheck (use exec-form, no heredoc) ----
98
+ HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD \
99
+ ["python3","-c","import torch; print('torch', torch.__version__, '| cuda', getattr(torch.version,'cuda',None), '| ok=', torch.cuda.is_available())"]
100
+
101
+ # ---- Runtime ----
102
+ USER user
103
+ EXPOSE 7860
104
+
105
+ CMD ["sh", "-c", "\
106
+ echo '===========================================' && \
107
+ echo '=== BACKGROUNDFX PRO CONTAINER STARTUP ===' && \
108
+ echo '===========================================' && \
109
+ echo 'Timestamp:' $(date) && \
110
+ echo 'Current directory:' $(pwd) && \
111
+ echo 'Current user:' $(whoami) && \
112
+ echo 'User ID:' $(id) && \
113
+ echo '' && \
114
+ echo '=== FILE SYSTEM CHECK ===' && \
115
+ echo 'Files in app directory:' && \
116
+ ls -la && \
117
+ echo '' && \
118
+ echo '=== UI.PY VERIFICATION ===' && \
119
+ if [ -f ui.py ]; then \
120
+ echo 'βœ… ui.py found' && \
121
+ echo 'File size:' $(wc -c < ui.py) 'bytes' && \
122
+ echo 'File permissions:' $(ls -l ui.py) && \
123
+ echo 'Testing Python imports...' && \
124
+ python3 -B -c 'import gradio; print(\"βœ… Gradio:\", gradio.__version__)' && \
125
+ python3 -B -c 'import torch; print(\"βœ… Torch:\", torch.__version__)' && \
126
+ echo 'Testing ui.py import...' && \
127
+ python3 -B -c 'import sys; sys.path.insert(0, \".\"); import ui; print(\"βœ… ui.py imports successfully\")' && \
128
+ echo 'βœ… All checks passed!'; \
129
+ else \
130
+ echo '❌ ERROR: ui.py not found!' && \
131
+ exit 1; \
132
+ fi && \
133
+ echo '' && \
134
+ echo '=== STARTING APPLICATION ===' && \
135
+ echo 'Launching ui.py with bytecode disabled...' && \
136
+ python3 -B -u ui.py \
137
+ "]
VideoBackgroundReplacer2/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 🎬 BackgroundFX Pro - SAM2 + MatAnyone
3
+ emoji: πŸŽ₯
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ license: mit
9
+ tags:
10
+ - video
11
+ - background-removal
12
+ - segmentation
13
+ - matting
14
+ - SAM2
15
+ - MatAnyone
16
+ ---
17
+
18
+ # 🎬 BackgroundFX Pro β€” Professional Video Background Replacement
19
+
20
+ BackgroundFX Pro is a GPU-accelerated app for Hugging Face Spaces (Docker) that replaces video backgrounds using:
21
+ - **SAM2** β€” high-quality object segmentation
22
+ - **MatAnyone** β€” temporal video matting for stable alpha over time
23
+
24
+ Built on: **CUDA 12.1.1**, **PyTorch 2.5.1 (cu121)**, **torchvision 0.20.1**, **Gradio 4.41.0**.
25
+
26
+ ---
27
+
28
+ ## ✨ Features
29
+
30
+ - Replace backgrounds with: **solid color**, **AI-generated** image (procedural), **custom uploaded image**, or **Unsplash** search
31
+ - Optimized for **T4 GPUs** on Hugging Face
32
+ - Caching & logs stored in the repo volume:
33
+ - HF cache β†’ `./.hf`
34
+ - Torch cache β†’ `./.torch`
35
+ - App data & logs β†’ `./data` (see `data/run.log`)
36
+
37
+ ---
38
+
39
+ ## πŸš€ Try It
40
+
41
+ Open the Space in your browser (GPU required):
42
+ https://huggingface.co/spaces/MogensR/VideoBackgroundReplacer2
43
+
44
+ ---
45
+
46
+ ## πŸ–±οΈ How to Use
47
+
48
+ 1. **Upload a video** (`.mp4`, `.avi`, `.mov`, `.mkv`).
49
+ 2. Choose a **Background Type**: Upload Image, AI Generate, Gradient, Solid, or Unsplash.
50
+ 3. If not uploading, enter a prompt and click **Generate Background**.
51
+ 4. Click **Process Video**.
52
+ 5. Preview and **Download Result**.
53
+
54
+ > Tip: Start with 720p/1080p on T4; 4K can exceed memory.
55
+
56
+ ---
57
+
58
+ ## πŸ—‚οΈ Project Structure (key files)
59
+
60
+ - `Dockerfile`
61
+ - `requirements.txt`
62
+ - `ui.py`
63
+ - `ui_core_interface.py`
64
+ - `ui_core_functionality.py`
65
+ - `two_stage_pipeline.py`
66
+ - `models/sam2_loader.py`
67
+ - `models/matanyone_loader.py`
68
+ - `utils/__init__.py`
69
+ - `data/` (created at runtime for logs/outputs)
70
+ - `tmp/` (created at runtime for jobs/temp files)
71
+
72
+ ---
73
+
74
+ ## βš™οΈ Runtime Notes
75
+
76
+ - Binds to `PORT` / `GRADIO_SERVER_PORT` (defaults to **7860**).
77
+ - Heartbeat logs every ~2s with memory & disk stats.
78
+ - If there’s no final β€œPROCESS EXITING” line, it was likely an **OOM** or hard kill.
79
+
80
+ ---
81
+
82
+ ## πŸ§ͺ Local Development (Docker)
83
+
84
+ Requires an NVIDIA GPU with CUDA drivers.
85
+
86
+ ```bash
87
+ git clone https://huggingface.co/spaces/MogensR/VideoBackgroundReplacer2
88
+ cd VideoBackgroundReplacer2
89
+
90
+ # Build (Ubuntu 22.04, CUDA 12.1.1; installs Torch 2.5.1+cu121)
91
+ docker build -t backgroundfx-pro .
92
+
93
+ # Run
94
+ docker run --gpus all -p 7860:7860 backgroundfx-pro
VideoBackgroundReplacer2/app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
4
+ =======================================================
5
+ - Sets up Gradio UI and launches pipeline
6
+ - Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
7
+
8
+ Changes (2025-09-18):
9
+ - Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
10
+ - Added toggleable "mount mode": run Gradio inside our own FastAPI app
11
+ and provide a safe /config route shim (uses demo.get_config_file()).
12
+ - Kept your startup diagnostics, GPU logging, and heartbeats
13
+ """
14
+
15
+ print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
16
+
17
+ # ---------------------------------------------------------------------
18
+ # Imports & basic setup
19
+ # ---------------------------------------------------------------------
20
+ import sys
21
+ import os
22
+ import gc
23
+ import json
24
+ import logging
25
+ import threading
26
+ import time
27
+ import warnings
28
+ import traceback
29
+ import subprocess
30
+ from pathlib import Path
31
+ from loguru import logger
32
+
33
+ # Logging (loguru to stderr)
34
+ logger.remove()
35
+ logger.add(
36
+ sys.stderr,
37
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> "
38
+ "| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
39
+ )
40
+
41
+ # Warnings
42
+ warnings.filterwarnings("ignore", category=UserWarning)
43
+ warnings.filterwarnings("ignore", category=FutureWarning)
44
+ warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
45
+
46
+ # Environment (lightweight & safe in Spaces)
47
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
48
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
49
+
50
+ # Paths
51
+ BASE_DIR = Path(__file__).parent.absolute()
52
+ THIRD_PARTY_DIR = BASE_DIR / "third_party"
53
+ SAM2_DIR = THIRD_PARTY_DIR / "sam2"
54
+ CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
55
+
56
+ # Python path extends
57
+ for p in (str(THIRD_PARTY_DIR), str(SAM2_DIR)):
58
+ if p not in sys.path:
59
+ sys.path.insert(0, p)
60
+
61
+ logger.info(f"Base directory: {BASE_DIR}")
62
+ logger.info(f"Python path[0:5]: {sys.path[:5]}")
63
+
64
+ # ---------------------------------------------------------------------
65
+ # GPU / Torch diagnostics (non-blocking)
66
+ # ---------------------------------------------------------------------
67
+ try:
68
+ import torch
69
+ except Exception as e:
70
+ logger.warning("Torch import failed at startup: %s", e)
71
+ torch = None
72
+
73
+ DEVICE = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
74
+ if DEVICE == "cuda":
75
+ os.environ["SAM2_DEVICE"] = "cuda"
76
+ os.environ["MATANY_DEVICE"] = "cuda"
77
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
78
+ try:
79
+ logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
80
+ except Exception:
81
+ logger.info("CUDA device name not available at startup.")
82
+ else:
83
+ os.environ["SAM2_DEVICE"] = "cpu"
84
+ os.environ["MATANY_DEVICE"] = "cpu"
85
+ logger.warning("CUDA not available, falling back to CPU")
86
+
87
+ def verify_models():
88
+ """Verify critical model files exist and are loadable (cheap checks)."""
89
+ results = {"status": "success", "details": {}}
90
+ try:
91
+ sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
92
+ if not os.path.exists(sam2_model_path):
93
+ raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
94
+ # Cheap load test (map to CPU to avoid VRAM use during boot)
95
+ if torch:
96
+ sd = torch.load(sam2_model_path, map_location="cpu")
97
+ if not isinstance(sd, dict):
98
+ raise ValueError("Invalid SAM2 checkpoint format")
99
+ results["details"]["sam2"] = {
100
+ "status": "success",
101
+ "path": sam2_model_path,
102
+ "size_mb": round(os.path.getsize(sam2_model_path) / (1024 * 1024), 2),
103
+ }
104
+ except Exception as e:
105
+ results["status"] = "error"
106
+ results["details"]["sam2"] = {
107
+ "status": "error",
108
+ "error": str(e),
109
+ "traceback": traceback.format_exc(),
110
+ }
111
+ return results
112
+
113
+ def run_startup_diagnostics():
114
+ diag = {
115
+ "system": {
116
+ "python": sys.version,
117
+ "pytorch": getattr(torch, "__version__", None) if torch else None,
118
+ "cuda_available": bool(torch and torch.cuda.is_available()),
119
+ "device_count": (torch.cuda.device_count() if torch and torch.cuda.is_available() else 0),
120
+ "cuda_version": getattr(getattr(torch, "version", None), "cuda", None) if torch else None,
121
+ },
122
+ "paths": {
123
+ "base_dir": str(BASE_DIR),
124
+ "checkpoints_dir": str(CHECKPOINTS_DIR),
125
+ "sam2_dir": str(SAM2_DIR),
126
+ },
127
+ "env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
128
+ }
129
+ diag["model_verification"] = verify_models()
130
+ return diag
131
+
132
+ startup_diag = run_startup_diagnostics()
133
+ logger.info("Startup diagnostics completed")
134
+
135
+ # Noisy heartbeat so logs show life during import time
136
+ def _heartbeat():
137
+ i = 0
138
+ while True:
139
+ i += 1
140
+ print(f"[startup-heartbeat] {i*5}s…", flush=True)
141
+ time.sleep(5)
142
+
143
+ threading.Thread(target=_heartbeat, daemon=True).start()
144
+
145
+ # Optional perf tuning import (non-fatal)
146
+ try:
147
+ import perf_tuning # noqa: F401
148
+ logger.info("perf_tuning imported successfully.")
149
+ except Exception as e:
150
+ logger.info("perf_tuning not available: %s", e)
151
+
152
+ # MatAnyone non-instantiating probe
153
+ try:
154
+ import inspect
155
+ from matanyone.inference import inference_core as ic # type: ignore
156
+ sigs = {}
157
+ for name in ("InferenceCore",):
158
+ obj = getattr(ic, name, None)
159
+ if obj:
160
+ sigs[name] = "callable" if callable(obj) else "present"
161
+ logger.info(f"[MATANY] probe (non-instantiating): {sigs}")
162
+ except Exception as e:
163
+ logger.info(f"[MATANY] probe skipped: {e}")
164
+
165
+ # ---------------------------------------------------------------------
166
+ # Gradio import and web-stack probes
167
+ # ---------------------------------------------------------------------
168
+ import gradio as gr
169
+
170
+ # Standard logger for some libs that use stdlib logging
171
+ py_logger = logging.getLogger("backgroundfx_pro")
172
+ if not py_logger.handlers:
173
+ h = logging.StreamHandler()
174
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
175
+ py_logger.addHandler(h)
176
+ py_logger.setLevel(logging.INFO)
177
+
178
+ def _log_web_stack_versions_and_paths():
179
+ import inspect
180
+ try:
181
+ import fastapi, starlette, pydantic, httpx, anyio
182
+ try:
183
+ import pydantic_core
184
+ pc_ver = pydantic_core.__version__
185
+ except Exception:
186
+ pc_ver = "unknown"
187
+ logger.info(
188
+ "[WEB-STACK] fastapi=%s | starlette=%s | pydantic=%s | pydantic-core=%s | httpx=%s | anyio=%s",
189
+ getattr(fastapi, "__version__", "?"),
190
+ getattr(starlette, "__version__", "?"),
191
+ getattr(pydantic, "__version__", "?"),
192
+ pc_ver,
193
+ getattr(httpx, "__version__", "?"),
194
+ getattr(anyio, "__version__", "?"),
195
+ )
196
+ except Exception as e:
197
+ logger.warning("[WEB-STACK] version probe failed: %s", e)
198
+
199
+ try:
200
+ import gradio
201
+ import gradio.routes as gr_routes
202
+ import gradio.queueing as gr_queueing
203
+ logger.info("[PATH] gradio.__file__ = %s", getattr(gradio, "__file__", "?"))
204
+ logger.info("[PATH] gradio.routes = %s", inspect.getfile(gr_routes))
205
+ logger.info("[PATH] gradio.queueing = %s", inspect.getfile(gr_queueing))
206
+ import starlette.exceptions as st_exc
207
+ logger.info("[PATH] starlette.exceptions= %s", inspect.getfile(st_exc))
208
+ except Exception as e:
209
+ logger.warning("[PATH] probe failed: %s", e)
210
+
211
+ def _post_launch_diag():
212
+ try:
213
+ if not torch:
214
+ return
215
+ avail = torch.cuda.is_available()
216
+ logger.info("CUDA available (post-launch): %s", avail)
217
+ if avail:
218
+ idx = torch.cuda.current_device()
219
+ name = torch.cuda.get_device_name(idx)
220
+ cap = torch.cuda.get_device_capability(idx)
221
+ logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
222
+ except Exception as e:
223
+ logger.warning("Post-launch CUDA diag failed: %s", e)
224
+
225
+ # ---------------------------------------------------------------------
226
+ # UI factory (uses your existing builder)
227
+ # ---------------------------------------------------------------------
228
+ def build_ui() -> gr.Blocks:
229
+ # FIX: import from ui_core_interface (not from ui)
230
+ from ui_core_interface import create_interface
231
+ return create_interface()
232
+
233
+ # ---------------------------------------------------------------------
234
+ # Optional: custom FastAPI mount mode
235
+ # ---------------------------------------------------------------------
236
+ def build_fastapi_with_gradio(demo: gr.Blocks):
237
+ """
238
+ Returns a FastAPI app with Gradio mounted at root.
239
+ Also exposes JSON health and a config shim using demo.get_config_file().
240
+ """
241
+ from fastapi import FastAPI
242
+ from fastapi.responses import JSONResponse
243
+
244
+ app = FastAPI(title="VideoBackgroundReplacer2")
245
+
246
+ @app.get("/healthz")
247
+ def _healthz():
248
+ return {"ok": True, "ts": time.time()}
249
+
250
+ @app.get("/config")
251
+ def _config():
252
+ try:
253
+ cfg = demo.get_config_file()
254
+ return JSONResponse(content=cfg)
255
+ except Exception as e:
256
+ return JSONResponse(
257
+ status_code=500,
258
+ content={"error": "config_generation_failed", "detail": str(e)},
259
+ )
260
+
261
+ # Mount Gradio UI at root; our /config route remains at parent level
262
+ app = gr.mount_gradio_app(app, demo, path="/")
263
+ return app
264
+
265
+ # ---------------------------------------------------------------------
266
+ # Entrypoint
267
+ # ---------------------------------------------------------------------
268
+ if __name__ == "__main__":
269
+ host = os.environ.get("HOST", "0.0.0.0")
270
+ port = int(os.environ.get("PORT", "7860"))
271
+ mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
272
+
273
+ logger.info("Launching on %s:%s (mount_mode=%s)…", host, port, mount_mode)
274
+ _log_web_stack_versions_and_paths()
275
+
276
+ demo = build_ui()
277
+ demo.queue(max_size=16, api_open=False)
278
+
279
+ threading.Thread(target=_post_launch_diag, daemon=True).start()
280
+
281
+ if mount_mode:
282
+ try:
283
+ from uvicorn import run as uvicorn_run
284
+ except Exception:
285
+ logger.error("uvicorn is not installed; mount mode cannot start.")
286
+ raise
287
+
288
+ app = build_fastapi_with_gradio(demo)
289
+ uvicorn_run(app=app, host=host, port=port, log_level="info")
290
+ else:
291
+ demo.launch(
292
+ server_name=host,
293
+ server_port=port,
294
+ share=False,
295
+ show_api=False,
296
+ show_error=True,
297
+ quiet=False,
298
+ debug=True,
299
+ max_threads=1,
300
+ )
VideoBackgroundReplacer2/integrated_pipeline.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ integrated_pipeline.py - Two-stage pipeline with fallback compatibility
4
+ - Stage 1: SAM2 -> lossless mask stream + metadata, then unload SAM2
5
+ - Stage 2: Read masks -> MatAnyone -> composite -> final output
6
+ - Maintains compatibility with existing UI calls
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import gc
12
+ import json
13
+ import subprocess
14
+ import tempfile
15
+ from pathlib import Path
16
+ from typing import Dict, Any, Optional, Tuple
17
+ import numpy as np
18
+ import cv2
19
+
20
+ # Add the parent directory to Python path for imports
21
+ current_dir = Path(__file__).parent
22
+ parent_dir = current_dir.parent
23
+ sys.path.append(str(parent_dir))
24
+
25
+ class TwoStageProcessor:
26
+ def __init__(self, temp_dir: Optional[str] = None):
27
+ self.temp_dir = Path(temp_dir) if temp_dir else Path(tempfile.mkdtemp())
28
+ self.temp_dir.mkdir(exist_ok=True)
29
+
30
+ # Stage outputs
31
+ self.masks_path = self.temp_dir / "masks.mkv"
32
+ self.metadata_path = self.temp_dir / "meta.json"
33
+
34
+ def process_video(self, input_video: str, background_video: str,
35
+ click_points: list, output_path: str,
36
+ use_matanyone: bool = True, progress_callback=None) -> bool:
37
+ """
38
+ Main entry point - maintains compatibility with existing UI
39
+ """
40
+ try:
41
+ # Stage 1: Generate masks
42
+ if progress_callback:
43
+ progress_callback("Stage 1: Generating masks with SAM2...")
44
+
45
+ if not self._stage1_generate_masks(input_video, click_points, progress_callback):
46
+ return False
47
+
48
+ # Stage 2: Process and composite
49
+ if progress_callback:
50
+ progress_callback("Stage 2: Processing and compositing...")
51
+
52
+ return self._stage2_composite(input_video, background_video,
53
+ output_path, use_matanyone, progress_callback)
54
+
55
+ except Exception as e:
56
+ print(f"Two-stage processing failed: {e}")
57
+ return False
58
+
59
+ def _stage1_generate_masks(self, input_video: str, click_points: list,
60
+ progress_callback=None) -> bool:
61
+ """Stage 1: SAM2 mask generation with complete memory cleanup"""
62
+ try:
63
+ # Import SAM2 only when needed
64
+ print("Loading SAM2...")
65
+ import torch
66
+ from sam2.build_sam import build_sam2_video_predictor
67
+
68
+ # Initialize SAM2
69
+ checkpoint = "checkpoints/sam2.1_hiera_large.pt"
70
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
71
+
72
+ if not os.path.exists(checkpoint):
73
+ print(f"SAM2 checkpoint not found: {checkpoint}")
74
+ return False
75
+
76
+ predictor = build_sam2_video_predictor(model_cfg, checkpoint)
77
+
78
+ # Get video info
79
+ cap = cv2.VideoCapture(input_video)
80
+ fps = cap.get(cv2.CAP_PROP_FPS)
81
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
82
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
83
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
84
+ cap.release()
85
+
86
+ # Save metadata
87
+ metadata = {
88
+ "fps": fps,
89
+ "frame_count": frame_count,
90
+ "width": width,
91
+ "height": height,
92
+ "click_points": click_points
93
+ }
94
+
95
+ with open(self.metadata_path, 'w') as f:
96
+ json.dump(metadata, f, indent=2)
97
+
98
+ # Initialize inference state
99
+ inference_state = predictor.init_state(video_path=input_video)
100
+
101
+ # Add prompts
102
+ for i, point in enumerate(click_points):
103
+ x, y = point
104
+ predictor.add_new_points_or_box(
105
+ inference_state=inference_state,
106
+ frame_idx=0,
107
+ obj_id=i,
108
+ points=np.array([[x, y]], dtype=np.float32),
109
+ labels=np.array([1], np.int32),
110
+ )
111
+
112
+ # Setup FFmpeg for lossless mask encoding
113
+ ffmpeg_cmd = [
114
+ 'ffmpeg', '-y', '-f', 'rawvideo',
115
+ '-pix_fmt', 'gray', '-s', f'{width}x{height}',
116
+ '-r', str(fps), '-i', '-',
117
+ '-c:v', 'ffv1', '-level', '3', '-pix_fmt', 'gray',
118
+ str(self.masks_path)
119
+ ]
120
+
121
+ ffmpeg_process = subprocess.Popen(
122
+ ffmpeg_cmd, stdin=subprocess.PIPE,
123
+ stderr=subprocess.PIPE, stdout=subprocess.PIPE
124
+ )
125
+
126
+ # Generate and stream masks
127
+ print(f"Processing {frame_count} frames...")
128
+
129
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
130
+ if progress_callback:
131
+ progress = (out_frame_idx + 1) / frame_count * 50 # 50% of total progress for stage 1
132
+ progress_callback(f"Generating masks... Frame {out_frame_idx + 1}/{frame_count}", progress)
133
+
134
+ # Combine masks from all objects
135
+ combined_mask = np.zeros((height, width), dtype=np.uint8)
136
+ for obj_id in out_obj_ids:
137
+ mask = (out_mask_logits[obj_id] > 0.0).squeeze()
138
+ combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8) * 255
139
+
140
+ # Write to FFmpeg
141
+ ffmpeg_process.stdin.write(combined_mask.tobytes())
142
+
143
+ # Finalize FFmpeg
144
+ ffmpeg_process.stdin.close()
145
+ ffmpeg_process.wait()
146
+
147
+ if ffmpeg_process.returncode != 0:
148
+ error = ffmpeg_process.stderr.read().decode()
149
+ print(f"FFmpeg error: {error}")
150
+ return False
151
+
152
+ print("Stage 1 complete: Masks saved")
153
+
154
+ # CRITICAL: Complete memory cleanup
155
+ del predictor
156
+ del inference_state
157
+ if 'torch' in locals():
158
+ if torch.cuda.is_available():
159
+ torch.cuda.empty_cache()
160
+ torch.cuda.synchronize()
161
+
162
+ # Force garbage collection
163
+ gc.collect()
164
+
165
+ # Clear SAM2 from sys.modules to prevent memory leaks
166
+ modules_to_clear = [mod for mod in sys.modules.keys() if 'sam2' in mod.lower()]
167
+ for mod in modules_to_clear:
168
+ del sys.modules[mod]
169
+
170
+ print("SAM2 completely unloaded from memory")
171
+ return True
172
+
173
+ except Exception as e:
174
+ print(f"Stage 1 failed: {e}")
175
+ return False
176
+
177
+ def _stage2_composite(self, input_video: str, background_video: str,
178
+ output_path: str, use_matanyone: bool, progress_callback=None) -> bool:
179
+ """Stage 2: Read masks, refine with MatAnyone, and composite"""
180
+ try:
181
+ # Load metadata
182
+ with open(self.metadata_path, 'r') as f:
183
+ metadata = json.load(f)
184
+
185
+ frame_count = metadata["frame_count"]
186
+
187
+ # Read masks back from lossless stream
188
+ masks = self._read_mask_stream()
189
+ if masks is None:
190
+ return False
191
+
192
+ # Optional MatAnyone refinement
193
+ if use_matanyone:
194
+ if progress_callback:
195
+ progress_callback("Refining masks with MatAnyone...")
196
+ masks = self._refine_with_matanyone(input_video, masks, progress_callback)
197
+ if masks is None:
198
+ return False
199
+
200
+ # Final composition
201
+ if progress_callback:
202
+ progress_callback("Compositing final video...")
203
+
204
+ return self._composite_final_video(input_video, background_video,
205
+ masks, output_path, metadata, progress_callback)
206
+
207
+ except Exception as e:
208
+ print(f"Stage 2 failed: {e}")
209
+ return False
210
+
211
+ def _read_mask_stream(self) -> Optional[list]:
212
+ """Read masks from the lossless FFV1 stream"""
213
+ try:
214
+ # Load metadata for dimensions
215
+ with open(self.metadata_path, 'r') as f:
216
+ metadata = json.load(f)
217
+
218
+ width = metadata["width"]
219
+ height = metadata["height"]
220
+ frame_count = metadata["frame_count"]
221
+
222
+ # Use FFmpeg to decode masks
223
+ ffmpeg_cmd = [
224
+ 'ffmpeg', '-i', str(self.masks_path),
225
+ '-f', 'rawvideo', '-pix_fmt', 'gray', '-'
226
+ ]
227
+
228
+ process = subprocess.Popen(
229
+ ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
230
+ )
231
+
232
+ masks = []
233
+ frame_size = width * height
234
+
235
+ for frame_idx in range(frame_count):
236
+ frame_data = process.stdout.read(frame_size)
237
+ if len(frame_data) != frame_size:
238
+ print(f"Unexpected frame size at frame {frame_idx}")
239
+ break
240
+
241
+ mask = np.frombuffer(frame_data, dtype=np.uint8).reshape((height, width))
242
+ masks.append(mask)
243
+
244
+ process.stdout.close()
245
+ process.wait()
246
+
247
+ if process.returncode != 0:
248
+ error = process.stderr.read().decode()
249
+ print(f"FFmpeg decode error: {error}")
250
+ return None
251
+
252
+ print(f"Successfully read {len(masks)} masks from stream")
253
+ return masks
254
+
255
+ except Exception as e:
256
+ print(f"Failed to read mask stream: {e}")
257
+ return None
258
+
259
+ def _refine_with_matanyone(self, input_video: str, masks: list, progress_callback=None) -> Optional[list]:
260
+ """Apply MatAnyone refinement to masks"""
261
+ try:
262
+ # Import MatAnyone only when needed
263
+ from matanyone.mat_anywhere import matting_inference_video
264
+
265
+ # Create temp directory for MatAnyone
266
+ matanyone_temp = self.temp_dir / "matanyone"
267
+ matanyone_temp.mkdir(exist_ok=True)
268
+
269
+ # Save masks as individual frames for MatAnyone
270
+ mask_dir = matanyone_temp / "masks"
271
+ mask_dir.mkdir(exist_ok=True)
272
+
273
+ for i, mask in enumerate(masks):
274
+ cv2.imwrite(str(mask_dir / f"mask_{i:06d}.png"), mask)
275
+
276
+ # Run MatAnyone
277
+ refined_masks_dir = matanyone_temp / "refined"
278
+ refined_masks_dir.mkdir(exist_ok=True)
279
+
280
+ success = matting_inference_video(
281
+ video_path=input_video,
282
+ mask_dir=str(mask_dir),
283
+ output_dir=str(refined_masks_dir),
284
+ progress_callback=progress_callback
285
+ )
286
+
287
+ if not success:
288
+ print("MatAnyone refinement failed, using original masks")
289
+ return masks
290
+
291
+ # Load refined masks
292
+ refined_masks = []
293
+ for i in range(len(masks)):
294
+ refined_path = refined_masks_dir / f"refined_{i:06d}.png"
295
+ if refined_path.exists():
296
+ refined_mask = cv2.imread(str(refined_path), cv2.IMREAD_GRAYSCALE)
297
+ refined_masks.append(refined_mask)
298
+ else:
299
+ refined_masks.append(masks[i]) # Fallback to original
300
+
301
+ return refined_masks
302
+
303
+ except Exception as e:
304
+ print(f"MatAnyone refinement failed: {e}, using original masks")
305
+ return masks
306
+
307
+ def _composite_final_video(self, input_video: str, background_video: str,
308
+ masks: list, output_path: str, metadata: Dict[str, Any],
309
+ progress_callback=None) -> bool:
310
+ """Create final composite video"""
311
+ try:
312
+ # Setup video capture
313
+ fg_cap = cv2.VideoCapture(input_video)
314
+ bg_cap = cv2.VideoCapture(background_video)
315
+
316
+ fps = metadata["fps"]
317
+ width = metadata["width"]
318
+ height = metadata["height"]
319
+
320
+ # Setup output writer
321
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
322
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
323
+
324
+ frame_idx = 0
325
+ total_frames = len(masks)
326
+
327
+ while frame_idx < total_frames:
328
+ # Read frames
329
+ ret_fg, fg_frame = fg_cap.read()
330
+ ret_bg, bg_frame = bg_cap.read()
331
+
332
+ if not ret_fg:
333
+ break
334
+
335
+ if not ret_bg:
336
+ # Loop background if shorter
337
+ bg_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
338
+ ret_bg, bg_frame = bg_cap.read()
339
+
340
+ if not ret_bg:
341
+ print("No background frame available")
342
+ break
343
+
344
+ # Resize background to match foreground
345
+ bg_frame = cv2.resize(bg_frame, (width, height))
346
+
347
+ # Get mask
348
+ mask = masks[frame_idx]
349
+ mask_norm = mask.astype(np.float32) / 255.0
350
+ mask_3ch = np.stack([mask_norm, mask_norm, mask_norm], axis=-1)
351
+
352
+ # Composite
353
+ composite = (fg_frame * mask_3ch + bg_frame * (1 - mask_3ch)).astype(np.uint8)
354
+ out.write(composite)
355
+
356
+ frame_idx += 1
357
+
358
+ if progress_callback and frame_idx % 10 == 0:
359
+ progress = 50 + (frame_idx / total_frames) * 50 # 50-100% for stage 2
360
+ progress_callback(f"Compositing... Frame {frame_idx}/{total_frames}", progress)
361
+
362
+ # Cleanup
363
+ fg_cap.release()
364
+ bg_cap.release()
365
+ out.release()
366
+
367
+ print(f"Final video saved to: {output_path}")
368
+ return True
369
+
370
+ except Exception as e:
371
+ print(f"Final composition failed: {e}")
372
+ return False
373
+
374
+ def cleanup(self):
375
+ """Clean up temporary files"""
376
+ try:
377
+ if self.temp_dir.exists():
378
+ import shutil
379
+ shutil.rmtree(self.temp_dir)
380
+ except Exception as e:
381
+ print(f"Cleanup failed: {e}")
382
+
383
+ # Compatibility wrapper for existing UI
384
+ def process_video_two_stage(input_video: str, background_video: str,
385
+ click_points: list, output_path: str,
386
+ use_matanyone: bool = True, progress_callback=None) -> bool:
387
+ """
388
+ Drop-in replacement for existing process_video function
389
+ """
390
+ processor = TwoStageProcessor()
391
+ try:
392
+ result = processor.process_video(
393
+ input_video, background_video, click_points,
394
+ output_path, use_matanyone, progress_callback
395
+ )
396
+ return result
397
+ finally:
398
+ processor.cleanup()
399
+
400
+ if __name__ == "__main__":
401
+ # Test the pipeline
402
+ import argparse
403
+ parser = argparse.ArgumentParser()
404
+ parser.add_argument("--input", required=True)
405
+ parser.add_argument("--background", required=True)
406
+ parser.add_argument("--output", required=True)
407
+ parser.add_argument("--clicks", required=True, help="JSON string of click points")
408
+ parser.add_argument("--no-matanyone", action="store_true")
409
+
410
+ args = parser.parse_args()
411
+
412
+ click_points = json.loads(args.clicks)
413
+ use_matanyone = not args.no_matanyone
414
+
415
+ success = process_video_two_stage(
416
+ args.input, args.background, click_points,
417
+ args.output, use_matanyone,
418
+ lambda msg, prog=None: print(f"Progress: {msg} ({prog}%)" if prog else msg)
419
+ )
420
+
421
+ print("Processing completed!" if success else "Processing failed!")
VideoBackgroundReplacer2/models/__init__.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro - Model Loading & Utilities (Hardened)
4
+ ======================================================
5
+ - Avoids heavy CUDA/Hydra work at import time
6
+ - Adds timeouts to subprocess probes
7
+ - Safer sys.path wiring for third_party repos
8
+ - MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession
9
+
10
+ Changes (2025-09-16):
11
+ - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
12
+ - Updated load_matany to apply T=1 squeeze patch before InferenceCore import
13
+ - Added patch status logging and MatAnyone version
14
+ - Added InferenceCore attributes logging for debugging
15
+ - Fixed InferenceCore import path to matanyone.inference.inference_core
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import sys
22
+ import cv2
23
+ import subprocess
24
+ import inspect
25
+ import logging
26
+ import importlib.metadata
27
+ from pathlib import Path
28
+ from typing import Optional, Tuple, Dict, Any, Union, Callable
29
+
30
+ import numpy as np
31
+ import yaml
32
+
33
+ # Import torch for GPU memory monitoring
34
+ try:
35
+ import torch
36
+ except ImportError:
37
+ torch = None
38
+
39
+ # --------------------------------------------------------------------------------------
40
+ # Logging (ensure a handler exists very early)
41
+ # --------------------------------------------------------------------------------------
42
+ logger = logging.getLogger("backgroundfx_pro")
43
+ if not logger.handlers:
44
+ _h = logging.StreamHandler()
45
+ _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
46
+ logger.addHandler(_h)
47
+ logger.setLevel(logging.INFO)
48
+
49
+ # Pin OpenCV threads (helps libgomp stability in Spaces)
50
+ try:
51
+ cv_threads = int(os.environ.get("CV_THREADS", "1"))
52
+ if hasattr(cv2, "setNumThreads"):
53
+ cv2.setNumThreads(cv_threads)
54
+ except Exception:
55
+ pass
56
+
57
+ # --------------------------------------------------------------------------------------
58
+ # Optional dependencies
59
+ # --------------------------------------------------------------------------------------
60
+ try:
61
+ import mediapipe as mp # type: ignore
62
+ _HAS_MEDIAPIPE = True
63
+ except Exception:
64
+ _HAS_MEDIAPIPE = False
65
+
66
+ # --------------------------------------------------------------------------------------
67
+ # Path setup for third_party repos
68
+ # --------------------------------------------------------------------------------------
69
+ ROOT = Path(__file__).resolve().parent.parent # project root
70
+ TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
71
+ TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
72
+
73
+ def _add_sys_path(p: Path) -> None:
74
+ if p.exists():
75
+ p_str = str(p)
76
+ if p_str not in sys.path:
77
+ sys.path.insert(0, p_str)
78
+ else:
79
+ logger.warning(f"third_party path not found: {p}")
80
+
81
+ _add_sys_path(TP_SAM2)
82
+ _add_sys_path(TP_MATANY)
83
+
84
+ # --------------------------------------------------------------------------------------
85
+ # Safe Torch accessors (no top-level import)
86
+ # --------------------------------------------------------------------------------------
87
+ def _torch():
88
+ try:
89
+ import torch # local import avoids early CUDA init during module import
90
+ return torch
91
+ except Exception as e:
92
+ logger.warning(f"[models.safe-torch] import failed: {e}")
93
+ return None
94
+
95
+ def _has_cuda() -> bool:
96
+ t = _torch()
97
+ if t is None:
98
+ return False
99
+ try:
100
+ return bool(t.cuda.is_available())
101
+ except Exception as e:
102
+ logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}")
103
+ return False
104
+
105
+ def _pick_device(env_key: str) -> str:
106
+ requested = os.environ.get(env_key, "").strip().lower()
107
+ has_cuda = _has_cuda()
108
+
109
+ # Log all CUDA-related environment variables
110
+ cuda_env_vars = {
111
+ 'FORCE_CUDA_DEVICE': os.environ.get('FORCE_CUDA_DEVICE', ''),
112
+ 'CUDA_MEMORY_FRACTION': os.environ.get('CUDA_MEMORY_FRACTION', ''),
113
+ 'PYTORCH_CUDA_ALLOC_CONF': os.environ.get('PYTORCH_CUDA_ALLOC_CONF', ''),
114
+ 'REQUIRE_CUDA': os.environ.get('REQUIRE_CUDA', ''),
115
+ 'SAM2_DEVICE': os.environ.get('SAM2_DEVICE', ''),
116
+ 'MATANY_DEVICE': os.environ.get('MATANY_DEVICE', ''),
117
+ }
118
+ logger.info(f"CUDA environment variables: {cuda_env_vars}")
119
+
120
+ logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
121
+
122
+ # Force CUDA if available (empty string counts as no explicit CPU request)
123
+ if has_cuda and requested not in {"cpu"}:
124
+ logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
125
+ return "cuda"
126
+ elif requested in {"cuda", "cpu"}:
127
+ logger.info(f"Using explicitly requested device: {requested}")
128
+ return requested
129
+
130
+ result = "cuda" if has_cuda else "cpu"
131
+ logger.info(f"Auto-selected device: {result}")
132
+ return result
133
+
134
+ # --------------------------------------------------------------------------------------
135
+ # Basic Utilities
136
+ # --------------------------------------------------------------------------------------
137
+ def _ffmpeg_bin() -> str:
138
+ return os.environ.get("FFMPEG_BIN", "ffmpeg")
139
+
140
+ def _probe_ffmpeg(timeout: int = 2) -> bool:
141
+ try:
142
+ subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout)
143
+ return True
144
+ except Exception:
145
+ return False
146
+
147
+ def _ensure_dir(p: Path) -> None:
148
+ p.mkdir(parents=True, exist_ok=True)
149
+
150
+ def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
151
+ cap = cv2.VideoCapture(str(video_path))
152
+ if not cap.isOpened():
153
+ return None, 0, (0, 0)
154
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
155
+ ok, frame = cap.read()
156
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
157
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
158
+ cap.release()
159
+ if not ok:
160
+ return None, fps, (w, h)
161
+ return frame, fps, (w, h)
162
+
163
+ def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
164
+ if mask.dtype == bool:
165
+ mask = (mask.astype(np.uint8) * 255)
166
+ elif mask.dtype != np.uint8:
167
+ mask = np.clip(mask, 0, 255).astype(np.uint8)
168
+ cv2.imwrite(str(path), mask)
169
+ return str(path)
170
+
171
+ def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
172
+ tw, th = target_wh
173
+ h, w = image.shape[:2]
174
+ if h == 0 or w == 0 or tw == 0 or th == 0:
175
+ return image
176
+ scale = min(tw / w, th / h)
177
+ nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
178
+ resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
179
+ canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
180
+ x0 = (tw - nw) // 2
181
+ y0 = (th - nh) // 2
182
+ canvas[y0:y0+nh, x0:x0+nw] = resized
183
+ return canvas
184
+
185
+ def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
186
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
187
+ return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
188
+
189
+ def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
190
+ """Copy video from silent_video + audio from src_video into out_path (AAC)."""
191
+ try:
192
+ cmd = [
193
+ _ffmpeg_bin(), "-y",
194
+ "-i", str(silent_video),
195
+ "-i", str(src_video),
196
+ "-map", "0:v:0",
197
+ "-map", "1:a:0?",
198
+ "-c:v", "copy",
199
+ "-c:a", "aac", "-b:a", "192k",
200
+ "-shortest",
201
+ str(out_path)
202
+ ]
203
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
204
+ return True
205
+ except Exception as e:
206
+ logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
207
+ return False
208
+
209
+ # --------------------------------------------------------------------------------------
210
+ # Compositing & Image Processing
211
+ # --------------------------------------------------------------------------------------
212
+ def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
213
+ if alpha.dtype != np.float32:
214
+ a = alpha.astype(np.float32)
215
+ if a.max() > 1.0:
216
+ a = a / 255.0
217
+ else:
218
+ a = alpha.copy()
219
+
220
+ a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
221
+ if erode_px > 0:
222
+ k = max(1, int(erode_px))
223
+ a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
224
+ if dilate_px > 0:
225
+ k = max(1, int(dilate_px))
226
+ a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
227
+ a = a_u8.astype(np.float32) / 255.0
228
+
229
+ if blur_px and blur_px > 0:
230
+ rad = max(1, int(round(blur_px)))
231
+ a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
232
+
233
+ return np.clip(a, 0.0, 1.0)
234
+
235
+ def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
236
+ x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
237
+ return np.power(x, gamma)
238
+
239
+ def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
240
+ x = np.clip(lin, 0.0, 1.0)
241
+ return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
242
+
243
+ def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
244
+ r = max(1, int(radius))
245
+ inv = 1.0 - alpha01
246
+ inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
247
+ lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
248
+ return lw
249
+
250
+ def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
251
+ w = 1.0 - 2.0 * np.abs(alpha01 - 0.5)
252
+ w = np.clip(w, 0.0, 1.0)
253
+ hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
254
+ H, S, V = cv2.split(hsv)
255
+ S = S * (1.0 - amount * w)
256
+ hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
257
+ out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
258
+ return out
259
+
260
+ def _composite_frame_pro(
261
+ fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
262
+ erode_px: int = None, dilate_px: int = None, blur_px: float = None,
263
+ lw_radius: int = None, lw_amount: float = None, despill_amount: float = None
264
+ ) -> np.ndarray:
265
+ erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
266
+ dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
267
+ blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
268
+ lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
269
+ lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
270
+ despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
271
+
272
+ a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
273
+ fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
274
+
275
+ fg_lin = _to_linear(fg_rgb)
276
+ bg_lin = _to_linear(bg_rgb)
277
+ lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
278
+ lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
279
+
280
+ comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
281
+ comp = _to_srgb(comp_lin)
282
+ return comp
283
+
284
+ # --------------------------------------------------------------------------------------
285
+ # SAM2 Integration
286
+ # --------------------------------------------------------------------------------------
287
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
288
+ """Resolve SAM2 config path - return relative path for Hydra compatibility."""
289
+ logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
290
+
291
+ # Get the third-party SAM2 directory
292
+ tp_sam2 = os.environ.get("THIRD_PARTY_SAM2_DIR", "/home/user/app/third_party/sam2")
293
+ logger.info(f"TP_SAM2 = {tp_sam2}")
294
+
295
+ # Check if the full path exists
296
+ candidate = os.path.join(tp_sam2, cfg_str)
297
+ logger.info(f"Candidate path: {candidate}")
298
+ logger.info(f"Candidate exists: {os.path.exists(candidate)}")
299
+
300
+ if os.path.exists(candidate):
301
+ # For Hydra compatibility, return just the relative path within sam2 package
302
+ if cfg_str.startswith("sam2/configs/"):
303
+ relative_path = cfg_str.replace("sam2/configs/", "configs/")
304
+ else:
305
+ relative_path = cfg_str
306
+ logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
307
+ return relative_path
308
+
309
+ # If not found, try some fallback paths
310
+ fallbacks = [
311
+ os.path.join(tp_sam2, "sam2", cfg_str),
312
+ os.path.join(tp_sam2, "configs", cfg_str),
313
+ ]
314
+
315
+ for fallback in fallbacks:
316
+ logger.info(f"Trying fallback: {fallback}")
317
+ if os.path.exists(fallback):
318
+ # Extract relative path for Hydra
319
+ if "configs/" in fallback:
320
+ relative_path = "configs/" + fallback.split("configs/")[-1]
321
+ logger.info(f"Returning fallback relative path: {relative_path}")
322
+ return relative_path
323
+
324
+ logger.warning(f"Config not found, returning original: {cfg_str}")
325
+ return cfg_str
326
+
327
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
328
+ """If config references 'hieradet', try to find a 'hiera' config."""
329
+ try:
330
+ with open(cfg_path, "r") as f:
331
+ data = yaml.safe_load(f)
332
+ model = data.get("model", {}) or {}
333
+ enc = model.get("image_encoder") or {}
334
+ trunk = enc.get("trunk") or {}
335
+ target = trunk.get("_target_") or trunk.get("target")
336
+ if isinstance(target, str) and "hieradet" in target:
337
+ for y in TP_SAM2.rglob("*.yaml"):
338
+ try:
339
+ with open(y, "r") as f2:
340
+ d2 = yaml.safe_load(f2) or {}
341
+ e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
342
+ t2 = (e2.get("trunk") or {})
343
+ tgt2 = t2.get("_target_") or t2.get("target")
344
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
345
+ logger.info(f"SAM2: switching config from 'hieradet' β†’ 'hiera': {y}")
346
+ return str(y)
347
+ except Exception:
348
+ continue
349
+ except Exception:
350
+ pass
351
+ return None
352
+
353
+ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
354
+ """Robust SAM2 loader with config resolution and error handling."""
355
+ meta = {"sam2_import_ok": False, "sam2_init_ok": False}
356
+ try:
357
+ from sam2.build_sam import build_sam2 # type: ignore
358
+ from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
359
+ meta["sam2_import_ok"] = True
360
+ except Exception as e:
361
+ logger.warning(f"SAM2 import failed: {e}")
362
+ return None, False, meta
363
+
364
+ # Check GPU memory before loading
365
+ if torch and torch.cuda.is_available():
366
+ mem_before = torch.cuda.memory_allocated() / 1024**3
367
+ logger.info(f"πŸ” GPU memory before SAM2 load: {mem_before:.2f}GB")
368
+
369
+ device = _pick_device("SAM2_DEVICE")
370
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
371
+ cfg = _resolve_sam2_cfg(cfg_env)
372
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "")
373
+
374
+ def _try_build(cfg_path: str):
375
+ logger.info(f"_try_build called with cfg_path: {cfg_path}")
376
+ params = set(inspect.signature(build_sam2).parameters.keys())
377
+ logger.info(f"build_sam2 parameters: {list(params)}")
378
+ kwargs = {}
379
+ if "config_file" in params:
380
+ kwargs["config_file"] = cfg_path
381
+ logger.info(f"Using config_file parameter: {cfg_path}")
382
+ elif "model_cfg" in params:
383
+ kwargs["model_cfg"] = cfg_path
384
+ logger.info(f"Using model_cfg parameter: {cfg_path}")
385
+ if ckpt:
386
+ if "checkpoint" in params:
387
+ kwargs["checkpoint"] = ckpt
388
+ elif "ckpt_path" in params:
389
+ kwargs["ckpt_path"] = ckpt
390
+ elif "weights" in params:
391
+ kwargs["weights"] = ckpt
392
+ if "device" in params:
393
+ kwargs["device"] = device
394
+ try:
395
+ logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
396
+ result = build_sam2(**kwargs)
397
+ logger.info(f"build_sam2 succeeded with kwargs")
398
+ # Log actual device of the model
399
+ if hasattr(result, 'device'):
400
+ logger.info(f"SAM2 model device: {result.device}")
401
+ elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
402
+ logger.info(f"SAM2 model device: {result.image_encoder.device}")
403
+ return result
404
+ except TypeError as e:
405
+ logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
406
+ pos = [cfg_path]
407
+ if ckpt:
408
+ pos.append(ckpt)
409
+ if "device" not in kwargs:
410
+ pos.append(device)
411
+ logger.info(f"Calling build_sam2 with positional args: {pos}")
412
+ result = build_sam2(*pos)
413
+ logger.info(f"build_sam2 succeeded with positional args")
414
+ return result
415
+
416
+ try:
417
+ try:
418
+ sam = _try_build(cfg)
419
+ except Exception:
420
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
421
+ if alt_cfg:
422
+ sam = _try_build(alt_cfg)
423
+ else:
424
+ raise
425
+
426
+ if sam is not None:
427
+ predictor = SAM2ImagePredictor(sam)
428
+ meta["sam2_init_ok"] = True
429
+ meta["sam2_device"] = device
430
+ return predictor, True, meta
431
+ else:
432
+ return None, False, meta
433
+
434
+ except Exception as e:
435
+ logger.error(f"SAM2 loading failed: {e}")
436
+ return None, False, meta
437
+
438
+ def run_sam2_mask(predictor: object,
439
+ first_frame_bgr: np.ndarray,
440
+ point: Optional[Tuple[int, int]] = None,
441
+ auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
442
+ """Return (mask_uint8_0_255, ok)."""
443
+ if predictor is None:
444
+ return None, False
445
+ try:
446
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
447
+ predictor.set_image(rgb)
448
+
449
+ if auto:
450
+ h, w = rgb.shape[:2]
451
+ box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
452
+ masks, _, _ = predictor.predict(box=box)
453
+ elif point is not None:
454
+ x, y = int(point[0]), int(point[1])
455
+ pts = np.array([[x, y]], dtype=np.int32)
456
+ labels = np.array([1], dtype=np.int32)
457
+ masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
458
+ else:
459
+ h, w = rgb.shape[:2]
460
+ box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
461
+ masks, _, _ = predictor.predict(box=box)
462
+
463
+ if masks is None or len(masks) == 0:
464
+ return None, False
465
+
466
+ m = masks[0].astype(np.uint8) * 255
467
+ return m, True
468
+ except Exception as e:
469
+ logger.warning(f"SAM2 mask failed: {e}")
470
+ return None, False
471
+
472
+ def _refine_mask_grabcut(image_bgr: np.ndarray,
473
+ mask_u8: np.ndarray,
474
+ iters: int = None,
475
+ trimap_erode: int = None,
476
+ trimap_dilate: int = None) -> np.ndarray:
477
+ """Use SAM2 seed as initialization for GrabCut refinement."""
478
+ iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
479
+ e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
480
+ d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
481
+
482
+ h, w = mask_u8.shape[:2]
483
+ m = (mask_u8 > 127).astype(np.uint8) * 255
484
+
485
+ sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
486
+ sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
487
+
488
+ gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
489
+ gc_mask[sure_bg > 0] = cv2.GC_BGD
490
+ gc_mask[sure_fg > 0] = cv2.GC_FGD
491
+
492
+ bgdModel = np.zeros((1, 65), np.float64)
493
+ fgdModel = np.zeros((1, 65), np.float64)
494
+ try:
495
+ cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
496
+ out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
497
+ out = cv2.medianBlur(out, 5)
498
+ return out
499
+ except Exception as e:
500
+ logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
501
+ return m
502
+
503
+ # --------------------------------------------------------------------------------------
504
+ # MatAnyone Integration
505
+ # --------------------------------------------------------------------------------------
506
+ def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
507
+ """
508
+ Probe MatAnyone availability with T=1 squeeze patch for conv2d compatibility.
509
+ Returns (None, available, meta); actual instantiation happens in MatAnyoneSession.
510
+ """
511
+ meta = {"matany_import_ok": False, "matany_init_ok": False}
512
+ enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
513
+ if enable_env in {"0", "false", "off", "no"}:
514
+ logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
515
+ meta["disabled"] = True
516
+ return None, False, meta
517
+
518
+ # Apply T=1 squeeze patch before importing InferenceCore
519
+ try:
520
+ from .matany_compat_patch import apply_matany_t1_squeeze_guard
521
+ if apply_matany_t1_squeeze_guard():
522
+ logger.info("[MatAnyCompat] T=1 squeeze guard applied")
523
+ meta["patch_applied"] = True
524
+ else:
525
+ logger.warning("[MatAnyCompat] T=1 squeeze patch failed; conv2d errors may occur")
526
+ meta["patch_applied"] = False
527
+ except Exception as e:
528
+ logger.warning(f"[MatAnyCompat] Patch import failed: {e}")
529
+ meta["patch_applied"] = False
530
+
531
+ try:
532
+ from matanyone.inference.inference_core import InferenceCore # type: ignore
533
+ meta["matany_import_ok"] = True
534
+ # Log MatAnyone version and InferenceCore attributes
535
+ try:
536
+ version = importlib.metadata.version("matanyone")
537
+ logger.info(f"[MATANY] MatAnyone version: {version}")
538
+ except Exception:
539
+ logger.info("[MATANY] MatAnyone version unknown")
540
+ logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}")
541
+ device = _pick_device("MATANY_DEVICE")
542
+ repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
543
+ meta["matany_repo_id"] = repo_id
544
+ meta["matany_device"] = device
545
+ return None, True, meta
546
+ except Exception as e:
547
+ logger.warning(f"MatAnyone import failed: {e}")
548
+ return None, False, meta
549
+
550
+ # --------------------------------------------------------------------------------------
551
+ # Fallback Functions
552
+ # --------------------------------------------------------------------------------------
553
+ def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
554
+ """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
555
+ h, w = first_frame_bgr.shape[:2]
556
+ if _HAS_MEDIAPIPE:
557
+ try:
558
+ mp_selfie = mp.solutions.selfie_segmentation
559
+ with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
560
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
561
+ res = segmenter.process(rgb)
562
+ m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
563
+ m = cv2.medianBlur(m, 5)
564
+ return m
565
+ except Exception as e:
566
+ logger.warning(f"MediaPipe fallback failed: {e}")
567
+
568
+ # Ultimate fallback: GrabCut
569
+ mask = np.zeros((h, w), np.uint8)
570
+ rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
571
+ bgdModel = np.zeros((1, 65), np.float64)
572
+ fgdModel = np.zeros((1, 65), np.float64)
573
+ try:
574
+ cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
575
+ mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
576
+ return mask_bin
577
+ except Exception as e:
578
+ logger.warning(f"GrabCut failed: {e}")
579
+ return np.zeros((h, w), dtype=np.uint8)
580
+
581
+ def composite_video(fg_path: Union[str, Path],
582
+ alpha_path: Union[str, Path],
583
+ bg_image_path: Union[str, Path],
584
+ out_path: Union[str, Path],
585
+ fps: int,
586
+ size: Tuple[int, int]) -> bool:
587
+ """Blend MatAnyone FG+ALPHA over background using pro compositor."""
588
+ fg_cap = cv2.VideoCapture(str(fg_path))
589
+ al_cap = cv2.VideoCapture(str(alpha_path))
590
+ if not fg_cap.isOpened() or not al_cap.isOpened():
591
+ return False
592
+
593
+ w, h = size
594
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
595
+ if bg is None:
596
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
597
+ bg_f = _resize_keep_ar(bg, (w, h))
598
+
599
+ if _probe_ffmpeg():
600
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
601
+ writer = _video_writer(tmp_out, fps, (w, h))
602
+ post_h264 = True
603
+ else:
604
+ writer = _video_writer(Path(out_path), fps, (w, h))
605
+ post_h264 = False
606
+
607
+ ok_any = False
608
+ try:
609
+ while True:
610
+ ok_fg, fg = fg_cap.read()
611
+ ok_al, al = al_cap.read()
612
+ if not ok_fg or not ok_al:
613
+ break
614
+ fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
615
+ al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
616
+
617
+ comp = _composite_frame_pro(
618
+ cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
619
+ al_gray,
620
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
621
+ )
622
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
623
+ ok_any = True
624
+ finally:
625
+ fg_cap.release()
626
+ al_cap.release()
627
+ writer.release()
628
+
629
+ if post_h264 and ok_any:
630
+ try:
631
+ cmd = [
632
+ _ffmpeg_bin(), "-y",
633
+ "-i", str(tmp_out),
634
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
635
+ str(out_path)
636
+ ]
637
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
638
+ tmp_out.unlink(missing_ok=True)
639
+ except Exception as e:
640
+ logger.warning(f"ffmpeg finalize failed: {e}")
641
+ Path(out_path).unlink(missing_ok=True)
642
+ tmp_out.replace(out_path)
643
+
644
+ return ok_any
645
+
646
+ def fallback_composite(video_path: Union[str, Path],
647
+ mask_path: Union[str, Path],
648
+ bg_image_path: Union[str, Path],
649
+ out_path: Union[str, Path]) -> bool:
650
+ """Static-mask compositing using pro compositor."""
651
+ mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
652
+ cap = cv2.VideoCapture(str(video_path))
653
+ if mask is None or not cap.isOpened():
654
+ return False
655
+
656
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
657
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
658
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
659
+
660
+ bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
661
+ if bg is None:
662
+ bg = np.full((h, w, 3), 127, dtype=np.uint8)
663
+
664
+ mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
665
+ bg_f = _resize_keep_ar(bg, (w, h))
666
+
667
+ if _probe_ffmpeg():
668
+ tmp_out = Path(str(out_path) + ".tmp.mp4")
669
+ writer = _video_writer(tmp_out, fps, (w, h))
670
+ use_post_ffmpeg = True
671
+ else:
672
+ writer = _video_writer(Path(out_path), fps, (w, h))
673
+ use_post_ffmpeg = False
674
+
675
+ ok_any = False
676
+ try:
677
+ while True:
678
+ ok, frame = cap.read()
679
+ if not ok:
680
+ break
681
+ comp = _composite_frame_pro(
682
+ cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
683
+ mask_resized,
684
+ cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
685
+ )
686
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
687
+ ok_any = True
688
+ finally:
689
+ cap.release()
690
+ writer.release()
691
+
692
+ if use_post_ffmpeg and ok_any:
693
+ try:
694
+ cmd = [
695
+ _ffmpeg_bin(), "-y",
696
+ "-i", str(tmp_out),
697
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
698
+ str(out_path)
699
+ ]
700
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
701
+ tmp_out.unlink(missing_ok=True)
702
+ except Exception as e:
703
+ logger.warning(f"ffmpeg H.264 finalize failed: {e}")
704
+ Path(out_path).unlink(missing_ok=True)
705
+ tmp_out.replace(out_path)
706
+
707
+ return ok_any
708
+
709
+ # --------------------------------------------------------------------------------------
710
+ # Stage-A (Transparent Export) Functions
711
+ # --------------------------------------------------------------------------------------
712
+ def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
713
+ y, x = np.mgrid[0:h, 0:w]
714
+ c = ((x // tile) + (y // tile)) % 2
715
+ a = np.where(c == 0, 200, 150).astype(np.uint8)
716
+ return np.stack([a, a, a], axis=-1)
717
+
718
+ def _build_stage_a_rgba_vp9_from_fg_alpha(
719
+ fg_path: Union[str, Path],
720
+ alpha_path: Union[str, Path],
721
+ out_webm: Union[str, Path],
722
+ fps: int,
723
+ size: Tuple[int, int],
724
+ src_audio: Optional[Union[str, Path]] = None,
725
+ ) -> bool:
726
+ if not _probe_ffmpeg():
727
+ return False
728
+ w, h = size
729
+ try:
730
+ cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)]
731
+ if src_audio:
732
+ cmd += ["-i", str(src_audio)]
733
+ fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
734
+ f"[0:v]scale={w}:{h},fps={fps}[fg];" \
735
+ f"[fg][al]alphamerge[outv]"
736
+ cmd += ["-filter_complex", fcx, "-map", "[outv]"]
737
+ if src_audio:
738
+ cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
739
+ cmd += [
740
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
741
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
742
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
743
+ ]
744
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
745
+ return True
746
+ except Exception as e:
747
+ logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
748
+ return False
749
+
750
+ def _build_stage_a_rgba_vp9_from_mask(
751
+ video_path: Union[str, Path],
752
+ mask_png: Union[str, Path],
753
+ out_webm: Union[str, Path],
754
+ fps: int,
755
+ size: Tuple[int, int],
756
+ ) -> bool:
757
+ if not _probe_ffmpeg():
758
+ return False
759
+ w, h = size
760
+ try:
761
+ cmd = [
762
+ _ffmpeg_bin(), "-y",
763
+ "-i", str(video_path),
764
+ "-loop", "1", "-i", str(mask_png),
765
+ "-filter_complex",
766
+ f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
767
+ f"[0:v]scale={w}:{h},fps={fps}[fg];"
768
+ f"[fg][al]alphamerge[outv]",
769
+ "-map", "[outv]",
770
+ "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
771
+ "-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
772
+ "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
773
+ ]
774
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
775
+ return True
776
+ except Exception as e:
777
+ logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
778
+ return False
779
+
780
+ def _build_stage_a_checkerboard_from_fg_alpha(
781
+ fg_path: Union[str, Path],
782
+ alpha_path: Union[str, Path],
783
+ out_mp4: Union[str, Path],
784
+ fps: int,
785
+ size: Tuple[int, int],
786
+ ) -> bool:
787
+ fg_cap = cv2.VideoCapture(str(fg_path))
788
+ al_cap = cv2.VideoCapture(str(alpha_path))
789
+ if not fg_cap.isOpened() or not al_cap.isOpened():
790
+ return False
791
+ w, h = size
792
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
793
+ bg = _checkerboard_bg(w, h)
794
+ ok_any = False
795
+ try:
796
+ while True:
797
+ okf, fg = fg_cap.read()
798
+ oka, al = al_cap.read()
799
+ if not okf or not oka:
800
+ break
801
+ fg = cv2.resize(fg, (w, h))
802
+ al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
803
+ comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
804
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
805
+ ok_any = True
806
+ finally:
807
+ fg_cap.release()
808
+ al_cap.release()
809
+ writer.release()
810
+ return ok_any
811
+
812
+ def _build_stage_a_checkerboard_from_mask(
813
+ video_path: Union[str, Path],
814
+ mask_png: Union[str, Path],
815
+ out_mp4: Union[str, Path],
816
+ fps: int,
817
+ size: Tuple[int, int],
818
+ ) -> bool:
819
+ cap = cv2.VideoCapture(str(video_path))
820
+ if not cap.isOpened():
821
+ return False
822
+ w, h = size
823
+ mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
824
+ if mask is None:
825
+ return False
826
+ mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
827
+ writer = _video_writer(Path(out_mp4), fps, (w, h))
828
+ bg = _checkerboard_bg(w, h)
829
+ ok_any = False
830
+ try:
831
+ while True:
832
+ ok, frame = cap.read()
833
+ if not ok:
834
+ break
835
+ frame = cv2.resize(frame, (w, h))
836
+ comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
837
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
838
+ ok_any = True
839
+ finally:
840
+ cap.release()
841
+ writer.release()
842
+ return ok_any
843
+
844
+ # --------------------------------------------------------------------------------------
845
+ # MatAnyone Integration
846
+ # --------------------------------------------------------------------------------------
847
+ def run_matany(
848
+ video_path: Union[str, Path],
849
+ mask_path: Optional[Union[str, Path]],
850
+ out_dir: Union[str, Path],
851
+ device: Optional[str] = None,
852
+ progress_callback: Optional[Callable[[float, str], None]] = None,
853
+ ) -> Tuple[Path, Path]:
854
+ """
855
+ Run MatAnyone streaming matting via our shape-guarded adapter.
856
+ Returns (alpha_mp4_path, fg_mp4_path).
857
+ Raises MatAnyError on failure.
858
+ """
859
+ from .matanyone_loader import MatAnyoneSession, MatAnyError
860
+
861
+ session = MatAnyoneSession(device=device, precision="auto")
862
+ alpha_p, fg_p = session.process_stream(
863
+ video_path=Path(video_path),
864
+ seed_mask_path=Path(mask_path) if mask_path else None,
865
+ out_dir=Path(out_dir),
866
+ progress_cb=progress_callback,
867
+ )
868
+ return alpha_p, fg_p
VideoBackgroundReplacer2/models/matanyone_loader.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ MatAnyone adapter β€” Using Official API (File-Based)
5
+
6
+ Fixed to use MatAnyone's official process_video() API instead of
7
+ bypassing it with internal tensor manipulation. This eliminates
8
+ all 5D tensor dimension issues.
9
+
10
+ Changes (2025-09-17):
11
+ - Replaced custom tensor processing with official MatAnyone API
12
+ - Uses file-based input/output as designed by MatAnyone authors
13
+ - Eliminates all tensor dimension compatibility issues
14
+ - Simplified error handling and logging
15
+ """
16
+
17
+ from __future__ import annotations
18
+ import os
19
+ import time
20
+ import logging
21
+ import tempfile
22
+ import importlib.metadata
23
+ from pathlib import Path
24
+ from typing import Optional, Callable, Tuple
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+ # ---------- Progress helper ----------
29
+ def _env_flag(name: str, default: str = "0") -> bool:
30
+ return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"}
31
+
32
+ _PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
33
+ _PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
34
+ _progress_last = 0.0
35
+ _progress_last_msg = None
36
+ _progress_disabled = False
37
+
38
+ def _emit_progress(cb, pct: float, msg: str):
39
+ global _progress_last, _progress_last_msg, _progress_disabled
40
+ if not cb or not _PROGRESS_CB_ENABLED or _progress_disabled:
41
+ return
42
+ now = time.time()
43
+ if (now - _progress_last) < _PROGRESS_MIN_INTERVAL and msg == _progress_last_msg:
44
+ return
45
+ try:
46
+ try:
47
+ cb(pct, msg) # preferred (pct, msg)
48
+ except TypeError:
49
+ cb(msg) # legacy (msg)
50
+ _progress_last = now
51
+ _progress_last_msg = msg
52
+ except Exception as e:
53
+ _progress_disabled = True
54
+ log.warning("[progress-cb] disabled due to exception: %s", e)
55
+
56
+ # ---------- Errors ----------
57
+ class MatAnyError(RuntimeError):
58
+ pass
59
+
60
+ # ---------- CUDA helpers ----------
61
+ def _cuda_snapshot(device: Optional[str]) -> str:
62
+ try:
63
+ import torch
64
+ if not torch.cuda.is_available():
65
+ return "CUDA: N/A"
66
+ idx = 0
67
+ if device and device.startswith("cuda:"):
68
+ try:
69
+ idx = int(device.split(":")[1])
70
+ except (ValueError, IndexError):
71
+ idx = 0
72
+ name = torch.cuda.get_device_name(idx)
73
+ alloc = torch.cuda.memory_allocated(idx) / (1024**3)
74
+ resv = torch.cuda.memory_reserved(idx) / (1024**3)
75
+ return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
76
+ except Exception as e:
77
+ return f"CUDA snapshot error: {e!r}"
78
+
79
+ def _safe_empty_cache():
80
+ try:
81
+ import torch
82
+ if torch.cuda.is_available():
83
+ log.info(f"[MATANY] CUDA memory before empty_cache: {_cuda_snapshot('cuda:0')}")
84
+ torch.cuda.empty_cache()
85
+ log.info(f"[MATANY] CUDA memory after empty_cache: {_cuda_snapshot('cuda:0')}")
86
+ except Exception:
87
+ pass
88
+
89
+ # ============================================================================
90
+
91
+ class MatAnyoneSession:
92
+ """
93
+ Simple wrapper around MatAnyone's official API.
94
+ Uses file-based input/output as designed by the MatAnyone authors.
95
+ """
96
+ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
97
+ self.device = device or ("cuda" if self._cuda_available() else "cpu")
98
+ self.precision = precision.lower()
99
+
100
+ # Log MatAnyone version
101
+ try:
102
+ version = importlib.metadata.version("matanyone")
103
+ log.info(f"[MATANY] MatAnyone version: {version}")
104
+ except Exception:
105
+ log.info("[MATANY] MatAnyone version unknown")
106
+
107
+ # Initialize MatAnyone's official API
108
+ try:
109
+ from matanyone import InferenceCore
110
+ self.processor = InferenceCore("PeiqingYang/MatAnyone")
111
+ log.info("[MATANY] MatAnyone InferenceCore initialized successfully")
112
+ except Exception as e:
113
+ raise MatAnyError(f"Failed to initialize MatAnyone: {e}")
114
+
115
+ def _cuda_available(self) -> bool:
116
+ try:
117
+ import torch
118
+ return torch.cuda.is_available()
119
+ except Exception:
120
+ return False
121
+
122
+ def process_stream(
123
+ self,
124
+ video_path: Path,
125
+ seed_mask_path: Optional[Path] = None,
126
+ out_dir: Optional[Path] = None,
127
+ progress_cb: Optional[Callable] = None,
128
+ ) -> Tuple[Path, Path]:
129
+ """
130
+ Process video using MatAnyone's official API.
131
+
132
+ Args:
133
+ video_path: Path to input video file
134
+ seed_mask_path: Path to first-frame mask PNG (white=foreground, black=background)
135
+ out_dir: Output directory for results
136
+ progress_cb: Progress callback function
137
+
138
+ Returns:
139
+ Tuple of (alpha_path, foreground_path)
140
+ """
141
+ video_path = Path(video_path)
142
+ if not video_path.exists():
143
+ raise MatAnyError(f"Video file not found: {video_path}")
144
+
145
+ if seed_mask_path and not Path(seed_mask_path).exists():
146
+ raise MatAnyError(f"Seed mask not found: {seed_mask_path}")
147
+
148
+ out_dir = Path(out_dir) if out_dir else video_path.parent / "matanyone_output"
149
+ out_dir.mkdir(parents=True, exist_ok=True)
150
+
151
+ log.info(f"[MATANY] Processing video: {video_path}")
152
+ log.info(f"[MATANY] Using mask: {seed_mask_path}")
153
+ log.info(f"[MATANY] Output directory: {out_dir}")
154
+
155
+ _emit_progress(progress_cb, 0.0, "Initializing MatAnyone processing...")
156
+
157
+ try:
158
+ # Use MatAnyone's official API
159
+ start_time = time.time()
160
+
161
+ _emit_progress(progress_cb, 0.1, "Running MatAnyone video matting...")
162
+
163
+ # Call the official process_video method
164
+ foreground_path, alpha_path = self.processor.process_video(
165
+ input_path=str(video_path),
166
+ mask_path=str(seed_mask_path) if seed_mask_path else None,
167
+ output_path=str(out_dir)
168
+ )
169
+
170
+ processing_time = time.time() - start_time
171
+ log.info(f"[MATANY] Processing completed in {processing_time:.1f}s")
172
+ log.info(f"[MATANY] Foreground output: {foreground_path}")
173
+ log.info(f"[MATANY] Alpha output: {alpha_path}")
174
+
175
+ # Convert to Path objects
176
+ fg_path = Path(foreground_path) if foreground_path else None
177
+ al_path = Path(alpha_path) if alpha_path else None
178
+
179
+ # Verify outputs exist
180
+ if not fg_path or not fg_path.exists():
181
+ raise MatAnyError(f"Foreground output not created: {fg_path}")
182
+ if not al_path or not al_path.exists():
183
+ raise MatAnyError(f"Alpha output not created: {al_path}")
184
+
185
+ _emit_progress(progress_cb, 1.0, "MatAnyone processing complete")
186
+
187
+ return al_path, fg_path # Return (alpha, foreground) to match expected order
188
+
189
+ except Exception as e:
190
+ log.error(f"[MATANY] Processing failed: {e}")
191
+ raise MatAnyError(f"MatAnyone processing failed: {e}")
192
+
193
+ finally:
194
+ _safe_empty_cache()
195
+
196
+ # ============================================================================
197
+ # MatAnyoneModel Wrapper Class for app_hf.py compatibility
198
+ # ============================================================================
199
+
200
+ class MatAnyoneModel:
201
+ """Wrapper class for MatAnyone to match app_hf.py interface"""
202
+
203
+ def __init__(self, device="cuda"):
204
+ self.device = device
205
+ self.session = None
206
+ self.loaded = False
207
+ log.info(f"Initializing MatAnyoneModel on device: {device}")
208
+
209
+ # Initialize the session
210
+ self._load_model()
211
+
212
+ def _load_model(self):
213
+ """Load the MatAnyone session"""
214
+ try:
215
+ self.session = MatAnyoneSession(device=self.device, precision="auto")
216
+ self.loaded = True
217
+ log.info("MatAnyoneModel loaded successfully")
218
+ except Exception as e:
219
+ log.error(f"Error loading MatAnyoneModel: {e}")
220
+ self.loaded = False
221
+
222
+ def replace_background(self, video_path, masks, background_path):
223
+ """Replace background in video using MatAnyone"""
224
+ if not self.loaded:
225
+ raise MatAnyError("MatAnyoneModel not loaded")
226
+
227
+ try:
228
+ from pathlib import Path
229
+ import tempfile
230
+
231
+ # Convert paths to Path objects
232
+ video_path = Path(video_path)
233
+
234
+ # For now, we expect masks to be a path to the first-frame mask
235
+ mask_path = Path(masks) if isinstance(masks, (str, Path)) else None
236
+
237
+ # Create output directory
238
+ with tempfile.TemporaryDirectory() as temp_dir:
239
+ output_dir = Path(temp_dir)
240
+
241
+ # Process the video stream
242
+ alpha_path, fg_path = self.session.process_stream(
243
+ video_path=video_path,
244
+ seed_mask_path=mask_path,
245
+ out_dir=output_dir,
246
+ progress_cb=None
247
+ )
248
+
249
+ # Return the foreground video path
250
+ # In a full implementation, you'd composite with the background_path
251
+ return str(fg_path)
252
+
253
+ except Exception as e:
254
+ log.error(f"Error in replace_background: {e}")
255
+ raise MatAnyError(f"Background replacement failed: {e}")
256
+
257
+ # ============================================================================
258
+ # Helper function for pipeline integration
259
+ # ============================================================================
260
+
261
+ def create_matanyone_session(device=None):
262
+ """Create a MatAnyone session for use in pipeline"""
263
+ return MatAnyoneSession(device=device)
264
+
265
+ def run_matanyone_on_files(video_path, mask_path, output_dir, device="cuda", progress_callback=None):
266
+ """
267
+ Run MatAnyone on video and mask files.
268
+
269
+ Args:
270
+ video_path: Path to input video
271
+ mask_path: Path to first-frame mask PNG
272
+ output_dir: Directory for outputs
273
+ device: Device to use (cuda/cpu)
274
+ progress_callback: Progress callback function
275
+
276
+ Returns:
277
+ Tuple of (alpha_path, foreground_path) or (None, None) on failure
278
+ """
279
+ try:
280
+ session = MatAnyoneSession(device=device)
281
+ alpha_path, fg_path = session.process_stream(
282
+ video_path=Path(video_path),
283
+ seed_mask_path=Path(mask_path) if mask_path else None,
284
+ out_dir=Path(output_dir),
285
+ progress_cb=progress_callback
286
+ )
287
+ return str(alpha_path), str(fg_path)
288
+ except Exception as e:
289
+ log.error(f"MatAnyone processing failed: {e}")
290
+ return None, None
VideoBackgroundReplacer2/models/sam2_loader.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAM2 Loader with T4-optimized predictor wrapper
4
+ Provides SAM2Predictor class with memory management and optimization features
5
+ """
6
+
7
+ import os
8
+ import gc
9
+ import torch
10
+ import logging
11
+ import numpy as np
12
+ from pathlib import Path
13
+ from typing import Optional, Any, Dict, List, Tuple
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class SAM2Predictor:
18
+ """
19
+ T4-optimized SAM2 video predictor wrapper with memory management
20
+ """
21
+
22
+ def __init__(self, device: torch.device, model_size: str = "small"):
23
+ self.device = device
24
+ self.model_size = model_size
25
+ self.predictor = None
26
+ self.model = None
27
+ self._load_predictor()
28
+
29
+ def _load_predictor(self):
30
+ """Load SAM2 predictor with optimizations"""
31
+ try:
32
+ from sam2.build_sam import build_sam2_video_predictor
33
+
34
+ # Download checkpoint if needed
35
+ checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt"
36
+ if not self._ensure_checkpoint(checkpoint_path):
37
+ raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint")
38
+
39
+ # Build predictor
40
+ model_cfg = f"sam2_hiera_{self.model_size[0]}.yaml" # small -> s, base -> b, large -> l
41
+ self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
42
+
43
+ # Apply T4 optimizations
44
+ self._optimize_for_t4()
45
+
46
+ logger.info(f"SAM2 {self.model_size} predictor loaded successfully")
47
+
48
+ except ImportError as e:
49
+ logger.error(f"SAM2 import failed: {e}")
50
+ raise RuntimeError("SAM2 not available - check third_party/sam2 installation")
51
+ except Exception as e:
52
+ logger.error(f"SAM2 loading failed: {e}")
53
+ raise
54
+
55
+ def _ensure_checkpoint(self, checkpoint_path: str) -> bool:
56
+ """Ensure checkpoint exists, download if needed"""
57
+ checkpoint_file = Path(checkpoint_path)
58
+
59
+ if checkpoint_file.exists():
60
+ file_size = checkpoint_file.stat().st_size / (1024**2)
61
+ if file_size > 50: # At least 50MB
62
+ logger.info(f"SAM2 checkpoint exists: {file_size:.1f}MB")
63
+ return True
64
+ else:
65
+ logger.warning(f"Checkpoint too small ({file_size:.1f}MB), re-downloading")
66
+ checkpoint_file.unlink()
67
+
68
+ return self._download_checkpoint(checkpoint_path)
69
+
70
+ def _download_checkpoint(self, checkpoint_path: str, timeout_seconds: int = 600) -> bool:
71
+ """Download SAM2 checkpoint"""
72
+ try:
73
+ logger.info(f"Downloading SAM2 {self.model_size} checkpoint...")
74
+
75
+ checkpoint_file = Path(checkpoint_path)
76
+ checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
77
+
78
+ import requests
79
+
80
+ # Checkpoint URLs
81
+ urls = {
82
+ "small": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
83
+ "base": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
84
+ "large": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
85
+ }
86
+
87
+ if self.model_size not in urls:
88
+ raise ValueError(f"Unknown model size: {self.model_size}")
89
+
90
+ checkpoint_url = urls[self.model_size]
91
+
92
+ import time
93
+ start_time = time.time()
94
+ response = requests.get(checkpoint_url, stream=True, timeout=30)
95
+ response.raise_for_status()
96
+
97
+ total_size = int(response.headers.get('content-length', 0))
98
+
99
+ temp_path = checkpoint_file.with_suffix('.download')
100
+ downloaded = 0
101
+ last_log = start_time
102
+
103
+ with open(temp_path, 'wb') as f:
104
+ for chunk in response.iter_content(chunk_size=1024*1024):
105
+ if chunk:
106
+ f.write(chunk)
107
+ downloaded += len(chunk)
108
+
109
+ current_time = time.time()
110
+ if current_time - start_time > timeout_seconds:
111
+ raise TimeoutError(f"Download timeout after {timeout_seconds}s")
112
+
113
+ # Progress logging every 15 seconds
114
+ if current_time - last_log > 15:
115
+ progress = (downloaded / total_size * 100) if total_size > 0 else 0
116
+ speed = downloaded / (current_time - start_time) / (1024**2)
117
+ logger.info(f"Download: {progress:.1f}% ({speed:.1f}MB/s)")
118
+ last_log = current_time
119
+
120
+ temp_path.rename(checkpoint_file)
121
+
122
+ download_time = time.time() - start_time
123
+ speed = downloaded / download_time / (1024**2)
124
+ logger.info(f"Download complete: {downloaded/(1024**2):.1f}MB in {download_time:.1f}s ({speed:.1f}MB/s)")
125
+
126
+ return True
127
+
128
+ except Exception as e:
129
+ logger.error(f"Checkpoint download failed: {e}")
130
+ if Path(checkpoint_path).exists():
131
+ Path(checkpoint_path).unlink()
132
+ return False
133
+
134
+ def _optimize_for_t4(self):
135
+ """Apply T4-specific optimizations"""
136
+ try:
137
+ if hasattr(self.predictor, "model") and self.predictor.model is not None:
138
+ self.model = self.predictor.model
139
+
140
+ # Apply fp16 and channels_last for T4 efficiency
141
+ self.model = self.model.half().to(self.device)
142
+ self.model = self.model.to(memory_format=torch.channels_last)
143
+
144
+ logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
145
+
146
+ except Exception as e:
147
+ logger.warning(f"SAM2 T4 optimization warning: {e}")
148
+
149
+ def init_state(self, video_path: str):
150
+ """Initialize video processing state"""
151
+ if self.predictor is None:
152
+ raise RuntimeError("Predictor not loaded")
153
+
154
+ try:
155
+ return self.predictor.init_state(video_path=video_path)
156
+ except Exception as e:
157
+ logger.error(f"Failed to initialize video state: {e}")
158
+ raise
159
+
160
+ def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
161
+ points: np.ndarray, labels: np.ndarray):
162
+ """Add new points for tracking"""
163
+ if self.predictor is None:
164
+ raise RuntimeError("Predictor not loaded")
165
+
166
+ try:
167
+ return self.predictor.add_new_points(
168
+ inference_state=inference_state,
169
+ frame_idx=frame_idx,
170
+ obj_id=obj_id,
171
+ points=points,
172
+ labels=labels
173
+ )
174
+ except Exception as e:
175
+ logger.error(f"Failed to add new points: {e}")
176
+ raise
177
+
178
+ def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
179
+ """Propagate through video with optional scaling"""
180
+ if self.predictor is None:
181
+ raise RuntimeError("Predictor not loaded")
182
+
183
+ try:
184
+ # Use the predictor's propagate_in_video method
185
+ return self.predictor.propagate_in_video(inference_state, **kwargs)
186
+ except Exception as e:
187
+ logger.error(f"Failed to propagate in video: {e}")
188
+ raise
189
+
190
+ def prune_state(self, inference_state, keep: int):
191
+ """Prune SAM2 state to keep only recent frames in memory"""
192
+ try:
193
+ # Try to access and prune internal caches
194
+ # This is model-specific and may need adjustment based on SAM2 internals
195
+ if hasattr(inference_state, 'cached_features'):
196
+ # Keep only the most recent 'keep' frames
197
+ cached_keys = list(inference_state.cached_features.keys())
198
+ if len(cached_keys) > keep:
199
+ keys_to_remove = cached_keys[:-keep]
200
+ for key in keys_to_remove:
201
+ if key in inference_state.cached_features:
202
+ del inference_state.cached_features[key]
203
+ logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
204
+
205
+ # Clear other potential caches
206
+ if hasattr(inference_state, 'point_inputs_per_obj'):
207
+ # Keep recent point inputs only
208
+ for obj_id in list(inference_state.point_inputs_per_obj.keys()):
209
+ obj_inputs = inference_state.point_inputs_per_obj[obj_id]
210
+ if len(obj_inputs) > keep:
211
+ # Keep only recent entries
212
+ recent_keys = sorted(obj_inputs.keys())[-keep:]
213
+ new_inputs = {k: obj_inputs[k] for k in recent_keys}
214
+ inference_state.point_inputs_per_obj[obj_id] = new_inputs
215
+
216
+ # Force garbage collection
217
+ torch.cuda.empty_cache() if self.device.type == 'cuda' else None
218
+
219
+ except Exception as e:
220
+ logger.debug(f"State pruning warning: {e}")
221
+
222
+ def clear_memory(self):
223
+ """Clear GPU memory aggressively"""
224
+ try:
225
+ if self.device.type == 'cuda':
226
+ torch.cuda.empty_cache()
227
+ torch.cuda.synchronize()
228
+ torch.cuda.ipc_collect()
229
+ gc.collect()
230
+ except Exception as e:
231
+ logger.warning(f"Memory clearing warning: {e}")
232
+
233
+ def get_memory_usage(self) -> Dict[str, float]:
234
+ """Get current memory usage statistics"""
235
+ if self.device.type != 'cuda':
236
+ return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
237
+
238
+ try:
239
+ allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
240
+ reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
241
+ free, total = torch.cuda.mem_get_info(self.device)
242
+ free_gb = free / (1024**3)
243
+
244
+ return {
245
+ "allocated_gb": allocated,
246
+ "reserved_gb": reserved,
247
+ "free_gb": free_gb,
248
+ "total_gb": total / (1024**3)
249
+ }
250
+ except Exception:
251
+ return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
252
+
253
+ def __del__(self):
254
+ """Cleanup on deletion"""
255
+ try:
256
+ if hasattr(self, 'predictor') and self.predictor is not None:
257
+ del self.predictor
258
+ if hasattr(self, 'model') and self.model is not None:
259
+ del self.model
260
+ self.clear_memory()
261
+ except Exception:
262
+ pass
VideoBackgroundReplacer2/pipeline.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ pipeline.py β€” Production SAM2 + MatAnyone (T4-optimized, single-pass streaming)
4
+
5
+ Key features
6
+ ------------
7
+ - One SAM2 inference state for the entire video (no per-chunk reinit).
8
+ - In-stream pipeline: Read β†’ SAM2 β†’ MatAnyone β†’ Compose β†’ Write (no big RAM dicts).
9
+ - Bounded memory everywhere (deque/window); optional CPU spill.
10
+ - fp16 + channels_last on SAM2; mixed precision blocks.
11
+ - VRAM-aware controller adjusts memory window/scale.
12
+ - Heartbeat logger to prevent HF watchdog restarts.
13
+ - Safer FFmpeg audio re-mux.
14
+
15
+ Compatible with Tesla T4 (β‰ˆ15–16 GB) and PyTorch 2.5.x + CUDA 12.4 wheels.
16
+ """
17
+
18
+ import os
19
+ import gc
20
+ import cv2
21
+ import time
22
+ import uuid
23
+ import torch
24
+ import queue
25
+ import shutil
26
+ import logging
27
+ import tempfile
28
+ import subprocess
29
+ import threading
30
+ import numpy as np
31
+ from PIL import Image
32
+ from pathlib import Path
33
+ from typing import Optional, Tuple, Dict, Any, Callable
34
+ from collections import deque
35
+
36
+ # ----------------------------------------------------------------------------------------------------------------------
37
+ # Logging
38
+ # ----------------------------------------------------------------------------------------------------------------------
39
+ logger = logging.getLogger("backgroundfx_pro")
40
+ if not logger.handlers:
41
+ h = logging.StreamHandler()
42
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
43
+ logger.addHandler(h)
44
+ logger.setLevel(logging.INFO)
45
+
46
+ # ----------------------------------------------------------------------------------------------------------------------
47
+ # Environment & Torch tuning for T4
48
+ # ----------------------------------------------------------------------------------------------------------------------
49
+ def setup_t4_environment():
50
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF",
51
+ "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
52
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
53
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
54
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
55
+ os.environ.setdefault("OPENCV_OPENCL_RUNTIME", "disabled")
56
+ os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
57
+
58
+ torch.set_grad_enabled(False)
59
+ try:
60
+ torch.backends.cudnn.benchmark = True
61
+ torch.backends.cuda.matmul.allow_tf32 = True
62
+ torch.backends.cudnn.allow_tf32 = True
63
+ torch.set_float32_matmul_precision("high")
64
+ except Exception:
65
+ pass
66
+
67
+ if torch.cuda.is_available():
68
+ try:
69
+ frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88"))
70
+ torch.cuda.set_per_process_memory_fraction(frac)
71
+ logger.info(f"CUDA per-process memory fraction = {frac:.2f}")
72
+ except Exception as e:
73
+ logger.warning(f"Could not set CUDA memory fraction: {e}")
74
+
75
+ def vram_gb() -> Tuple[float, float]:
76
+ if not torch.cuda.is_available():
77
+ return 0.0, 0.0
78
+ free, total = torch.cuda.mem_get_info()
79
+ return free / (1024 ** 3), total / (1024 ** 3)
80
+
81
+ # ----------------------------------------------------------------------------------------------------------------------
82
+ # Heartbeat (prevents Spaces watchdog killing the job)
83
+ # ----------------------------------------------------------------------------------------------------------------------
84
+ def heartbeat_monitor(running_flag: Dict[str, bool], interval: float = 8.0):
85
+ while running_flag.get("running", False):
86
+ print(f"[HB] t={int(time.time())}", flush=True)
87
+ time.sleep(interval)
88
+
89
+ # ----------------------------------------------------------------------------------------------------------------------
90
+ # Streaming video I/O
91
+ # ----------------------------------------------------------------------------------------------------------------------
92
+ class StreamingVideoIO:
93
+ def __init__(self, video_path: str, out_path: str, fps: float):
94
+ self.video_path = video_path
95
+ self.out_path = out_path
96
+ self.fps = fps
97
+ self.cap = None
98
+ self.writer = None
99
+ self.size = None
100
+
101
+ def __enter__(self):
102
+ self.cap = cv2.VideoCapture(self.video_path)
103
+ if not self.cap.isOpened():
104
+ raise RuntimeError(f"Cannot open video: {self.video_path}")
105
+ w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
106
+ h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
107
+ self.size = (w, h)
108
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
109
+ self.writer = cv2.VideoWriter(self.out_path, fourcc, self.fps, (w, h))
110
+ return self
111
+
112
+ def __exit__(self, exc_type, exc_val, exc_tb):
113
+ if self.cap:
114
+ self.cap.release()
115
+ if self.writer:
116
+ self.writer.release()
117
+
118
+ def read_frame(self):
119
+ if not self.cap:
120
+ return False, None
121
+ return self.cap.read()
122
+
123
+ def write_frame(self, frame_bgr: np.ndarray):
124
+ if not self.writer:
125
+ return
126
+ self.writer.write(frame_bgr)
127
+
128
+ # ----------------------------------------------------------------------------------------------------------------------
129
+ # Models: loaders and safe optimizations
130
+ # ----------------------------------------------------------------------------------------------------------------------
131
+ def load_sam2_predictor(device: torch.device):
132
+ """
133
+ Prefer your local wrapper to keep interfaces stable.
134
+ """
135
+ try:
136
+ from models.sam2_loader import SAM2Predictor # your wrapper
137
+ predictor = SAM2Predictor(device=device)
138
+ # Optional: try to access underlying model to set fp16 + channels_last
139
+ try:
140
+ if hasattr(predictor, "model") and predictor.model is not None:
141
+ predictor.model = predictor.model.half().to(device)
142
+ predictor.model = predictor.model.to(memory_format=torch.channels_last)
143
+ logger.info("SAM2: fp16 + channels_last applied (wrapper model).")
144
+ except Exception as e:
145
+ logger.warning(f"SAM2 fp16 optimization warning: {e}")
146
+ return predictor
147
+ except Exception as e:
148
+ logger.error(f"Failed to import SAM2Predictor: {e}")
149
+ raise
150
+
151
+ def load_matany_session(device: torch.device):
152
+ """
153
+ Supports either MatAnyoneSession or MatAnyoneLoader (your code has varied).
154
+ """
155
+ try:
156
+ try:
157
+ from models.matanyone_loader import MatAnyoneSession as _MatAny
158
+ except Exception:
159
+ from models.matanyone_loader import MatAnyoneLoader as _MatAny
160
+ session = _MatAny(device=device)
161
+ # Try fp16 eval where safe
162
+ if hasattr(session, "model") and session.model is not None:
163
+ session.model.eval()
164
+ try:
165
+ session.model = session.model.half().to(device)
166
+ logger.info("MatAnyone: fp16 + eval applied.")
167
+ except Exception:
168
+ logger.info("MatAnyone: using fp32 (fp16 not supported for some layers).")
169
+ return session
170
+ except Exception as e:
171
+ logger.warning(f"MatAnyone not available ({e}). Proceeding without refinement.")
172
+ return None
173
+
174
+ # ----------------------------------------------------------------------------------------------------------------------
175
+ # SAM2 state pruning (adapter): we call predictor.prune_state if present, else best-effort
176
+ # ----------------------------------------------------------------------------------------------------------------------
177
+ def prune_sam2_state(predictor, state: Any, keep: int):
178
+ """
179
+ Try to prune SAM2 temporal caches to a fixed window length.
180
+ Your SAM2Predictor should implement prune_state(state, keep=N). If not, we do nothing.
181
+ """
182
+ try:
183
+ if hasattr(predictor, "prune_state"):
184
+ predictor.prune_state(state, keep=keep)
185
+ elif hasattr(state, "prune") and callable(getattr(state, "prune")):
186
+ state.prune(keep=keep)
187
+ else:
188
+ # No-op; rely on model internals and GC
189
+ pass
190
+ except Exception as e:
191
+ logger.debug(f"SAM2 prune_state warning: {e}")
192
+
193
+ # ----------------------------------------------------------------------------------------------------------------------
194
+ # VRAM-aware controller
195
+ # ----------------------------------------------------------------------------------------------------------------------
196
+ class VRAMAdaptiveController:
197
+ def __init__(self):
198
+ self.memory_window = int(os.getenv("SAM2_WINDOW", "96")) # frames to keep in model state
199
+ self.propagation_scale = float(os.getenv("SAM2_PROP_SCALE", "0.90")) # e.g., downscale factor for propagation
200
+ self.cleanup_every = 20 # frames
201
+
202
+ def adapt(self):
203
+ free, total = vram_gb()
204
+ if free == 0.0:
205
+ return
206
+ # Tighten if we dip under ~1.6 GB
207
+ if free < 1.6:
208
+ self.memory_window = max(48, self.memory_window - 8)
209
+ self.propagation_scale = max(0.75, self.propagation_scale - 0.03)
210
+ self.cleanup_every = max(12, self.cleanup_every - 2)
211
+ logger.warning(f"Low VRAM ({free:.2f} GB free) β†’ window={self.memory_window}, scale={self.propagation_scale:.2f}")
212
+ # Relax if plenty free
213
+ elif free > 3.0:
214
+ self.memory_window = min(128, self.memory_window + 4)
215
+ self.propagation_scale = min(1.0, self.propagation_scale + 0.01)
216
+ self.cleanup_every = min(40, self.cleanup_every + 2)
217
+
218
+ # ----------------------------------------------------------------------------------------------------------------------
219
+ # Audio mux helper (safer stream mapping)
220
+ # ----------------------------------------------------------------------------------------------------------------------
221
+ def mux_audio(video_path_no_audio: str, source_with_audio: str, out_path: str) -> bool:
222
+ cmd = [
223
+ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
224
+ "-i", video_path_no_audio,
225
+ "-i", source_with_audio,
226
+ "-map", "0:v:0", "-map", "1:a:0",
227
+ "-c:v", "copy", "-c:a", "aac", "-shortest",
228
+ out_path
229
+ ]
230
+ try:
231
+ r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
232
+ if r.returncode != 0:
233
+ logger.warning(f"FFmpeg mux failed: {r.stderr.strip()}")
234
+ return False
235
+ return True
236
+ except Exception as e:
237
+ logger.warning(f"FFmpeg mux error: {e}")
238
+ return False
239
+
240
+ # ----------------------------------------------------------------------------------------------------------------------
241
+ # Main processing
242
+ # ----------------------------------------------------------------------------------------------------------------------
243
+ def process(
244
+ video_path: str,
245
+ background_image: Optional[Image.Image] = None,
246
+ background_type: str = "custom",
247
+ background_prompt: str = "",
248
+ job_directory: Optional[Path] = None,
249
+ progress_callback: Optional[Callable[[str, float], None]] = None
250
+ ) -> str:
251
+ """
252
+ Production SAM2 + MatAnyone pipeline for T4.
253
+ - Single-pass streaming (no large mask dicts)
254
+ - Bounded memory windows
255
+ """
256
+ setup_t4_environment()
257
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
258
+
259
+ # Heartbeat
260
+ hb_flag = {"running": True}
261
+ hb_thread = threading.Thread(target=heartbeat_monitor, args=(hb_flag, 8.0), daemon=True)
262
+ hb_thread.start()
263
+
264
+ def report(step: str, p: Optional[float] = None):
265
+ if p is None:
266
+ logger.info(step)
267
+ else:
268
+ logger.info(f"{step} [{p:.1%}]")
269
+ if progress_callback:
270
+ try:
271
+ progress_callback(step, p)
272
+ except Exception as e:
273
+ logger.debug(f"progress_callback error: {e}")
274
+
275
+ # Validate I/O
276
+ src = Path(video_path)
277
+ if not src.exists():
278
+ hb_flag["running"] = False
279
+ raise FileNotFoundError(f"Video not found: {video_path}")
280
+
281
+ if job_directory is None:
282
+ job_directory = Path.cwd() / "tmp" / f"job_{uuid.uuid4().hex[:8]}"
283
+ job_directory.mkdir(parents=True, exist_ok=True)
284
+
285
+ # Probe video
286
+ cap_probe = cv2.VideoCapture(str(src))
287
+ if not cap_probe.isOpened():
288
+ hb_flag["running"] = False
289
+ raise RuntimeError(f"Cannot open video: {video_path}")
290
+ fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0
291
+ width = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
292
+ height = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
293
+ frame_count = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
294
+ duration = frame_count / fps if fps > 0 else 0.0
295
+ cap_probe.release()
296
+ logger.info(f"Video: {width}x{height} @ {fps:.2f} fps | {frame_count} frames ({duration:.1f}s)")
297
+
298
+ # Prepare background
299
+ if background_image is None:
300
+ hb_flag["running"] = False
301
+ raise ValueError("background_image is required")
302
+ bg = background_image.resize((width, height), Image.LANCZOS)
303
+ bg_np = np.array(bg).astype(np.float32)
304
+
305
+ # Load models
306
+ report("Loading SAM2 + MatAnyone", 0.05)
307
+ predictor = load_sam2_predictor(device)
308
+ matany = load_matany_session(device)
309
+
310
+ # Init SAM2 state (single)
311
+ report("Initializing SAM2 video state", 0.08)
312
+ state = predictor.init_state(video_path=str(src))
313
+
314
+ # Minimal prompt: single positive point at center (replace with your prompt UI if needed)
315
+ center_pt = np.array([[width // 2, height // 2]], dtype=np.float32)
316
+ labels = np.array([1], dtype=np.int32)
317
+ ann_obj_id = 1
318
+ with torch.inference_mode():
319
+ _ = predictor.add_new_points(
320
+ inference_state=state,
321
+ frame_idx=0,
322
+ obj_id=ann_obj_id,
323
+ points=center_pt,
324
+ labels=labels,
325
+ )
326
+
327
+ # Controller
328
+ ctrl = VRAMAdaptiveController()
329
+
330
+ # Output paths
331
+ out_raw = str(job_directory / f"composite_{int(time.time())}.mp4")
332
+ out_final = str(job_directory / f"final_{int(time.time())}.mp4")
333
+
334
+ # Windows/buffers (bounded)
335
+ # For completeness we keep a tiny deque for any auxiliary temporal ops (e.g., matting history)
336
+ aux_window = deque(maxlen=max(32, min(96, ctrl.memory_window // 2)))
337
+
338
+ # Stream processing
339
+ start = time.time()
340
+ frames_done = 0
341
+ next_cleanup_at = ctrl.cleanup_every
342
+
343
+ report("Streaming: SAM2 β†’ MatAnyone β†’ Compose β†’ Write", 0.12)
344
+ with StreamingVideoIO(str(src), out_raw, fps) as vio:
345
+ # iterate SAM2 propagation alongside reading frames
346
+ with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16 if device.type == "cuda" else None):
347
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(state, scale=ctrl.propagation_scale):
348
+ # Read the matching frame
349
+ ret, frame_bgr = vio.read_frame()
350
+ if not ret:
351
+ break
352
+
353
+ # Get mask for ann_obj_id; keep on GPU as long as possible
354
+ mask_t = None
355
+ try:
356
+ if isinstance(out_obj_ids, torch.Tensor):
357
+ # find index where id == ann_obj_id
358
+ idxs = (out_obj_ids == ann_obj_id).nonzero(as_tuple=False)
359
+ if idxs.numel() > 0:
360
+ i = idxs[0].item()
361
+ logits = out_mask_logits[i]
362
+ else:
363
+ logits = None
364
+ else:
365
+ # list/array fallback
366
+ ids_list = list(out_obj_ids)
367
+ i = ids_list.index(ann_obj_id) if ann_obj_id in ids_list else -1
368
+ logits = out_mask_logits[i] if i >= 0 else None
369
+
370
+ if logits is not None:
371
+ # logits β†’ prob β†’ binary mask (threshold 0)
372
+ mask_t = (logits > 0).float() # HxW on CUDA fp16 β†’ fp32 float
373
+ except Exception as e:
374
+ logger.debug(f"Mask extraction warning @frame {out_frame_idx}: {e}")
375
+ mask_t = None
376
+
377
+ # Optional: MatAnyone refinement
378
+ if mask_t is not None and matany is not None:
379
+ try:
380
+ # MatAnyone APIs vary β€” try common forms
381
+ # Convert RGB because many mattors expect RGB
382
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
383
+ # Move frame to GPU only if your matting backend supports it
384
+ refined = None
385
+ if hasattr(matany, "refine_mask"):
386
+ refined = matany.refine_mask(frame_rgb, mask_t) # allow handler to decide device
387
+ elif hasattr(matany, "process_frame"):
388
+ refined = matany.process_frame(frame_rgb, mask_t)
389
+ if refined is not None:
390
+ # ensure float mask 0..1 on CUDA or CPU
391
+ if isinstance(refined, torch.Tensor):
392
+ mask_t = refined.float()
393
+ else:
394
+ # numpy β†’ torch
395
+ mask_t = torch.from_numpy(refined.astype(np.float32))
396
+ if device.type == "cuda":
397
+ mask_t = mask_t.to(device)
398
+ except Exception as e:
399
+ logger.debug(f"MatAnyone refinement failed (frame {out_frame_idx}): {e}")
400
+
401
+ # Compose and write (convert once, keep math sane)
402
+ if mask_t is not None:
403
+ # bring mask to CPU for np composition; keep as float [0,1]
404
+ mask_np = mask_t.detach().clamp(0, 1).to("cpu", non_blocking=True).float().numpy()
405
+ m3 = mask_np[..., None] # HxWx1
406
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
407
+ comp = frame_rgb * m3 + bg_np * (1.0 - m3)
408
+ comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
409
+ vio.write_frame(comp_bgr)
410
+ else:
411
+ # No mask β€” write original frame
412
+ vio.write_frame(frame_bgr)
413
+
414
+ # Periodic maintenance
415
+ frames_done += 1
416
+ if frames_done >= next_cleanup_at:
417
+ ctrl.adapt()
418
+ prune_sam2_state(predictor, state, keep=ctrl.memory_window)
419
+ # Clear small aux buffers
420
+ aux_window.clear()
421
+ if device.type == "cuda":
422
+ torch.cuda.ipc_collect()
423
+ torch.cuda.empty_cache()
424
+ next_cleanup_at = frames_done + ctrl.cleanup_every
425
+
426
+ # Progress
427
+ if frames_done % 25 == 0 and frame_count > 0:
428
+ p = 0.12 + 0.75 * (frames_done / frame_count)
429
+ report(f"Processing frame {frames_done}/{frame_count} | win={ctrl.memory_window} scale={ctrl.propagation_scale:.2f}", p)
430
+
431
+ # Audio mux
432
+ report("Restoring audio", 0.93)
433
+ ok = mux_audio(out_raw, str(src), out_final)
434
+ final_path = out_final if ok else out_raw
435
+
436
+ # Cleanup models/state promptly
437
+ try:
438
+ del predictor
439
+ del state
440
+ if matany is not None:
441
+ del matany
442
+ except Exception:
443
+ pass
444
+
445
+ if device.type == "cuda":
446
+ torch.cuda.ipc_collect()
447
+ torch.cuda.empty_cache()
448
+ gc.collect()
449
+
450
+ hb_flag["running"] = False
451
+ elapsed = time.time() - start
452
+ try:
453
+ peak = torch.cuda.max_memory_allocated() / (1024 ** 3) if device.type == "cuda" else 0.0
454
+ logger.info(f"Peak GPU memory: {peak:.2f} GB")
455
+ except Exception:
456
+ pass
457
+ report(f"Done in {elapsed:.1f}s", 1.0)
458
+ logger.info(f"Output: {final_path}")
459
+ logger.info(f"Artifacts: {job_directory}")
460
+ return final_path
461
+
462
+
463
+ # -------------------------------------------------------------------------------------------------
464
+ # CLI entry (optional)
465
+ # -------------------------------------------------------------------------------------------------
466
+ if __name__ == "__main__":
467
+ import argparse
468
+ parser = argparse.ArgumentParser(description="BackgroundFX Pro pipeline")
469
+ parser.add_argument("--video", required=True, help="Path to input video")
470
+ parser.add_argument("--background", required=True, help="Path to background image")
471
+ parser.add_argument("--outdir", default=None, help="Job directory (optional)")
472
+ args = parser.parse_args()
473
+
474
+ bg_img = Image.open(args.background).convert("RGB")
475
+ outdir = Path(args.outdir) if args.outdir else None
476
+ out_path = process(args.video, background_image=bg_img, job_directory=outdir)
477
+ print(out_path)
VideoBackgroundReplacer2/requirements.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===== Core Dependencies =====
2
+ # PyTorch is installed in Dockerfile with CUDA 12.1 β€” REQUIRED for SAM2
3
+ # torch==2.5.1
4
+ # torchvision==0.20.1
5
+ # torchaudio==2.5.1
6
+
7
+ # ===== Base Dependencies =====
8
+ numpy>=1.24.0,<2.1.0
9
+ Pillow>=10.0.0,<12.0.0
10
+ protobuf>=4.25.0,<6.0.0
11
+
12
+ # ===== Image/Video Processing =====
13
+ opencv-python-headless>=4.8.0,<4.11.0
14
+ imageio>=2.25.0,<3.0.0
15
+ imageio-ffmpeg>=0.4.7,<0.6.0
16
+ moviepy>=1.0.3,<2.0.0
17
+ decord>=0.6.0,<0.7.0
18
+ scikit-image>=0.19.3,<0.22.0
19
+
20
+ # ===== MediaPipe =====
21
+ mediapipe>=0.10.0,<0.11.0
22
+
23
+ # ===== SAM2 Dependencies =====
24
+ # SAM2 is installed via git clone in Dockerfile
25
+ hydra-core>=1.3.2,<2.0.0
26
+ omegaconf>=2.3.0,<3.0.0
27
+ einops>=0.6.0,<0.9.0
28
+ timm>=0.9.0,<1.1.0
29
+ pyyaml>=6.0.0,<7.0.0
30
+ matplotlib>=3.5.0,<4.0.0
31
+ iopath>=0.1.10,<0.2.0
32
+
33
+ # ===== MatAnyone Dependencies =====
34
+ # MatAnyone is installed separately in Dockerfile
35
+ kornia>=0.7.0,<0.8.0
36
+ tqdm>=4.60.0,<5.0.0
37
+
38
+ # ===== UI and API =====
39
+ # Bump to avoid gradio_client 1.3.0 bug ("bool is not iterable")
40
+ gradio==4.42.0
41
+
42
+ # ===== Web stack pins for Gradio 4.42.0 =====
43
+ fastapi==0.109.2
44
+ starlette==0.36.3
45
+ uvicorn==0.29.0
46
+ httpx==0.27.2
47
+ anyio==4.4.0
48
+ orjson>=3.10.0
49
+
50
+ # ===== Pydantic family (avoid breaking core 2.23.x) =====
51
+ pydantic==2.8.2
52
+ pydantic-core==2.20.1
53
+ annotated-types==0.6.0
54
+ typing-extensions==4.12.2
55
+
56
+ # ===== Helpers and Utilities =====
57
+ huggingface-hub>=0.20.0,<1.0.0
58
+ ffmpeg-python>=0.2.0,<1.0.0
59
+ psutil>=5.8.0,<7.0.0
60
+ requests>=2.25.0,<3.0.0
61
+ scikit-learn>=1.3.0,<2.0.0
62
+
63
+ # ===== Additional Dependencies =====
64
+ # Performance and monitoring
65
+ gputil>=1.4.0,<2.0.0
66
+ nvidia-ml-py3>=7.352.0,<12.0.0
67
+
68
+ # Error handling and logging
69
+ loguru>=0.6.0,<1.0.0
70
+
71
+ # File handling
72
+ python-multipart>=0.0.5,<1.0.0
VideoBackgroundReplacer2/two_stage_pipeline.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ two_stage_pipeline.py β€” Ephemeral SAM2 stage + MatAnyone stage
4
+ - Stage 1: SAM2 -> lossless mask stream (FFV1 .mkv) + meta.json, then unload SAM2
5
+ - Stage 2: read mask stream -> (optional) MatAnyone refine -> composite -> mux audio
6
+ """
7
+
8
+ import os, sys, gc, json, cv2, time, uuid, torch, shutil, logging, subprocess, threading
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from typing import Optional, Callable, Tuple, Dict, Any
12
+ from PIL import Image
13
+
14
+ logger = logging.getLogger("backgroundfx_pro.two_stage")
15
+ if not logger.handlers:
16
+ h = logging.StreamHandler()
17
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
18
+ logger.addHandler(h)
19
+ logger.setLevel(logging.INFO)
20
+
21
+ # ---------------------------
22
+ # Env & CUDA helpers
23
+ # ---------------------------
24
+ def setup_env():
25
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF","expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
26
+ os.environ.setdefault("OMP_NUM_THREADS","1")
27
+ os.environ.setdefault("OPENBLAS_NUM_THREADS","1")
28
+ os.environ.setdefault("MKL_NUM_THREADS","1")
29
+ torch.set_grad_enabled(False)
30
+ try:
31
+ torch.backends.cudnn.benchmark = True
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = True
34
+ torch.set_float32_matmul_precision("high")
35
+ except Exception:
36
+ pass
37
+ if torch.cuda.is_available():
38
+ try:
39
+ torch.cuda.set_per_process_memory_fraction(float(os.getenv("CUDA_MEMORY_FRACTION","0.88")))
40
+ except Exception:
41
+ pass
42
+
43
+ def free_cuda():
44
+ if torch.cuda.is_available():
45
+ torch.cuda.ipc_collect()
46
+ torch.cuda.empty_cache()
47
+
48
+ def unload_sam2_modules():
49
+ """Aggressively unload SAM2 python modules to reduce RSS."""
50
+ try:
51
+ import importlib
52
+ mods = [m for m in list(sys.modules) if m.startswith("sam2")]
53
+ for m in mods:
54
+ sys.modules.pop(m, None)
55
+ importlib.invalidate_caches()
56
+ gc.collect()
57
+ free_cuda()
58
+ logger.info("SAM2 modules unloaded.")
59
+ except Exception as e:
60
+ logger.warning(f"Unloading SAM2 modules: {e}")
61
+
62
+ # ---------------------------
63
+ # Video probing
64
+ # ---------------------------
65
+ def probe_video(path:str) -> Tuple[int,int,float,int]:
66
+ cap = cv2.VideoCapture(path)
67
+ if not cap.isOpened():
68
+ raise RuntimeError(f"Cannot open video: {path}")
69
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
70
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
+ n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
73
+ cap.release()
74
+ return w,h,float(fps),n
75
+
76
+ # ---------------------------
77
+ # FFmpeg mask writers/readers
78
+ # ---------------------------
79
+ class MaskFFV1Writer:
80
+ """Write uint8 binary/gray masks to FFV1 lossless .mkv via pipe."""
81
+ def __init__(self, path:str, w:int, h:int, fps:float):
82
+ self.path = path
83
+ self.w, self.h, self.fps = w,h,fps
84
+ self.proc = None
85
+
86
+ def __enter__(self):
87
+ cmd = [
88
+ "ffmpeg","-y","-hide_banner","-loglevel","error",
89
+ "-f","rawvideo","-pix_fmt","gray","-s",f"{self.w}x{self.h}","-r",f"{self.fps}",
90
+ "-i","-",
91
+ "-c:v","ffv1","-level","3","-g","1", self.path
92
+ ]
93
+ self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
94
+ return self
95
+
96
+ def write(self, mask_u8: np.ndarray):
97
+ # Expect HxW uint8 (0/255). Ensure contiguous.
98
+ if mask_u8.dtype != np.uint8:
99
+ mask_u8 = mask_u8.astype(np.uint8)
100
+ self.proc.stdin.write(mask_u8.tobytes())
101
+
102
+ def __exit__(self, exc_type, exc, tb):
103
+ if self.proc:
104
+ try:
105
+ self.proc.stdin.flush()
106
+ self.proc.stdin.close()
107
+ self.proc.wait(timeout=120)
108
+ except Exception:
109
+ self.proc.kill()
110
+
111
+ class MaskFFV1Reader:
112
+ """Read uint8 masks from FFV1 .mkv via pipe."""
113
+ def __init__(self, path:str, w:int, h:int):
114
+ self.path = path
115
+ self.w,self.h = w,h
116
+ self.proc = None
117
+ self.frame_bytes = w*h
118
+
119
+ def __enter__(self):
120
+ cmd = [
121
+ "ffmpeg","-hide_banner","-loglevel","error","-i", self.path,
122
+ "-f","rawvideo","-pix_fmt","gray","-"
123
+ ]
124
+ self.proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
125
+ return self
126
+
127
+ def read(self) -> Optional[np.ndarray]:
128
+ buf = self.proc.stdout.read(self.frame_bytes)
129
+ if not buf or len(buf) < self.frame_bytes:
130
+ return None
131
+ return np.frombuffer(buf, dtype=np.uint8).reshape(self.h, self.w)
132
+
133
+ def __exit__(self, exc_type, exc, tb):
134
+ if self.proc:
135
+ try:
136
+ self.proc.stdout.close()
137
+ self.proc.wait(timeout=30)
138
+ except Exception:
139
+ self.proc.kill()
140
+
141
+ # Fallback: PNG sequence (disk heavy but simple & robust)
142
+ class MaskPNGWriter:
143
+ def __init__(self, dirpath: Path):
144
+ self.dir = dirpath; self.dir.mkdir(parents=True, exist_ok=True); self.idx=0
145
+ def write(self, mask_u8: np.ndarray):
146
+ cv2.imwrite(str(self.dir / f"{self.idx:06d}.png"), mask_u8)
147
+ self.idx+=1
148
+
149
+ class MaskPNGReader:
150
+ def __init__(self, dirpath: Path):
151
+ self.dir=dirpath; self.idx=0
152
+ def read(self) -> Optional[np.ndarray]:
153
+ p = self.dir / f"{self.idx:06d}.png"
154
+ if not p.exists(): return None
155
+ img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
156
+ self.idx+=1
157
+ return img
158
+
159
+ # ---------------------------
160
+ # Stage 1 β€” SAM2 β†’ mask dump
161
+ # ---------------------------
162
+ def stage1_dump_masks(video_path:str, out_dir:Path, obj_point:Tuple[int,int]=None) -> Dict[str,Any]:
163
+ """
164
+ Run only SAM2, save masks as FFV1 (preferred) or PNG sequence + meta.json.
165
+ Returns meta dict.
166
+ """
167
+ setup_env()
168
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+ w,h,fps,n = probe_video(video_path)
170
+ out_dir.mkdir(parents=True, exist_ok=True)
171
+ meta = {"video":video_path, "width":w,"height":h,"fps":fps,"frames":n, "storage":None}
172
+ logger.info(f"[Stage1] {w}x{h}@{fps:.2f} | frames={n}")
173
+
174
+ # Load SAM2 (your wrapper)
175
+ from models.sam2_loader import SAM2Predictor
176
+ predictor = SAM2Predictor(device=device)
177
+ state = predictor.init_state(video_path=video_path)
178
+
179
+ # Prompt: center positive if not provided
180
+ if obj_point is None:
181
+ obj_point = (w//2, h//2)
182
+ pts = np.array([[obj_point[0], obj_point[1]]], dtype=np.float32)
183
+ labels = np.array([1], dtype=np.int32)
184
+ ann_obj_id = 1
185
+ with torch.inference_mode():
186
+ predictor.add_new_points(state, 0, ann_obj_id, pts, labels)
187
+
188
+ # Preferred: FFV1 mask stream
189
+ mask_mkv = out_dir / "mask.mkv"
190
+ use_png = False
191
+ try:
192
+ with MaskFFV1Writer(str(mask_mkv), w, h, fps) as writer, \
193
+ torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16 if device.type=="cuda" else None):
194
+ for _, out_ids, out_logits in predictor.propagate_in_video(state):
195
+ # pick ann_obj_id
196
+ i = None
197
+ if isinstance(out_ids, torch.Tensor):
198
+ nz = (out_ids == ann_obj_id).nonzero(as_tuple=False)
199
+ if nz.numel() > 0: i = nz[0].item()
200
+ else:
201
+ ids = list(out_ids); i = ids.index(ann_obj_id) if ann_obj_id in ids else None
202
+ if i is None:
203
+ # write empty
204
+ writer.write(np.zeros((h,w), np.uint8))
205
+ continue
206
+ mask = (out_logits[i] > 0).detach()
207
+ mask_u8 = (mask.float().mul_(255).to("cpu", non_blocking=True).numpy()).astype(np.uint8)
208
+ writer.write(mask_u8)
209
+ meta["storage"] = "ffv1"
210
+ meta["mask_path"] = str(mask_mkv)
211
+ logger.info("[Stage1] Masks saved as FFV1 .mkv")
212
+ except Exception as e:
213
+ logger.warning(f"FFV1 writer failed ({e}), falling back to PNG sequence.")
214
+ png_dir = out_dir / "masks_png"
215
+ wr = MaskPNGWriter(png_dir)
216
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16 if device.type=="cuda" else None):
217
+ for _, out_ids, out_logits in predictor.propagate_in_video(state):
218
+ i = None
219
+ if isinstance(out_ids, torch.Tensor):
220
+ nz = (out_ids == ann_obj_id).nonzero(as_tuple=False)
221
+ if nz.numel() > 0: i = nz[0].item()
222
+ else:
223
+ ids = list(out_ids); i = ids.index(ann_obj_id) if ann_obj_id in ids else None
224
+ if i is None:
225
+ wr.write(np.zeros((h,w), np.uint8)); continue
226
+ mask = (out_logits[i] > 0).detach()
227
+ wr.write((mask.float().mul_(255).to("cpu").numpy()).astype(np.uint8))
228
+ meta["storage"] = "png"
229
+ meta["mask_path"] = str(png_dir)
230
+
231
+ # Persist meta
232
+ with open(out_dir / "meta.json","w") as f:
233
+ json.dump(meta, f)
234
+ # Unload SAM2 completely
235
+ del predictor, state
236
+ free_cuda(); unload_sam2_modules()
237
+ return meta
238
+
239
+ # ---------------------------
240
+ # Stage 2 β€” refine + compose
241
+ # ---------------------------
242
+ def stage2_refine_and_compose(video_path:str, mask_dir:Path, background_image:Image.Image,
243
+ out_path:str, use_matany:bool=True) -> str:
244
+ w,h,fps,n = probe_video(video_path)
245
+ bg = background_image.resize((w,h), Image.LANCZOS)
246
+ bg_np = np.array(bg).astype(np.float32)
247
+
248
+ # Read meta
249
+ with open(mask_dir / "meta.json","r") as f:
250
+ meta = json.load(f)
251
+ storage = meta["storage"]; mask_path = meta["mask_path"]
252
+
253
+ # Optional MatAnyone
254
+ session = None
255
+ if use_matany:
256
+ try:
257
+ from models.matanyone_loader import MatAnyoneSession as _M
258
+ except Exception:
259
+ try:
260
+ from models.matanyone_loader import MatAnyoneLoader as _M
261
+ except Exception:
262
+ _M = None
263
+ if _M:
264
+ session = _M(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
265
+ if hasattr(session,"model") and session.model is not None:
266
+ session.model.eval()
267
+
268
+ # Open video + writer
269
+ cap = cv2.VideoCapture(video_path)
270
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
271
+ tmp_out = str(Path(out_path).with_suffix(".noaudio.mp4"))
272
+ writer = cv2.VideoWriter(tmp_out, fourcc, fps, (w,h))
273
+
274
+ # Open mask reader
275
+ if storage == "ffv1":
276
+ mreader = MaskFFV1Reader(mask_path, w, h)
277
+ mreader.__enter__()
278
+ read_mask = lambda : mreader.read()
279
+ else:
280
+ mreader = MaskPNGReader(Path(mask_path))
281
+ read_mask = lambda : mreader.read()
282
+
283
+ i = 0
284
+ try:
285
+ while True:
286
+ ok, frame_bgr = cap.read()
287
+ if not ok: break
288
+ mask_u8 = read_mask()
289
+ if mask_u8 is None:
290
+ # out of masks; write original
291
+ writer.write(frame_bgr); i+=1; continue
292
+
293
+ # Optional refine
294
+ if session is not None:
295
+ try:
296
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
297
+ # Provide a float mask 0..1 to session; adapt if your API differs
298
+ mask_f = (mask_u8.astype(np.float32) / 255.0)
299
+ if hasattr(session,"refine_mask"):
300
+ mask_refined = session.refine_mask(frame_rgb, mask_f)
301
+ elif hasattr(session,"process_frame"):
302
+ mask_refined = session.process_frame(frame_rgb, mask_f)
303
+ else:
304
+ mask_refined = mask_f
305
+ if isinstance(mask_refined, torch.Tensor):
306
+ mask_u8 = (mask_refined.detach().clamp(0,1).mul(255).to("cpu").numpy()).astype(np.uint8)
307
+ elif isinstance(mask_refined, np.ndarray):
308
+ mask_u8 = (np.clip(mask_refined,0,1)*255).astype(np.uint8)
309
+ except Exception as e:
310
+ logger.debug(f"MatAnyone refine failed @frame {i}: {e}")
311
+
312
+ # Composite
313
+ m = (mask_u8.astype(np.float32)/255.0)[...,None] # HxWx1
314
+ fr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
315
+ comp = fr*m + bg_np*(1.0-m)
316
+ comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
317
+ writer.write(comp_bgr)
318
+
319
+ if i % 50 == 0:
320
+ logger.info(f"[Stage2] frame {i}/{n}")
321
+ i += 1
322
+ finally:
323
+ cap.release(); writer.release()
324
+ if isinstance(mreader, MaskFFV1Reader):
325
+ mreader.__exit__(None,None,None)
326
+
327
+ # Mux audio
328
+ final_out = str(Path(out_path))
329
+ cmd = [
330
+ "ffmpeg","-y","-hide_banner","-loglevel","error",
331
+ "-i", tmp_out, "-i", video_path,
332
+ "-map","0:v:0","-map","1:a:0","-c:v","copy","-c:a","aac","-shortest", final_out
333
+ ]
334
+ try:
335
+ r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
336
+ if r.returncode != 0:
337
+ logger.warning(f"Audio mux failed: {r.stderr.strip()}")
338
+ shutil.move(tmp_out, final_out)
339
+ else:
340
+ os.remove(tmp_out)
341
+ except Exception:
342
+ shutil.move(tmp_out, final_out)
343
+ return final_out
344
+
345
+ # ---------------------------
346
+ # Orchestrator
347
+ # ---------------------------
348
+ def process_two_stage(
349
+ video_path:str,
350
+ background_image: Image.Image,
351
+ workdir: Optional[Path]=None,
352
+ progress: Optional[Callable[[str,float],None]] = None,
353
+ use_matany: bool = True,
354
+ ) -> str:
355
+ setup_env()
356
+ if workdir is None:
357
+ workdir = Path.cwd()/ "tmp" / f"job_{uuid.uuid4().hex[:8]}"
358
+ workdir.mkdir(parents=True, exist_ok=True)
359
+
360
+ # Stage 1
361
+ if progress: progress("Stage 1: SAM2 mask pass", 0.05)
362
+ mask_dir = workdir / "sam2_masks"
363
+ meta = stage1_dump_masks(video_path, mask_dir)
364
+ if progress: progress("Stage 1 complete", 0.45)
365
+
366
+ # Stage 2
367
+ if progress: progress("Stage 2: refine + compose", 0.50)
368
+ out_path = workdir / f"final_{int(time.time())}.mp4"
369
+ final_video = stage2_refine_and_compose(video_path, mask_dir, background_image, str(out_path), use_matany=use_matany)
370
+ if progress: progress("Done", 1.0)
371
+ logger.info(f"Output: {final_video}")
372
+ return final_video
373
+
374
+ # ---------------------------
375
+ # CLI
376
+ # ---------------------------
377
+ if __name__ == "__main__":
378
+ import argparse
379
+ parser = argparse.ArgumentParser(description="Two-stage BackgroundFX Pro")
380
+ parser.add_argument("--video", required=True)
381
+ parser.add_argument("--background", required=True)
382
+ parser.add_argument("--outdir", default=None)
383
+ parser.add_argument("--no-matany", action="store_true")
384
+ args = parser.parse_args()
385
+
386
+ bg = Image.open(args.background).convert("RGB")
387
+ out = process_two_stage(args.video, bg, Path(args.outdir) if args.outdir else None, use_matany=not args.no_matany)
388
+ print(out)
VideoBackgroundReplacer2/ui.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro β€” Main UI Application (Gradio 4.42.x)
4
+ Clean, focused main file that coordinates the application
5
+ """
6
+
7
+ # ============================================================
8
+ # Mount-mode handoff: delegate to app.py when enabled
9
+ # (So we can serve a safe /config JSON via our FastAPI shim)
10
+ # ============================================================
11
+ import os, runpy
12
+ if os.getenv("GRADIO_MOUNT_MODE") == "1":
13
+ runpy.run_module("app", run_name="__main__")
14
+ raise SystemExit
15
+
16
+ # ==== Runtime hygiene & paths (very high in file) ====
17
+ import sys
18
+ import logging
19
+ from pathlib import Path
20
+
21
+ # --- Sanitize OMP/BLAS threads early (avoids "libgomp: Invalid value..." issues)
22
+ def _sanitize_omp_env():
23
+ import multiprocessing as _mp
24
+ cpu = max(1, _mp.cpu_count())
25
+ default_omp = max(1, cpu // 2)
26
+
27
+ raw = os.environ.get("OMP_NUM_THREADS", "").strip()
28
+ try:
29
+ n = int(raw)
30
+ if n <= 0 or n > cpu * 2:
31
+ raise ValueError
32
+ omp_val = n
33
+ except Exception:
34
+ omp_val = default_omp
35
+ os.environ["OMP_NUM_THREADS"] = str(omp_val)
36
+
37
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
38
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
39
+ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
40
+
41
+ _sanitize_omp_env()
42
+
43
+ # Stable app dirs (avoid /tmp surprises on HF)
44
+ APP_ROOT = Path(__file__).resolve().parent
45
+ DATA_ROOT = APP_ROOT / "data"
46
+ TMP_ROOT = APP_ROOT / "tmp"
47
+ JOB_ROOT = TMP_ROOT / "backgroundfx_jobs"
48
+ for p in (DATA_ROOT, TMP_ROOT, JOB_ROOT):
49
+ p.mkdir(parents=True, exist_ok=True)
50
+
51
+ # Keep model/caches local to repo volume
52
+ os.environ.setdefault("HF_HOME", str(APP_ROOT / ".hf"))
53
+ os.environ.setdefault("TORCH_HOME", str(APP_ROOT / ".torch"))
54
+ Path(os.environ["HF_HOME"]).mkdir(parents=True, exist_ok=True)
55
+ Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
56
+
57
+ # Make Gradio a bit quieter / safer in Spaces
58
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
59
+
60
+ # Standard imports (after env is sane)
61
+ import torch
62
+ import gradio as gr
63
+
64
+ # Import our modules
65
+ from ui_core_functionality import startup_probe, logger
66
+ from ui_core_interface import create_interface
67
+
68
+ # Optional: patch a Gradio client util to tolerate boolean JSON Schemas
69
+ def _patch_gradio_client_bool_schema():
70
+ try:
71
+ import gradio_client.utils as _gc_utils # type: ignore
72
+ _orig_get_type = _gc_utils.get_type
73
+
74
+ def _safe_get_type(schema):
75
+ if isinstance(schema, bool):
76
+ return "Any" if schema else "None"
77
+ return _orig_get_type(schema)
78
+
79
+ _gc_utils.get_type = _safe_get_type # type: ignore[attr-defined]
80
+
81
+ if hasattr(_gc_utils, "_json_schema_to_python_type"):
82
+ _orig_walk = _gc_utils._json_schema_to_python_type # type: ignore[attr-defined]
83
+ def _safe_walk(schema, defs):
84
+ if isinstance(schema, bool):
85
+ return "Any" if schema else "None"
86
+ return _orig_walk(schema, defs)
87
+ _gc_utils._json_schema_to_python_type = _safe_walk # type: ignore[attr-defined]
88
+
89
+ logger.info("🩹 Patched gradio_client.utils to handle boolean JSON Schemas.")
90
+ except Exception as e:
91
+ logger.warning("Could not patch gradio_client boolean schema handling: %s", e)
92
+
93
+ _patch_gradio_client_bool_schema()
94
+
95
+ # =======================================================================
96
+ # MAIN APPLICATION
97
+ # =======================================================================
98
+
99
+ def main():
100
+ """Main application entry point"""
101
+ try:
102
+ startup_probe()
103
+
104
+ logger.info("πŸš€ Launching Gradio interface...")
105
+ logger.info(
106
+ "Gradio=%s | torch=%s | cu=%s | cuda_available=%s",
107
+ getattr(gr, "__version__", "?"),
108
+ torch.__version__,
109
+ getattr(torch.version, "cuda", None),
110
+ torch.cuda.is_available(),
111
+ )
112
+
113
+ demo = create_interface()
114
+
115
+ # Gradio 4.x: keep queue small to avoid RAM spikes (no concurrency_count here)
116
+ demo.queue(max_size=2)
117
+
118
+ # Port from env (HF sets PORT)
119
+ port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", "7860")))
120
+
121
+ # Detect HF Space; never use share=True on Spaces (avoids frpc download / 500s)
122
+ in_space = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") or os.getenv("SYSTEM") == "spaces")
123
+
124
+ demo.launch(
125
+ server_name="0.0.0.0",
126
+ server_port=port,
127
+ share=False if in_space else False, # keep False on Spaces
128
+ show_api=False, # safer on public Spaces
129
+ show_error=True,
130
+ quiet=True,
131
+ debug=False,
132
+ max_threads=1 # worker threads; per-listener concurrency set in ui_core_interface.py
133
+ )
134
+
135
+ except Exception as e:
136
+ logger.error("❌ Application startup failed: %s", e)
137
+ raise
138
+
139
+ if __name__ == "__main__":
140
+ main()
VideoBackgroundReplacer2/ui_core_functionality.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro β€” Core Functionality
4
+ All processing logic, utilities, background generators, and handlers
5
+ Enhanced with file safety, robust logging, and runtime diagnostics.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import io
11
+ import gc
12
+ import time
13
+ import json
14
+ import uuid
15
+ import shutil
16
+ import logging
17
+ import tempfile
18
+ import requests
19
+ import threading
20
+ import traceback
21
+ import subprocess
22
+ from datetime import datetime
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from typing import Optional, Tuple, List, Dict, Any, Union, Callable
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import numpy as np
29
+ from PIL import Image, ImageDraw, ImageFont
30
+ import cv2
31
+
32
+ # ==============================================================================
33
+ # PATHS & ENV
34
+ # ==============================================================================
35
+
36
+ # Repo root (…/app)
37
+ APP_ROOT = Path(__file__).resolve().parent
38
+ DATA_ROOT = APP_ROOT / "data"
39
+ TMP_ROOT = APP_ROOT / "tmp"
40
+ JOB_ROOT = TMP_ROOT / "backgroundfx_jobs"
41
+
42
+ for p in (
43
+ DATA_ROOT,
44
+ TMP_ROOT,
45
+ JOB_ROOT,
46
+ APP_ROOT / ".hf",
47
+ APP_ROOT / ".torch",
48
+ APP_ROOT / "checkpoints",
49
+ APP_ROOT / "models",
50
+ APP_ROOT / "utils",
51
+ ):
52
+ p.mkdir(parents=True, exist_ok=True)
53
+
54
+ # Cache dirs (stable on Spaces)
55
+ os.environ.setdefault("HF_HOME", str(APP_ROOT / ".hf"))
56
+ os.environ.setdefault("TORCH_HOME", str(APP_ROOT / ".torch"))
57
+
58
+ # Quiet BLAS/OpenMP spam (in case ui.py wasn't first)
59
+ if not os.environ.get("OMP_NUM_THREADS", "").isdigit():
60
+ os.environ["OMP_NUM_THREADS"] = "4"
61
+ os.environ.setdefault("OMP_NUM_THREADS", "4")
62
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
63
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
64
+ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
65
+ os.environ.setdefault("PYTHONFAULTHANDLER", "1")
66
+
67
+ # ==============================================================================
68
+ # LOGGING + DIAGNOSTICS (console + file + heartbeat)
69
+ # ==============================================================================
70
+
71
+ # Line-buffer logs so Space UI shows them promptly
72
+ try:
73
+ sys.stdout.reconfigure(line_buffering=True)
74
+ sys.stderr.reconfigure(line_buffering=True)
75
+ except Exception:
76
+ pass
77
+
78
+ LOG_FILE = DATA_ROOT / "run.log"
79
+ logging.basicConfig(
80
+ level=logging.INFO,
81
+ format="%(asctime)s | %(levelname)s | %(message)s",
82
+ handlers=[logging.StreamHandler(sys.stdout),
83
+ logging.FileHandler(LOG_FILE, encoding="utf-8")],
84
+ force=True,
85
+ )
86
+ logger = logging.getLogger("bgfx")
87
+
88
+ # Faulthandler (native crashes -> stacks)
89
+ try:
90
+ import faulthandler, signal # type: ignore
91
+ faulthandler.enable(all_threads=True)
92
+ if hasattr(signal, "SIGUSR1"):
93
+ faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True)
94
+ except Exception as e:
95
+ logger.warning("faulthandler setup skipped: %s", e)
96
+
97
+ def _disk_stats(p: Path) -> str:
98
+ try:
99
+ total, used, free = shutil.disk_usage(str(p))
100
+ mb = lambda x: x // (1024 * 1024)
101
+ return f"disk(total={mb(total)}MB, used={mb(used)}MB, free={mb(free)}MB)"
102
+ except Exception:
103
+ return "disk(n/a)"
104
+
105
+ def _cgroup_limit_bytes():
106
+ for fp in ("/sys/fs/cgroup/memory.max", "/sys/fs/cgroup/memory/memory.limit_in_bytes"):
107
+ try:
108
+ s = Path(fp).read_text().strip()
109
+ if s and s != "max":
110
+ return int(s)
111
+ except Exception:
112
+ pass
113
+
114
+ def _rss_bytes():
115
+ try:
116
+ for line in Path("/proc/self/status").read_text().splitlines():
117
+ if line.startswith("VmRSS:"):
118
+ return int(line.split()[1]) * 1024
119
+ except Exception:
120
+ return None
121
+
122
+ def _heartbeat():
123
+ lim = _cgroup_limit_bytes()
124
+ while True:
125
+ rss = _rss_bytes()
126
+ logger.info(
127
+ "HEARTBEAT | rss=%s MB | limit=%s MB | %s",
128
+ f"{rss//2**20}" if rss else "n/a",
129
+ f"{lim//2**20}" if lim else "n/a",
130
+ _disk_stats(APP_ROOT),
131
+ )
132
+ time.sleep(2)
133
+
134
+ # Start heartbeat as a daemon thread (only once)
135
+ try:
136
+ threading.Thread(target=_heartbeat, name="heartbeat", daemon=True).start()
137
+ except Exception as e:
138
+ logger.warning("heartbeat skipped: %s", e)
139
+
140
+ import atexit
141
+ @atexit.register
142
+ def _on_exit():
143
+ logger.info("PROCESS EXITING (atexit) β€” if you don't see this, it was a hard kill (OOM/SIGKILL)")
144
+
145
+ # ==============================================================================
146
+ # STARTUP VALIDATION
147
+ # ==============================================================================
148
+
149
+ def startup_probe():
150
+ """Comprehensive startup probe - validates system readiness"""
151
+ try:
152
+ logger.info("πŸš€ BACKGROUNDFX PRO STARTUP PROBE")
153
+ logger.info("πŸ“ Working directory: %s", os.getcwd())
154
+ logger.info("🐍 Python executable: %s", sys.executable)
155
+
156
+ # Write probe (fail fast if not writable)
157
+ probe_file = TMP_ROOT / "startup_probe.txt"
158
+ probe_file.write_text("startup_test_ok", encoding="utf-8")
159
+ assert probe_file.read_text(encoding="utf-8") == "startup_test_ok"
160
+ logger.info("βœ… WRITE PROBE OK: %s | %s", probe_file, _disk_stats(APP_ROOT))
161
+ probe_file.unlink(missing_ok=True)
162
+
163
+ # GPU/Torch status
164
+ try:
165
+ logger.info("πŸ”§ Torch=%s | cu=%s | cuda_available=%s",
166
+ torch.__version__, getattr(torch.version, "cuda", None), torch.cuda.is_available())
167
+ if torch.cuda.is_available():
168
+ gpu_count = torch.cuda.device_count()
169
+ name = torch.cuda.get_device_name(0) if gpu_count else "Unknown"
170
+ vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) if gpu_count else 0
171
+ logger.info("πŸ”₯ GPU Available: %s (%d device(s)) β€” VRAM %.1f GB", name, gpu_count, vram_gb)
172
+ else:
173
+ logger.warning("⚠️ No GPU available β€” using CPU")
174
+ except Exception as e:
175
+ logger.warning("⚠️ Torch check failed: %s", e)
176
+
177
+ # Directory verification (and creation if missing)
178
+ for d in ("checkpoints", "models", "utils"):
179
+ dp = APP_ROOT / d
180
+ dp.mkdir(parents=True, exist_ok=True)
181
+ logger.info("βœ… Directory %s: %s", d, dp)
182
+
183
+ # Job dir isolation test
184
+ test_job = JOB_ROOT / "startup_test_job"
185
+ test_job.mkdir(parents=True, exist_ok=True)
186
+ tfile = test_job / "test.tmp"
187
+ tfile.write_text("job_isolation_test")
188
+ assert tfile.read_text() == "job_isolation_test"
189
+ logger.info("βœ… Job isolation directory ready: %s", JOB_ROOT)
190
+ shutil.rmtree(test_job, ignore_errors=True)
191
+
192
+ # Env summary
193
+ logger.info("🌍 Env: OMP_NUM_THREADS=%s | HF_HOME=%s | TORCH_HOME=%s",
194
+ os.environ.get("OMP_NUM_THREADS", "unset"),
195
+ os.environ.get("HF_HOME", "default"),
196
+ os.environ.get("TORCH_HOME", "default"))
197
+
198
+ logger.info("🎯 Startup probe completed β€” system ready!")
199
+
200
+ except Exception as e:
201
+ logger.error("❌ STARTUP PROBE FAILED: %s", e)
202
+ logger.error("πŸ“Š %s", _disk_stats(APP_ROOT))
203
+ raise RuntimeError(f"Startup probe failed β€” system not ready: {e}") from e
204
+
205
+ # ==============================================================================
206
+ # FILE SAFETY UTILITIES
207
+ # ==============================================================================
208
+
209
+ def new_tmp_path(suffix: str) -> Path:
210
+ """Generate safe temporary path within TMP_ROOT"""
211
+ return TMP_ROOT / f"{uuid.uuid4().hex}{suffix}"
212
+
213
+ def atomic_write_bytes(dst: Path, data: bytes):
214
+ """Atomic file write to prevent corruption"""
215
+ tmp = new_tmp_path(dst.suffix + ".part")
216
+ try:
217
+ with open(tmp, "wb") as f:
218
+ f.write(data)
219
+ tmp.replace(dst) # atomic on same FS
220
+ logger.debug("βœ… Atomic write: %s", dst)
221
+ except Exception as e:
222
+ if tmp.exists():
223
+ tmp.unlink(missing_ok=True)
224
+ raise e
225
+
226
+ def safe_name(name: str, default="file") -> str:
227
+ """Sanitize filename to prevent traversal/unicode issues"""
228
+ import re
229
+ base = re.sub(r"[^A-Za-z0-9._-]+", "_", (name or default))
230
+ return base[:120] or default
231
+
232
+ def place_uploaded(in_path: str, sub="uploads") -> Path:
233
+ """Safely handle uploaded files with sanitized names"""
234
+ target_dir = DATA_ROOT / sub
235
+ target_dir.mkdir(exist_ok=True, parents=True)
236
+ out = target_dir / safe_name(Path(in_path).name)
237
+ shutil.copy2(in_path, out)
238
+ logger.info("πŸ“ Uploaded file placed: %s", out)
239
+ return out
240
+
241
+ def tmp_video_path(ext=".mp4") -> Path:
242
+ return new_tmp_path(ext)
243
+
244
+ def tmp_image_path(ext=".png") -> Path:
245
+ return new_tmp_path(ext)
246
+
247
+ def run_safely(fn: Callable, *args, **kwargs):
248
+ """Execute function with comprehensive error logging"""
249
+ try:
250
+ return fn(*args, **kwargs)
251
+ except Exception:
252
+ logger.error("PROCESSING FAILED\n%s", "".join(traceback.format_exc()))
253
+ logger.error("CWD=%s | DATA_ROOT=%s | TMP_ROOT=%s | %s",
254
+ os.getcwd(), DATA_ROOT, TMP_ROOT, _disk_stats(APP_ROOT))
255
+ try:
256
+ logger.error("Env: OMP_NUM_THREADS=%s | CUDA=%s | torch=%s | cu=%s",
257
+ os.environ.get("OMP_NUM_THREADS"),
258
+ os.environ.get("CUDA_VISIBLE_DEVICES", "default"),
259
+ torch.__version__,
260
+ getattr(torch.version, "cuda", None))
261
+ except Exception:
262
+ pass
263
+ raise
264
+
265
+ # ==============================================================================
266
+ # SYSTEM UTILITIES
267
+ # ==============================================================================
268
+
269
+ def get_device():
270
+ """Get optimal device for processing"""
271
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
272
+
273
+ def clear_gpu_memory():
274
+ """Aggressive GPU memory cleanup"""
275
+ try:
276
+ if torch.cuda.is_available():
277
+ torch.cuda.empty_cache()
278
+ torch.cuda.synchronize()
279
+ gc.collect()
280
+ logger.info("🧹 GPU memory cleared")
281
+ except Exception as e:
282
+ logger.warning("GPU cleanup warning: %s", e)
283
+
284
+ def safe_file_operation(operation: Callable, *args, max_retries: int = 3, **kwargs):
285
+ """Safely execute file operations with retries"""
286
+ last_error = None
287
+ for attempt in range(max_retries):
288
+ try:
289
+ return operation(*args, **kwargs)
290
+ except Exception as e:
291
+ last_error = e
292
+ if attempt < max_retries - 1:
293
+ time.sleep(0.1 * (attempt + 1))
294
+ logger.warning("File op retry %d: %s", attempt + 1, e)
295
+ else:
296
+ logger.error("File op failed after %d attempts: %s", max_retries, e)
297
+ raise last_error
298
+
299
+ # ==============================================================================
300
+ # BACKGROUND GENERATORS
301
+ # ==============================================================================
302
+
303
+ def generate_ai_background(prompt: str, width: int, height: int) -> Image.Image:
304
+ """Generate AI-like background using prompt cues (procedural)"""
305
+ try:
306
+ logger.info("Generating AI background: '%s' (%dx%d)", prompt, width, height)
307
+ img = np.zeros((height, width, 3), dtype=np.uint8)
308
+ prompt_lower = prompt.lower()
309
+
310
+ if any(w in prompt_lower for w in ('city', 'urban', 'futuristic', 'cyberpunk')):
311
+ for i in range(height):
312
+ r = int(20 + 80 * (i / height))
313
+ g = int(30 + 100 * (i / height))
314
+ b = int(60 + 120 * (i / height))
315
+ img[i, :] = [r, g, b]
316
+ elif any(w in prompt_lower for w in ('beach', 'tropical', 'ocean', 'sea')):
317
+ for i in range(height):
318
+ r = int(135 + 120 * (i / height))
319
+ g = int(206 + 49 * (i / height))
320
+ b = int(235 + 20 * (i / height))
321
+ img[i, :] = [r, g, b]
322
+ elif any(w in prompt_lower for w in ('forest', 'jungle', 'nature', 'green')):
323
+ for i in range(height):
324
+ r = int(34 + 105 * (i / height))
325
+ g = int(139 + 30 * (i / height))
326
+ b = int(34 - 15 * (i / height))
327
+ img[i, :] = [max(0, r), max(0, g), max(0, b)]
328
+ elif any(w in prompt_lower for w in ('space', 'galaxy', 'stars', 'cosmic')):
329
+ for i in range(height):
330
+ r = int(10 + 50 * (i / height))
331
+ g = int(0 + 30 * (i / height))
332
+ b = int(30 + 100 * (i / height))
333
+ img[i, :] = [r, g, b]
334
+ elif any(w in prompt_lower for w in ('desert', 'sand', 'canyon')):
335
+ for i in range(height):
336
+ r = int(238 + 17 * (i / height))
337
+ g = int(203 + 52 * (i / height))
338
+ b = int(173 + 82 * (i / height))
339
+ img[i, :] = [min(255, r), min(255, g), min(255, b)]
340
+ else:
341
+ colors = [(255, 182, 193), (255, 218, 185), (176, 224, 230)]
342
+ color = colors[len(prompt) % len(colors)]
343
+ for i in range(height):
344
+ t = 1 - (i / height) * 0.3
345
+ img[i, :] = [int(color[0] * t), int(color[1] * t), int(color[2] * t)]
346
+
347
+ noise = np.random.randint(-15, 15, (height, width, 3))
348
+ img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)
349
+ return Image.fromarray(img)
350
+
351
+ except Exception as e:
352
+ logger.warning("AI background generation failed: %s β€” using fallback", e)
353
+ return create_gradient_background("sunset", width, height)
354
+
355
+ def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
356
+ img = np.zeros((height, width, 3), dtype=np.uint8)
357
+ gradients = {
358
+ "sunset": [(255, 165, 0), (128, 64, 128)],
359
+ "ocean": [(0, 100, 255), (30, 144, 255)],
360
+ "forest": [(34, 139, 34), (139, 69, 19)],
361
+ "sky": [(135, 206, 235), (206, 235, 255)],
362
+ }
363
+ if gradient_type in gradients:
364
+ start, end = gradients[gradient_type]
365
+ for i in range(height):
366
+ r = int(start[0] * (1 - i/height) + end[0] * (i/height))
367
+ g = int(start[1] * (1 - i/height) + end[1] * (i/height))
368
+ b = int(start[2] * (1 - i/height) + end[2] * (i/height))
369
+ img[i, :] = [r, g, b]
370
+ else:
371
+ img.fill(128)
372
+ return Image.fromarray(img)
373
+
374
+ def create_solid_background(color: str, width: int, height: int) -> Image.Image:
375
+ color_map = {
376
+ "white": (255, 255, 255), "black": (0, 0, 0), "red": (255, 0, 0),
377
+ "green": (0, 255, 0), "blue": (0, 0, 255), "yellow": (255, 255, 0),
378
+ "purple": (128, 0, 128), "orange": (255, 165, 0), "pink": (255, 192, 203),
379
+ "gray": (128, 128, 128)
380
+ }
381
+ rgb = color_map.get(color.lower(), (128, 128, 128))
382
+ return Image.new("RGB", (width, height), rgb)
383
+
384
+ def download_unsplash_image(query: str, width: int, height: int) -> Image.Image:
385
+ try:
386
+ url = f"https://source.unsplash.com/{width}x{height}/?{query}"
387
+ resp = requests.get(url, timeout=10)
388
+ resp.raise_for_status()
389
+ img = Image.open(io.BytesIO(resp.content))
390
+ if img.size != (width, height):
391
+ img = img.resize((width, height), Image.Resampling.LANCZOS)
392
+ return img.convert("RGB")
393
+ except Exception as e:
394
+ logger.warning("Unsplash download failed: %s", e)
395
+ return create_solid_background("gray", width, height)
396
+
397
+ # ==============================================================================
398
+ # VIDEO UTILITIES
399
+ # ==============================================================================
400
+
401
+ def get_video_info(video_path: str) -> Dict[str, Any]:
402
+ try:
403
+ cap = cv2.VideoCapture(video_path)
404
+ if not cap.isOpened():
405
+ raise ValueError("Cannot open video file")
406
+ fps = cap.get(cv2.CAP_PROP_FPS)
407
+ frames= int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
408
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
409
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
410
+ cap.release()
411
+ return {"fps": fps, "frame_count": frames, "width": w, "height": h,
412
+ "duration": (frames / fps if fps > 0 else 0)}
413
+ except Exception as e:
414
+ logger.error("get_video_info failed: %s", e)
415
+ return {"fps": 30.0, "frame_count": 0, "width": 1920, "height": 1080, "duration": 0}
416
+
417
+ def extract_frame(video_path: str, frame_number: int) -> Optional[np.ndarray]:
418
+ try:
419
+ cap = cv2.VideoCapture(video_path)
420
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
421
+ ret, frame = cap.read()
422
+ cap.release()
423
+ if ret:
424
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
425
+ return None
426
+ except Exception as e:
427
+ logger.error("extract_frame failed: %s", e)
428
+ return None
429
+
430
+ def ffmpeg_safe_call(inp: Path, out: Path, extra=()):
431
+ cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error", "-i", str(inp), *extra, str(out)]
432
+ logger.info("FFMPEG %s", " ".join(cmd))
433
+ subprocess.run(cmd, check=True, timeout=300)
434
+
435
+ # ==============================================================================
436
+ # PROGRESS TRACKING
437
+ # ==============================================================================
438
+
439
+ class ProgressTracker:
440
+ """Thread-safe progress tracking for video processing"""
441
+ def __init__(self):
442
+ self.current_step = ""
443
+ self.progress = 0.0
444
+ self.total_frames = 0
445
+ self.processed_frames = 0
446
+ self.start_time = time.time()
447
+ self.lock = threading.Lock()
448
+
449
+ def update(self, step: str, progress: float = None):
450
+ with self.lock:
451
+ self.current_step = step
452
+ if progress is not None:
453
+ self.progress = max(0.0, min(1.0, progress))
454
+
455
+ def update_frames(self, processed: int, total: int = None):
456
+ with self.lock:
457
+ self.processed_frames = processed
458
+ if total is not None:
459
+ self.total_frames = total
460
+ if self.total_frames > 0:
461
+ self.progress = self.processed_frames / self.total_frames
462
+
463
+ def get_status(self) -> Dict[str, Any]:
464
+ with self.lock:
465
+ elapsed = time.time() - self.start_time
466
+ eta = 0
467
+ if self.progress > 0.01:
468
+ eta = elapsed * (1.0 - self.progress) / self.progress
469
+ return {
470
+ "step": self.current_step, "progress": self.progress,
471
+ "processed_frames": self.processed_frames, "total_frames": self.total_frames,
472
+ "elapsed": elapsed, "eta": eta
473
+ }
474
+
475
+ # Global tracker
476
+ progress_tracker = ProgressTracker()
477
+
478
+ # ==============================================================================
479
+ # SAFE FILE OPS
480
+ # ==============================================================================
481
+
482
+ def create_job_directory() -> Path:
483
+ job_id = str(uuid.uuid4())[:8]
484
+ job_dir = JOB_ROOT / f"job_{job_id}_{int(time.time())}"
485
+ job_dir.mkdir(parents=True, exist_ok=True)
486
+ logger.info("πŸ“ Created job directory: %s", job_dir)
487
+ return job_dir
488
+
489
+ def atomic_file_write(filepath: Path, content: bytes):
490
+ # Use with_name to append ".tmp" without breaking pathlib rules
491
+ temp_path = filepath.with_name(f"{filepath.name}.tmp")
492
+ try:
493
+ with open(temp_path, 'wb') as f:
494
+ f.write(content)
495
+ temp_path.rename(filepath)
496
+ logger.debug("βœ… Atomic write: %s", filepath)
497
+ except Exception as e:
498
+ if temp_path.exists():
499
+ temp_path.unlink(missing_ok=True)
500
+ raise e
501
+
502
+ def safe_download(url: str, filepath: Path, max_size: int = 500 * 1024 * 1024):
503
+ # Use with_name to append ".download" safely (e.g., "video.mp4.download")
504
+ temp_path = filepath.with_name(f"{filepath.name}.download")
505
+
506
+ try:
507
+ r = requests.get(url, stream=True, timeout=30)
508
+ r.raise_for_status()
509
+ cl = r.headers.get('content-length')
510
+ if cl and int(cl) > max_size:
511
+ raise ValueError(f"File too large: {cl} bytes")
512
+
513
+ downloaded = 0
514
+ with open(temp_path, 'wb') as f:
515
+ for chunk in r.iter_content(chunk_size=8192):
516
+ if chunk:
517
+ downloaded += len(chunk)
518
+ if downloaded > max_size:
519
+ raise ValueError(f"Download exceeded size limit: {downloaded} bytes")
520
+ f.write(chunk)
521
+
522
+ if not temp_path.exists() or temp_path.stat().st_size == 0:
523
+ raise ValueError("Download resulted in empty file")
524
+
525
+ temp_path.rename(filepath)
526
+ logger.info("βœ… Downloaded: %s (%d bytes)", filepath, downloaded)
527
+
528
+ except Exception as e:
529
+ if temp_path.exists():
530
+ temp_path.unlink(missing_ok=True)
531
+ logger.error("❌ Download failed: %s", e)
532
+ raise
533
+
534
+ # ==============================================================================
535
+ # ENHANCED PIPELINE INTEGRATION
536
+ # ==============================================================================
537
+
538
+ def process_video_pipeline(
539
+ video_path: str,
540
+ background_image: Optional[Image.Image],
541
+ background_type: str,
542
+ background_prompt: str,
543
+ job_dir: Path,
544
+ progress_callback: Optional[Callable] = None
545
+ ) -> str:
546
+ """Process video using the two-stage pipeline with enhanced safety and monitoring"""
547
+
548
+ def _inner_process():
549
+ logger.info("=" * 60)
550
+ logger.info("=== ENHANCED TWO-STAGE PIPELINE (WITH SAFETY) ===")
551
+ logger.info("=" * 60)
552
+
553
+ logger.info("DEBUG video_path=%s exists=%s size=%s bytes",
554
+ video_path, Path(video_path).exists(),
555
+ (Path(video_path).stat().st_size if Path(video_path).exists() else "N/A"))
556
+ logger.info("DEBUG job_dir=%s writable=%s", job_dir, os.access(job_dir, os.W_OK))
557
+ logger.info("DEBUG bg_image=%s bg_type=%s | %s",
558
+ (background_image.size if background_image else None),
559
+ background_type, _disk_stats(APP_ROOT))
560
+
561
+ if not Path(video_path).exists():
562
+ raise FileNotFoundError(f"Video file not found: {video_path}")
563
+
564
+ # Copy into controlled area
565
+ safe_video_path = place_uploaded(video_path, "videos")
566
+ logger.info("DEBUG safe_video_path=%s", safe_video_path)
567
+
568
+ logger.info("DEBUG importing two-stage pipeline…")
569
+ try:
570
+ from two_stage_pipeline import process_two_stage as pipeline_process
571
+ logger.info("βœ“ two-stage pipeline import OK")
572
+ except ImportError as e:
573
+ logger.error("Import two_stage_pipeline failed: %s", e)
574
+ raise
575
+
576
+ progress_tracker.update("Initializing enhanced two-stage pipeline…")
577
+
578
+ current_stage = {"stage": "init", "start_time": time.time()}
579
+
580
+ def safe_progress_callback(step: str, progress: float = None):
581
+ try:
582
+ now = time.time()
583
+ elapsed = now - current_stage["start_time"]
584
+
585
+ if "Stage 1" in step and current_stage["stage"] != "stage1":
586
+ current_stage["stage"] = "stage1"
587
+ current_stage["start_time"] = now
588
+ logger.info("πŸ”„ Entering Stage 1 (SAM2) | %s", _disk_stats(APP_ROOT))
589
+ elif "Stage 2" in step and current_stage["stage"] != "stage2":
590
+ d1 = now - current_stage["start_time"]
591
+ current_stage["stage"] = "stage2"
592
+ current_stage["start_time"] = now
593
+ logger.info("πŸ”„ Entering Stage 2 (Composition) β€” Stage 1 time %.1fs | %s", d1, _disk_stats(APP_ROOT))
594
+ elif "Done" in step and current_stage["stage"] != "complete":
595
+ d2 = now - current_stage["start_time"]
596
+ current_stage["stage"] = "complete"
597
+ logger.info("πŸ”„ Pipeline complete β€” Stage 2 time %.1fs | %s", d2, _disk_stats(APP_ROOT))
598
+
599
+ logger.info("PROGRESS [%s] (%.1fs): %s (%s)",
600
+ current_stage['stage'].upper(), elapsed, step, progress)
601
+ progress_tracker.update(step, progress)
602
+
603
+ if progress_callback:
604
+ progress_callback(f"Progress: {progress:.1%} - {step}" if progress is not None else step)
605
+
606
+ if current_stage["stage"] == "stage1" and elapsed > 15:
607
+ logger.warning("⚠️ Stage 1 running for %.1fs β€” monitoring memory", elapsed)
608
+
609
+ except Exception as e:
610
+ logger.error("Progress callback error: %s", e)
611
+
612
+ if background_image is None:
613
+ raise ValueError("Background image is required")
614
+
615
+ logger.info("DEBUG: calling two-stage pipeline…")
616
+ result_path = pipeline_process(
617
+ video_path=str(safe_video_path),
618
+ background_image=background_image,
619
+ workdir=job_dir,
620
+ progress=safe_progress_callback,
621
+ use_matany=True
622
+ )
623
+
624
+ logger.info("DEBUG: pipeline returned %s (%s)", result_path, type(result_path))
625
+
626
+ if result_path:
627
+ result_file = Path(result_path)
628
+ logger.info("DEBUG: result exists=%s", result_file.exists())
629
+ if result_file.exists():
630
+ size = result_file.stat().st_size
631
+ logger.info("DEBUG: result size=%d bytes", size)
632
+ if size == 0:
633
+ raise RuntimeError("Pipeline produced empty output file")
634
+
635
+ # Quick validity check
636
+ try:
637
+ cap = cv2.VideoCapture(str(result_file))
638
+ if cap.isOpened():
639
+ frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
640
+ logger.info("DEBUG: output frame_count=%d", frames)
641
+ cap.release()
642
+ else:
643
+ logger.warning("⚠️ Output may not be a valid video (cannot open)")
644
+ except Exception as e:
645
+ logger.warning("⚠️ Could not verify output video: %s", e)
646
+
647
+ if not result_path or not Path(result_path).exists():
648
+ raise RuntimeError("Two-stage pipeline failed β€” no output produced")
649
+
650
+ logger.info("=" * 60)
651
+ logger.info("βœ… ENHANCED TWO-STAGE PIPELINE COMPLETED: %s", result_path)
652
+ logger.info("=" * 60)
653
+ return result_path
654
+
655
+ try:
656
+ return run_safely(_inner_process)
657
+ except Exception as e:
658
+ logger.error("🧹 Error cleanup…")
659
+ clear_gpu_memory()
660
+ logger.error("Job dir state: %s",
661
+ (list(job_dir.iterdir()) if job_dir.exists() else "does not exist"))
662
+ raise
VideoBackgroundReplacer2/ui_core_interface.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BackgroundFX Pro β€” Gradio Interface & Event Handlers
4
+ UI components, event handlers, and interface creation
5
+ """
6
+
7
+ import logging
8
+ import shutil
9
+ import traceback
10
+ from typing import Optional, Tuple
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import gradio as gr
15
+ from PIL import Image
16
+
17
+ # Import our functionality
18
+ from ui_core_functionality import (
19
+ get_device, clear_gpu_memory, get_video_info, extract_frame,
20
+ create_gradient_background, create_solid_background, download_unsplash_image,
21
+ generate_ai_background, create_job_directory, safe_file_operation, process_video_pipeline,
22
+ progress_tracker, JOB_ROOT, APP_ROOT, logger
23
+ )
24
+
25
+ # ===============================================================================
26
+ # GRADIO HANDLERS
27
+ # ===============================================================================
28
+
29
+ def handle_custom_background_upload(image: Optional[Image.Image]) -> Tuple[Optional[Image.Image], str]:
30
+ """Handle custom background image upload"""
31
+ if image is None:
32
+ return None, "No image uploaded"
33
+ try:
34
+ if image.mode != "RGB":
35
+ image = image.convert("RGB")
36
+ status = f"βœ… Custom background uploaded: {image.size[0]}x{image.size[1]}"
37
+ logger.info(status)
38
+ return image, status
39
+ except Exception as e:
40
+ error_msg = f"❌ Background upload failed: {str(e)}"
41
+ logger.error(error_msg)
42
+ return None, error_msg
43
+
44
+ def handle_background_type_change(bg_type: str):
45
+ """Handle background type selection - show/hide relevant controls"""
46
+ logger.info(f"🎨 Background type changed to: {bg_type}")
47
+ if bg_type == "upload":
48
+ return (
49
+ gr.update(visible=True, label="Upload Custom Background Image"),
50
+ gr.update(visible=False),
51
+ gr.update(visible=False),
52
+ )
53
+ else:
54
+ prompt_placeholder = {
55
+ "ai_generate": "Describe the scene: 'futuristic city', 'tropical beach', 'mystical forest'...",
56
+ "gradient": "Choose style: 'sunset', 'ocean', 'forest', 'sky'",
57
+ "solid": "Choose color: 'red', 'blue', 'green', 'white', 'black'...",
58
+ "unsplash": "Search query: 'mountain landscape', 'city skyline', 'nature'..."
59
+ }
60
+ return (
61
+ gr.update(visible=False),
62
+ gr.update(visible=True, placeholder=prompt_placeholder.get(bg_type, "Enter your prompt...")),
63
+ gr.update(visible=True, value=f"Generate {bg_type.replace('_', ' ').title()} Background"),
64
+ )
65
+
66
+ def handle_video_upload(video_file) -> Tuple[Optional[str], str]:
67
+ """Handle video file upload"""
68
+ if video_file is None:
69
+ return None, "No video file provided"
70
+ try:
71
+ job_dir = create_job_directory()
72
+ # Preserve original extension if possible
73
+ src_path = Path(video_file)
74
+ ext = src_path.suffix if src_path.suffix else ".mp4"
75
+ video_path = job_dir / f"input_video{ext}"
76
+ safe_file_operation(lambda src, dst: shutil.copy2(src, dst), str(src_path), str(video_path))
77
+
78
+ info = get_video_info(str(video_path))
79
+ duration_text = f"{info['duration']:.1f}s"
80
+ status = f"βœ… Video uploaded: {info['width']}x{info['height']}, {info['fps']:.1f}fps, {duration_text}"
81
+ logger.info(status)
82
+ return str(video_path), status
83
+ except Exception as e:
84
+ error_msg = f"❌ Video upload failed: {str(e)}"
85
+ logger.error(error_msg)
86
+ return None, error_msg
87
+
88
+ def handle_background_generation(bg_type: str, bg_prompt: str, video_path: str) -> Tuple[Optional[Image.Image], str]:
89
+ """Handle background generation (for non-upload types)"""
90
+ if not video_path:
91
+ return None, "No video loaded"
92
+ if bg_type == "upload":
93
+ return None, "Use the upload field above for custom backgrounds"
94
+
95
+ try:
96
+ info = get_video_info(video_path)
97
+ width, height = info['width'], info['height']
98
+
99
+ if bg_type == "ai_generate":
100
+ background = generate_ai_background(bg_prompt, width, height)
101
+ status = f"βœ… Generated AI background: '{bg_prompt}'"
102
+
103
+ elif bg_type == "gradient":
104
+ gradients = ["sunset", "ocean", "forest", "sky"]
105
+ gradient_type = next((g for g in gradients if g in bg_prompt.lower()), gradients[0])
106
+ background = create_gradient_background(gradient_type, width, height)
107
+ status = f"βœ… Generated {gradient_type} gradient background"
108
+
109
+ elif bg_type == "solid":
110
+ colors = ["white", "black", "red", "green", "blue", "yellow", "purple", "orange", "pink", "gray"]
111
+ color = next((c for c in colors if c in bg_prompt.lower()), "white")
112
+ background = create_solid_background(color, width, height)
113
+ status = f"βœ… Generated {color} solid background"
114
+
115
+ elif bg_type == "unsplash":
116
+ query = bg_prompt.strip() or "nature"
117
+ background = download_unsplash_image(query, width, height)
118
+ status = f"βœ… Downloaded background from Unsplash: '{query}'"
119
+
120
+ else:
121
+ background = create_solid_background("gray", width, height)
122
+ status = "βœ… Generated default gray background"
123
+
124
+ logger.info(status)
125
+ return background, status
126
+
127
+ except Exception as e:
128
+ error_msg = f"❌ Background generation failed: {str(e)}"
129
+ logger.error(error_msg)
130
+ return None, error_msg
131
+
132
+ def handle_video_processing(
133
+ video_path: str,
134
+ background_image: Optional[Image.Image],
135
+ background_type: str,
136
+ background_prompt: str,
137
+ progress=gr.Progress()
138
+ ) -> Tuple[Optional[str], str]:
139
+ """Handle complete video processing"""
140
+ if not video_path:
141
+ return None, "❌ No video provided"
142
+ if not background_image:
143
+ return None, "❌ No background provided"
144
+
145
+ try:
146
+ progress(0, "Starting video processing...")
147
+ logger.info("🎬 Starting video processing")
148
+
149
+ job_dir = create_job_directory()
150
+ progress_tracker.update("Creating job directory...")
151
+
152
+ def update_progress(message: str):
153
+ try:
154
+ status = progress_tracker.get_status()
155
+ progress_val = status['progress']
156
+ progress(progress_val, message)
157
+ logger.info(f"Progress: {progress_val:.1%} - {message}")
158
+ except Exception as e:
159
+ logger.warning(f"Progress update failed: {e}")
160
+
161
+ result_path = process_video_pipeline(
162
+ video_path=video_path,
163
+ background_image=background_image,
164
+ background_type=background_type,
165
+ background_prompt=background_prompt,
166
+ job_dir=job_dir,
167
+ progress_callback=update_progress
168
+ )
169
+
170
+ progress(1.0, "Processing complete!")
171
+ clear_gpu_memory()
172
+
173
+ status = "βœ… Video processing completed successfully!"
174
+ logger.info(status)
175
+ return result_path, status
176
+
177
+ except Exception as e:
178
+ error_msg = f"❌ Processing failed: {str(e)}"
179
+ logger.error(error_msg)
180
+ logger.error("Traceback: %s", traceback.format_exc())
181
+ clear_gpu_memory()
182
+ return None, error_msg
183
+
184
+ def handle_preview_generation(video_path: str, frame_number: int = 0) -> Tuple[Optional[Image.Image], str]:
185
+ """Generate preview frame from video"""
186
+ if not video_path:
187
+ return None, "No video loaded"
188
+ try:
189
+ frame = extract_frame(video_path, frame_number)
190
+ if frame is None:
191
+ return None, "Failed to extract frame"
192
+ preview_image = Image.fromarray(frame)
193
+ return preview_image, f"βœ… Preview generated (frame {frame_number})"
194
+ except Exception as e:
195
+ error_msg = f"❌ Preview generation failed: {str(e)}"
196
+ logger.error(error_msg)
197
+ return None, error_msg
198
+
199
+ # ===============================================================================
200
+ # GRADIO INTERFACE
201
+ # ===============================================================================
202
+
203
+ def create_interface():
204
+ """Create the main Gradio interface"""
205
+
206
+ custom_css = """
207
+ .container { max-width: 1200px; margin: auto; }
208
+ .header { text-align: center; margin-bottom: 30px; }
209
+ .section { margin: 20px 0; padding: 20px; border-radius: 10px; }
210
+ .status { font-family: monospace; font-size: 12px; }
211
+ .progress-bar { margin: 10px 0; }
212
+ """
213
+
214
+ with gr.Blocks(
215
+ title="BackgroundFX Pro",
216
+ css=custom_css,
217
+ theme=gr.themes.Soft(),
218
+ analytics_enabled=False, # keep things quiet/stable on 4.x
219
+ ) as demo:
220
+
221
+ gr.HTML("""
222
+ <div class="header">
223
+ <h1>🎬 BackgroundFX Pro</h1>
224
+ <p>Professional AI-powered video background replacement using SAM2 and MatAnyone</p>
225
+ </div>
226
+ """)
227
+
228
+ video_path_state = gr.State(value=None)
229
+ background_image_state = gr.State(value=None)
230
+
231
+ with gr.Row():
232
+ with gr.Column(scale=1):
233
+ with gr.Group():
234
+ gr.HTML("<h3>πŸ“Ή Video Input</h3>")
235
+ video_upload = gr.File(
236
+ label="Upload Video",
237
+ file_types=[".mp4", ".avi", ".mov", ".mkv"],
238
+ type="filepath"
239
+ )
240
+ video_preview = gr.Image(
241
+ label="Video Preview",
242
+ interactive=False,
243
+ height=300
244
+ )
245
+ # Fixed preview status box (hidden)
246
+ preview_status = gr.Textbox(
247
+ label="Preview Status",
248
+ interactive=False,
249
+ visible=False,
250
+ elem_classes=["status"]
251
+ )
252
+ video_status = gr.Textbox(
253
+ label="Video Status",
254
+ interactive=False,
255
+ elem_classes=["status"]
256
+ )
257
+
258
+ with gr.Group():
259
+ gr.HTML("<h3>🎨 Background Selection</h3>")
260
+
261
+ gr.HTML("""
262
+ <div style='background: #f0f8ff; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>
263
+ <b>Choose your background method:</b><br>
264
+ β€’ <b>Upload:</b> Use your own image<br>
265
+ β€’ <b>AI Generate:</b> Create with AI prompt<br>
266
+ β€’ <b>Gradient/Solid/Unsplash:</b> Quick generation
267
+ </div>
268
+ """)
269
+
270
+ background_type = gr.Radio(
271
+ choices=[
272
+ ("πŸ“€ Upload Image", "upload"),
273
+ ("πŸ€– AI Generate", "ai_generate"),
274
+ ("🌈 Gradient", "gradient"),
275
+ ("🎯 Solid Color", "solid"),
276
+ ("πŸ“Έ Unsplash Photo", "unsplash")
277
+ ],
278
+ label="Background Type",
279
+ value="upload"
280
+ )
281
+
282
+ custom_bg_upload = gr.Image(
283
+ label="Upload Custom Background",
284
+ type="pil",
285
+ interactive=True,
286
+ height=250,
287
+ visible=True
288
+ )
289
+
290
+ background_prompt = gr.Textbox(
291
+ label="Background Prompt",
292
+ placeholder=("AI: 'futuristic city', 'tropical beach' | Gradient: 'sunset', 'ocean' | "
293
+ "Solid: 'red', 'blue' | Unsplash: 'mountain landscape'"),
294
+ value="futuristic city skyline at sunset",
295
+ visible=False
296
+ )
297
+
298
+ generate_bg_btn = gr.Button(
299
+ "Generate Background",
300
+ variant="secondary",
301
+ )
302
+
303
+ background_preview = gr.Image(
304
+ label="Background Preview",
305
+ interactive=False,
306
+ height=300
307
+ )
308
+
309
+ background_status = gr.Textbox(
310
+ label="Background Status",
311
+ interactive=False,
312
+ elem_classes=["status"]
313
+ )
314
+
315
+ with gr.Column(scale=1):
316
+ with gr.Group():
317
+ gr.HTML("<h3>⚑ Processing</h3>")
318
+
319
+ process_btn = gr.Button(
320
+ "πŸš€ Process Video",
321
+ variant="primary",
322
+ )
323
+
324
+ processing_status = gr.Textbox(
325
+ label="Processing Status",
326
+ interactive=False,
327
+ elem_classes=["status"]
328
+ )
329
+
330
+ with gr.Group():
331
+ gr.HTML("<h3>πŸ“½οΈ Results</h3>")
332
+
333
+ result_video = gr.Video(
334
+ label="Processed Video",
335
+ height=400
336
+ )
337
+
338
+ # Real downloadable output
339
+ download_btn = gr.DownloadButton(
340
+ "πŸ“₯ Download Result",
341
+ visible=False
342
+ )
343
+
344
+ with gr.Accordion("πŸ”§ System Information", open=False):
345
+ system_info = gr.HTML(f"""
346
+ <div class="system-info">
347
+ <p><strong>Device:</strong> {get_device()}</p>
348
+ <p><strong>Torch Version:</strong> {torch.__version__}</p>
349
+ <p><strong>CUDA Available:</strong> {torch.cuda.is_available()}</p>
350
+ <p><strong>Job Directory:</strong> {JOB_ROOT}</p>
351
+ <p><strong>App Root:</strong> {APP_ROOT}</p>
352
+ </div>
353
+ """)
354
+
355
+ # =========================
356
+ # Event Handlers (4.42.x)
357
+ # =========================
358
+
359
+ # Lightweight; no queue needed
360
+ background_type.change(
361
+ fn=handle_background_type_change,
362
+ inputs=[background_type],
363
+ outputs=[custom_bg_upload, background_prompt, generate_bg_btn],
364
+ queue=False,
365
+ concurrency_limit=4,
366
+ )
367
+
368
+ # Small, immediate state update; no queue
369
+ custom_bg_upload.change(
370
+ fn=handle_custom_background_upload,
371
+ inputs=[custom_bg_upload],
372
+ outputs=[background_image_state, background_status],
373
+ queue=False,
374
+ concurrency_limit=2,
375
+ ).then(
376
+ fn=lambda img: img,
377
+ inputs=[background_image_state],
378
+ outputs=[background_preview],
379
+ queue=False,
380
+ )
381
+
382
+ # Copy to job dir + probe video info; keep queued but single flight
383
+ video_upload.change(
384
+ fn=handle_video_upload,
385
+ inputs=[video_upload],
386
+ outputs=[video_path_state, video_status],
387
+ queue=True,
388
+ concurrency_limit=1,
389
+ ).then(
390
+ fn=handle_preview_generation,
391
+ inputs=[video_path_state],
392
+ outputs=[video_preview, preview_status],
393
+ queue=False,
394
+ )
395
+
396
+ # Background generation can be heavier; single-flight
397
+ generate_bg_btn.click(
398
+ fn=handle_background_generation,
399
+ inputs=[background_type, background_prompt, video_path_state],
400
+ outputs=[background_image_state, background_status],
401
+ queue=True,
402
+ concurrency_limit=1,
403
+ ).then(
404
+ fn=lambda img: img,
405
+ inputs=[background_image_state],
406
+ outputs=[background_preview],
407
+ queue=False,
408
+ )
409
+
410
+ # The heavy pipeline β€” single-flight
411
+ process_btn.click(
412
+ fn=handle_video_processing,
413
+ inputs=[
414
+ video_path_state,
415
+ background_image_state,
416
+ background_type,
417
+ background_prompt
418
+ ],
419
+ outputs=[result_video, processing_status],
420
+ queue=True,
421
+ concurrency_limit=1,
422
+ ).then(
423
+ # Wire the download button (set value=path and visibility)
424
+ fn=lambda path: gr.update(value=path, visible=bool(path)),
425
+ inputs=[result_video],
426
+ outputs=[download_btn],
427
+ queue=False,
428
+ )
429
+
430
+ return demo
VideoBackgroundReplacer2/update_pins.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ update_pins.py
4
+ - Fetch newest SHAs (release tag or default branch) for SAM2 + MatAnyone
5
+ - Update ARG lines in Dockerfile: SAM2_SHA / MATANYONE_SHA
6
+ - Supports dry-run and manual pins
7
+ - Uses GitHub API; set GITHUB_TOKEN to avoid rate limits (optional)
8
+ """
9
+
10
+ import os
11
+ import re
12
+ import sys
13
+ import json
14
+ import argparse
15
+ from urllib.parse import urlparse
16
+ import requests
17
+ from datetime import datetime, timezone
18
+ from shutil import copyfile
19
+
20
+ DOCKERFILE_PATH = "Dockerfile"
21
+
22
+ # Default repos (must match your Dockerfile ARGs)
23
+ SAM2_REPO_URL = "https://github.com/facebookresearch/segment-anything-2"
24
+ MATANY_REPO_URL = "https://github.com/pq-yang/MatAnyone"
25
+
26
+ SESSION = requests.Session()
27
+ if os.getenv("GITHUB_TOKEN"):
28
+ SESSION.headers.update({"Authorization": f"Bearer {os.environ['GITHUB_TOKEN']}"})
29
+ SESSION.headers.update({
30
+ "Accept": "application/vnd.github+json",
31
+ "User-Agent": "update-pins-script"
32
+ })
33
+
34
+ def gh_owner_repo(repo_url: str):
35
+ p = urlparse(repo_url)
36
+ parts = p.path.strip("/").split("/")
37
+ if len(parts) < 2:
38
+ raise ValueError(f"Invalid repo URL: {repo_url}")
39
+ return parts[0], parts[1]
40
+
41
+ def gh_api(path: str):
42
+ url = f"https://api.github.com{path}"
43
+ r = SESSION.get(url, timeout=30)
44
+ if r.status_code >= 400:
45
+ raise RuntimeError(f"GitHub API error {r.status_code}: {r.text}")
46
+ return r.json()
47
+
48
+ def get_latest_release_sha(repo_url: str) -> tuple[str, str]:
49
+ """Return (ref_desc, commit_sha) using latest release tag."""
50
+ owner, repo = gh_owner_repo(repo_url)
51
+ try:
52
+ rel = gh_api(f"/repos/{owner}/{repo}/releases/latest")
53
+ tag = rel["tag_name"]
54
+ # Resolve tag to commit
55
+ ref = gh_api(f"/repos/{owner}/{repo}/git/ref/tags/{tag}")
56
+ obj = ref["object"]
57
+ if obj["type"] == "tag":
58
+ tag_obj = gh_api(f"/repos/{owner}/{repo}/git/tags/{obj['sha']}")
59
+ sha = tag_obj["object"]["sha"]
60
+ else:
61
+ sha = obj["sha"]
62
+ return (f"release:{tag}", sha)
63
+ except Exception as e:
64
+ raise RuntimeError(f"Could not get latest release for {repo}: {e}")
65
+
66
+ def get_latest_default_branch_sha(repo_url: str) -> tuple[str, str]:
67
+ """Return (ref_desc, commit_sha) using the default branch head."""
68
+ owner, repo = gh_owner_repo(repo_url)
69
+ info = gh_api(f"/repos/{owner}/{repo}")
70
+ default_branch = info["default_branch"]
71
+ branch = gh_api(f"/repos/{owner}/{repo}/branches/{default_branch}")
72
+ sha = branch["commit"]["sha"]
73
+ return (f"branch:{default_branch}", sha)
74
+
75
+ def get_sha_for_ref(repo_url: str, ref: str) -> tuple[str, str]:
76
+ """
77
+ Resolve any Git ref (branch name, tag name, or commit SHA) to a commit SHA.
78
+ """
79
+ owner, repo = gh_owner_repo(repo_url)
80
+ # If it's already a full SHA, just return it
81
+ if re.fullmatch(r"[0-9a-f]{40}", ref):
82
+ return (f"commit:{ref[:7]}", ref)
83
+ # Try branches/<ref>, then tags/<ref>, then commits/<ref>
84
+ for kind, path in [
85
+ ("branch", f"/repos/{owner}/{repo}/branches/{ref}"),
86
+ ("tag", f"/repos/{owner}/{repo}/git/ref/tags/{ref}"),
87
+ ("commit", f"/repos/{owner}/{repo}/commits/{ref}")
88
+ ]:
89
+ try:
90
+ data = gh_api(path)
91
+ if kind == "branch":
92
+ return (f"branch:{ref}", data["commit"]["sha"])
93
+ if kind == "tag":
94
+ obj = data["object"]
95
+ if obj["type"] == "tag":
96
+ tag_obj = gh_api(f"/repos/{owner}/{repo}/git/tags/{obj['sha']}")
97
+ return (f"tag:{ref}", tag_obj["object"]["sha"])
98
+ else:
99
+ return (f"tag:{ref}", obj["sha"])
100
+ if kind == "commit":
101
+ return (f"commit:{ref[:7]}", data["sha"])
102
+ except Exception:
103
+ continue
104
+ raise RuntimeError(f"Could not resolve ref '{ref}' for {repo}")
105
+
106
+ def update_dockerfile_arg(dockerfile_text: str, arg_name: str, new_value: str) -> str:
107
+ """
108
+ Replace a line like:
109
+ ARG SAM2_SHA=...
110
+ with:
111
+ ARG SAM2_SHA=<new_value>
112
+ """
113
+ pattern = rf"^(ARG\s+{re.escape(arg_name)}=).*$"
114
+
115
+ # Use a callable replacement to avoid backreference ambiguity (e.g., \12)
116
+ def repl(m: re.Match) -> str:
117
+ return m.group(1) + new_value
118
+
119
+ new_text, n = re.subn(pattern, repl, dockerfile_text, flags=re.MULTILINE)
120
+ if n == 0:
121
+ raise RuntimeError(f"ARG {arg_name}=… line not found in Dockerfile.")
122
+ return new_text
123
+
124
+ def main():
125
+ ap = argparse.ArgumentParser(description="Update pinned SHAs in Dockerfile.")
126
+ ap.add_argument("--mode", choices=["release", "default-branch"], default="release",
127
+ help="Where to pull pins from (latest GitHub release tag or default branch head).")
128
+ ap.add_argument("--sam2-ref", help="Explicit ref for SAM2 (tag/branch/sha). Overrides --mode.")
129
+ ap.add_argument("--matany-ref", help="Explicit ref for MatAnyone (tag/branch/sha). Overrides --mode.")
130
+ ap.add_argument("--dockerfile", default=DOCKERFILE_PATH, help="Path to Dockerfile.")
131
+ ap.add_argument("--dry-run", action="store_true", help="Show changes but do not write file.")
132
+ ap.add_argument("--json", action="store_true", help="Print resulting pins as JSON.")
133
+ ap.add_argument("--no-backup", action="store_true", help="Do not create a Dockerfile.bak backup.")
134
+ args = ap.parse_args()
135
+
136
+ # Resolve SHAs
137
+ if args.sam2_ref:
138
+ sam2_refdesc, sam2_sha = get_sha_for_ref(SAM2_REPO_URL, args.sam2_ref)
139
+ else:
140
+ sam2_refdesc, sam2_sha = (
141
+ get_latest_release_sha(SAM2_REPO_URL) if args.mode == "release"
142
+ else get_latest_default_branch_sha(SAM2_REPO_URL)
143
+ )
144
+
145
+ if args.matany_ref:
146
+ mat_refdesc, mat_sha = get_sha_for_ref(MATANY_REPO_URL, args.matany_ref)
147
+ else:
148
+ mat_refdesc, mat_sha = (
149
+ get_latest_release_sha(MATANY_REPO_URL) if args.mode == "release"
150
+ else get_latest_default_branch_sha(MATANY_REPO_URL)
151
+ )
152
+
153
+ result = {
154
+ "timestamp": datetime.now(timezone.utc).isoformat(),
155
+ "mode": args.mode,
156
+ "SAM2": {"repo": SAM2_REPO_URL, "ref": sam2_refdesc, "sha": sam2_sha},
157
+ "MatAnyone": {"repo": MATANY_REPO_URL, "ref": mat_refdesc, "sha": mat_sha},
158
+ }
159
+
160
+ # Show pins
161
+ if args.json:
162
+ print(json.dumps(result, indent=2))
163
+ else:
164
+ print(f"[Pins] SAM2 -> {sam2_refdesc} -> {sam2_sha}")
165
+ print(f"[Pins] MatAnyone -> {mat_refdesc} -> {mat_sha}")
166
+
167
+ # Read Dockerfile
168
+ if not os.path.isfile(args.dockerfile):
169
+ raise FileNotFoundError(f"Dockerfile not found at: {args.dockerfile}")
170
+ with open(args.dockerfile, "r", encoding="utf-8") as f:
171
+ text = f.read()
172
+
173
+ # Update lines
174
+ text = update_dockerfile_arg(text, "SAM2_SHA", sam2_sha)
175
+ text = update_dockerfile_arg(text, "MATANYONE_SHA", mat_sha)
176
+
177
+ if args.dry_run:
178
+ print("\n--- Dockerfile (preview) ---\n")
179
+ print(text)
180
+ return
181
+
182
+ # Backup
183
+ if not args.no_backup:
184
+ copyfile(args.dockerfile, args.dockerfile + ".bak")
185
+
186
+ # Write
187
+ with open(args.dockerfile, "w", encoding="utf-8") as f:
188
+ f.write(text)
189
+
190
+ print(f"\nβœ… Updated {args.dockerfile} with new pins.")
191
+
192
+ if __name__ == "__main__":
193
+ try:
194
+ main()
195
+ except Exception as e:
196
+ print(f"\n❌ Error: {e}", file=sys.stderr)
197
+ sys.exit(1)
VideoBackgroundReplacer2/utils/__init__.py ADDED
File without changes
VideoBackgroundReplacer2/utils/paths.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/paths.py
2
+ from pathlib import Path
3
+ import os, re, uuid, shutil
4
+
5
+ APP_ROOT = Path(__file__).resolve().parents[1]
6
+ DATA_ROOT = APP_ROOT / "data"
7
+ TMP_ROOT = APP_ROOT / "tmp"
8
+ for p in (DATA_ROOT, TMP_ROOT, APP_ROOT / ".hf", APP_ROOT / ".torch"):
9
+ p.mkdir(parents=True, exist_ok=True)
10
+
11
+ os.environ.setdefault("HF_HOME", str(APP_ROOT / ".hf"))
12
+ os.environ.setdefault("TORCH_HOME", str(APP_ROOT / ".torch"))
13
+
14
+ def safe_name(name: str, default="file"):
15
+ base = re.sub(r"[^A-Za-z0-9._-]+", "_", (name or default))
16
+ return (base or default)[:120]
17
+
18
+ def job_dir(prefix="job"):
19
+ d = DATA_ROOT / f"{prefix}-{uuid.uuid4().hex[:8]}"
20
+ d.mkdir(parents=True, exist_ok=True)
21
+ return d
22
+
23
+ def disk_stats(p: Path = APP_ROOT) -> str:
24
+ try:
25
+ total, used, free = shutil.disk_usage(str(p))
26
+ mb = lambda x: x // (1024 * 1024)
27
+ return f"disk(total={mb(total)}MB, used={mb(used)}MB, free={mb(free)}MB)"
28
+ except Exception:
29
+ return "disk(n/a)"
VideoBackgroundReplacer2/utils/perf_tuning.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/perf_tuning.py
2
+ import os, logging
3
+ try:
4
+ import cv2
5
+ except Exception:
6
+ cv2 = None
7
+ import torch
8
+
9
+ def apply():
10
+ os.environ.setdefault("OMP_NUM_THREADS", "4")
11
+ if cv2:
12
+ try:
13
+ cv2.setNumThreads(4)
14
+ except Exception as e:
15
+ logging.info("cv2 threads not set: %s", e)
16
+ if torch.cuda.is_available():
17
+ torch.backends.cudnn.benchmark = True
18
+ try:
19
+ logging.info("CUDA device %s β€” cuDNN benchmark ON", torch.cuda.get_device_name(0))
20
+ except Exception:
21
+ logging.info("CUDA available β€” cuDNN benchmark ON")
app.py CHANGED
@@ -1,300 +1,570 @@
1
- #!/usr/bin/env python3
2
- """
3
- VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
4
- =======================================================
5
- - Sets up Gradio UI and launches pipeline
6
- - Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
7
-
8
- Changes (2025-09-18):
9
- - Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
10
- - Added toggleable "mount mode": run Gradio inside our own FastAPI app
11
- and provide a safe /config route shim (uses demo.get_config_file()).
12
- - Kept your startup diagnostics, GPU logging, and heartbeats
13
- """
14
-
15
- print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
16
-
17
- # ---------------------------------------------------------------------
18
- # Imports & basic setup
19
- # ---------------------------------------------------------------------
20
- import sys
21
  import os
22
- import gc
23
- import json
24
- import logging
25
- import threading
26
  import time
27
- import warnings
28
- import traceback
29
- import subprocess
30
  from pathlib import Path
31
- from loguru import logger
32
-
33
- # Logging (loguru to stderr)
34
- logger.remove()
35
- logger.add(
36
- sys.stderr,
37
- format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> "
38
- "| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
- # Warnings
42
- warnings.filterwarnings("ignore", category=UserWarning)
43
- warnings.filterwarnings("ignore", category=FutureWarning)
44
- warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
45
-
46
- # Environment (lightweight & safe in Spaces)
47
- os.environ.setdefault("OMP_NUM_THREADS", "1")
48
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
49
-
50
- # Paths
51
- BASE_DIR = Path(__file__).parent.absolute()
52
- THIRD_PARTY_DIR = BASE_DIR / "third_party"
53
- SAM2_DIR = THIRD_PARTY_DIR / "sam2"
54
- CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
55
-
56
- # Python path extends
57
- for p in (str(THIRD_PARTY_DIR), str(SAM2_DIR)):
58
- if p not in sys.path:
59
- sys.path.insert(0, p)
60
-
61
- logger.info(f"Base directory: {BASE_DIR}")
62
- logger.info(f"Python path[0:5]: {sys.path[:5]}")
63
-
64
- # ---------------------------------------------------------------------
65
- # GPU / Torch diagnostics (non-blocking)
66
- # ---------------------------------------------------------------------
67
- try:
68
- import torch
69
- except Exception as e:
70
- logger.warning("Torch import failed at startup: %s", e)
71
- torch = None
72
-
73
- DEVICE = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
74
- if DEVICE == "cuda":
75
- os.environ["SAM2_DEVICE"] = "cuda"
76
- os.environ["MATANY_DEVICE"] = "cuda"
77
- os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
78
- try:
79
- logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
80
- except Exception:
81
- logger.info("CUDA device name not available at startup.")
82
- else:
83
- os.environ["SAM2_DEVICE"] = "cpu"
84
- os.environ["MATANY_DEVICE"] = "cpu"
85
- logger.warning("CUDA not available, falling back to CPU")
86
-
87
- def verify_models():
88
- """Verify critical model files exist and are loadable (cheap checks)."""
89
- results = {"status": "success", "details": {}}
90
- try:
91
- sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
92
- if not os.path.exists(sam2_model_path):
93
- raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
94
- # Cheap load test (map to CPU to avoid VRAM use during boot)
95
- if torch:
96
- sd = torch.load(sam2_model_path, map_location="cpu")
97
- if not isinstance(sd, dict):
98
- raise ValueError("Invalid SAM2 checkpoint format")
99
- results["details"]["sam2"] = {
100
- "status": "success",
101
- "path": sam2_model_path,
102
- "size_mb": round(os.path.getsize(sam2_model_path) / (1024 * 1024), 2),
103
- }
104
- except Exception as e:
105
- results["status"] = "error"
106
- results["details"]["sam2"] = {
107
- "status": "error",
108
- "error": str(e),
109
- "traceback": traceback.format_exc(),
110
- }
111
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- def run_startup_diagnostics():
114
- diag = {
115
- "system": {
116
- "python": sys.version,
117
- "pytorch": getattr(torch, "__version__", None) if torch else None,
118
- "cuda_available": bool(torch and torch.cuda.is_available()),
119
- "device_count": (torch.cuda.device_count() if torch and torch.cuda.is_available() else 0),
120
- "cuda_version": getattr(getattr(torch, "version", None), "cuda", None) if torch else None,
121
- },
122
- "paths": {
123
- "base_dir": str(BASE_DIR),
124
- "checkpoints_dir": str(CHECKPOINTS_DIR),
125
- "sam2_dir": str(SAM2_DIR),
126
- },
127
- "env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
128
  }
129
- diag["model_verification"] = verify_models()
130
- return diag
131
-
132
- startup_diag = run_startup_diagnostics()
133
- logger.info("Startup diagnostics completed")
134
-
135
- # Noisy heartbeat so logs show life during import time
136
- def _heartbeat():
137
- i = 0
138
- while True:
139
- i += 1
140
- print(f"[startup-heartbeat] {i*5}s…", flush=True)
141
- time.sleep(5)
142
-
143
- threading.Thread(target=_heartbeat, daemon=True).start()
144
-
145
- # Optional perf tuning import (non-fatal)
146
- try:
147
- import perf_tuning # noqa: F401
148
- logger.info("perf_tuning imported successfully.")
149
- except Exception as e:
150
- logger.info("perf_tuning not available: %s", e)
151
-
152
- # MatAnyone non-instantiating probe
153
- try:
154
- import inspect
155
- from matanyone.inference import inference_core as ic # type: ignore
156
- sigs = {}
157
- for name in ("InferenceCore",):
158
- obj = getattr(ic, name, None)
159
- if obj:
160
- sigs[name] = "callable" if callable(obj) else "present"
161
- logger.info(f"[MATANY] probe (non-instantiating): {sigs}")
162
- except Exception as e:
163
- logger.info(f"[MATANY] probe skipped: {e}")
164
-
165
- # ---------------------------------------------------------------------
166
- # Gradio import and web-stack probes
167
- # ---------------------------------------------------------------------
168
- import gradio as gr
169
-
170
- # Standard logger for some libs that use stdlib logging
171
- py_logger = logging.getLogger("backgroundfx_pro")
172
- if not py_logger.handlers:
173
- h = logging.StreamHandler()
174
- h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
175
- py_logger.addHandler(h)
176
- py_logger.setLevel(logging.INFO)
177
-
178
- def _log_web_stack_versions_and_paths():
179
- import inspect
180
- try:
181
- import fastapi, starlette, pydantic, httpx, anyio
182
- try:
183
- import pydantic_core
184
- pc_ver = pydantic_core.__version__
185
- except Exception:
186
- pc_ver = "unknown"
187
- logger.info(
188
- "[WEB-STACK] fastapi=%s | starlette=%s | pydantic=%s | pydantic-core=%s | httpx=%s | anyio=%s",
189
- getattr(fastapi, "__version__", "?"),
190
- getattr(starlette, "__version__", "?"),
191
- getattr(pydantic, "__version__", "?"),
192
- pc_ver,
193
- getattr(httpx, "__version__", "?"),
194
- getattr(anyio, "__version__", "?"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
- except Exception as e:
197
- logger.warning("[WEB-STACK] version probe failed: %s", e)
198
-
199
- try:
200
- import gradio
201
- import gradio.routes as gr_routes
202
- import gradio.queueing as gr_queueing
203
- logger.info("[PATH] gradio.__file__ = %s", getattr(gradio, "__file__", "?"))
204
- logger.info("[PATH] gradio.routes = %s", inspect.getfile(gr_routes))
205
- logger.info("[PATH] gradio.queueing = %s", inspect.getfile(gr_queueing))
206
- import starlette.exceptions as st_exc
207
- logger.info("[PATH] starlette.exceptions= %s", inspect.getfile(st_exc))
208
- except Exception as e:
209
- logger.warning("[PATH] probe failed: %s", e)
210
-
211
- def _post_launch_diag():
212
- try:
213
- if not torch:
214
- return
215
- avail = torch.cuda.is_available()
216
- logger.info("CUDA available (post-launch): %s", avail)
217
- if avail:
218
- idx = torch.cuda.current_device()
219
- name = torch.cuda.get_device_name(idx)
220
- cap = torch.cuda.get_device_capability(idx)
221
- logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
222
- except Exception as e:
223
- logger.warning("Post-launch CUDA diag failed: %s", e)
224
-
225
- # ---------------------------------------------------------------------
226
- # UI factory (uses your existing builder)
227
- # ---------------------------------------------------------------------
228
- def build_ui() -> gr.Blocks:
229
- # FIX: import from ui_core_interface (not from ui)
230
- from ui_core_interface import create_interface
231
- return create_interface()
232
-
233
- # ---------------------------------------------------------------------
234
- # Optional: custom FastAPI mount mode
235
- # ---------------------------------------------------------------------
236
- def build_fastapi_with_gradio(demo: gr.Blocks):
237
  """
238
- Returns a FastAPI app with Gradio mounted at root.
239
- Also exposes JSON health and a config shim using demo.get_config_file().
240
  """
241
- from fastapi import FastAPI
242
- from fastapi.responses import JSONResponse
243
-
244
- app = FastAPI(title="VideoBackgroundReplacer2")
245
-
246
- @app.get("/healthz")
247
- def _healthz():
248
- return {"ok": True, "ts": time.time()}
249
-
250
- @app.get("/config")
251
- def _config():
252
- try:
253
- cfg = demo.get_config_file()
254
- return JSONResponse(content=cfg)
255
- except Exception as e:
256
- return JSONResponse(
257
- status_code=500,
258
- content={"error": "config_generation_failed", "detail": str(e)},
259
- )
260
-
261
- # Mount Gradio UI at root; our /config route remains at parent level
262
- app = gr.mount_gradio_app(app, demo, path="/")
263
- return app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- # ---------------------------------------------------------------------
266
- # Entrypoint
267
- # ---------------------------------------------------------------------
268
  if __name__ == "__main__":
269
- host = os.environ.get("HOST", "0.0.0.0")
270
- port = int(os.environ.get("PORT", "7860"))
271
- mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
272
-
273
- logger.info("Launching on %s:%s (mount_mode=%s)…", host, port, mount_mode)
274
- _log_web_stack_versions_and_paths()
275
-
276
- demo = build_ui()
277
- demo.queue(max_size=16, api_open=False)
278
-
279
- threading.Thread(target=_post_launch_diag, daemon=True).start()
280
-
281
- if mount_mode:
282
- try:
283
- from uvicorn import run as uvicorn_run
284
- except Exception:
285
- logger.error("uvicorn is not installed; mount mode cannot start.")
286
- raise
287
-
288
- app = build_fastapi_with_gradio(demo)
289
- uvicorn_run(app=app, host=host, port=port, log_level="info")
290
- else:
291
- demo.launch(
292
- server_name=host,
293
- server_port=port,
294
- share=False,
295
- show_api=False,
296
- show_error=True,
297
- quiet=False,
298
- debug=True,
299
- max_threads=1,
300
- )
 
1
+ import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
+ import sys
4
+ import tempfile
 
 
5
  import time
 
 
 
6
  from pathlib import Path
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+ import logging
11
+ import base64
12
+ from io import BytesIO
13
+
14
+ # Add project root to path
15
+ sys.path.append(str(Path(__file__).parent.absolute()))
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Set page config
22
+ st.set_page_config(
23
+ page_title="MyAvatar - Video Background Replacer",
24
+ page_icon="πŸŽ₯",
25
+ layout="wide",
26
+ initial_sidebar_state="expanded"
27
  )
28
 
29
+ # Custom CSS for better UI with logo
30
+ def add_logo():
31
+ st.markdown(
32
+ """
33
+ <style>
34
+ .main .block-container {
35
+ padding-top: 2rem;
36
+ padding-bottom: 2rem;
37
+ }
38
+ .stButton>button {
39
+ width: 100%;
40
+ background-color: #4CAF50;
41
+ color: white;
42
+ font-weight: bold;
43
+ transition: all 0.3s;
44
+ }
45
+ .stButton>button:hover {
46
+ background-color: #45a049;
47
+ }
48
+ .stProgress > div > div > div > div {
49
+ background-color: #4CAF50;
50
+ }
51
+ .stAlert {
52
+ border-radius: 10px;
53
+ }
54
+ .stTabs [data-baseweb="tab-list"] {
55
+ gap: 10px;
56
+ }
57
+ .stTabs [data-baseweb="tab"] {
58
+ height: 50px;
59
+ white-space: pre;
60
+ background-color: #f0f2f6;
61
+ border-radius: 4px 4px 0 0;
62
+ padding: 10px 20px;
63
+ margin-right: 5px;
64
+ }
65
+ .stTabs [aria-selected="true"] {
66
+ background-color: #4CAF50;
67
+ color: white;
68
+ }
69
+ .video-container {
70
+ border: 2px dashed #4CAF50;
71
+ border-radius: 10px;
72
+ padding: 10px;
73
+ margin-bottom: 20px;
74
+ }
75
+ .logo-container {
76
+ display: flex;
77
+ justify-content: flex-end;
78
+ padding: 10px 20px 0 0;
79
+ }
80
+ .logo {
81
+ height: 50px;
82
+ width: auto;
83
+ margin-bottom: -20px;
84
+ }
85
+ .title-container {
86
+ text-align: center;
87
+ margin-bottom: 30px;
88
+ }
89
+ .color-swatch {
90
+ display: inline-block;
91
+ width: 30px;
92
+ height: 30px;
93
+ margin: 5px;
94
+ border: 2px solid #ddd;
95
+ border-radius: 4px;
96
+ cursor: pointer;
97
+ transition: transform 0.2s;
98
+ }
99
+ .color-swatch:hover {
100
+ transform: scale(1.1);
101
+ border-color: #4CAF50;
102
+ }
103
+ .background-option {
104
+ padding: 10px;
105
+ margin: 5px 0;
106
+ border-radius: 5px;
107
+ background-color: #f8f9fa;
108
+ border-left: 4px solid #4CAF50;
109
+ }
110
+ </style>
111
+ """,
112
+ unsafe_allow_html=True
113
+ )
114
+
115
+ # Add logo to the top right
116
+ st.markdown(
117
+ """
118
+ <div class="logo-container">
119
+ <img src="data:image/png;base64,{}" class="logo">
120
+ </div>
121
+ """.format(base64.b64encode(open("myavatar_logo.png", "rb").read()).decode()),
122
+ unsafe_allow_html=True
123
+ )
124
+
125
+ def initialize_session_state():
126
+ """Initialize all session state variables"""
127
+ if 'uploaded_video' not in st.session_state:
128
+ st.session_state.uploaded_video = None
129
+ if 'bg_image' not in st.session_state:
130
+ st.session_state.bg_image = None
131
+ if 'bg_color' not in st.session_state:
132
+ st.session_state.bg_color = "#00FF00"
133
+ if 'bg_type' not in st.session_state:
134
+ st.session_state.bg_type = "image"
135
+ if 'processed_video_path' not in st.session_state:
136
+ st.session_state.processed_video_path = None
137
+ if 'processing' not in st.session_state:
138
+ st.session_state.processing = False
139
+ if 'progress' not in st.session_state:
140
+ st.session_state.progress = 0
141
+ if 'progress_text' not in st.session_state:
142
+ st.session_state.progress_text = "Ready"
143
+
144
+ def handle_video_upload():
145
+ """Handle video file upload"""
146
+ uploaded = st.file_uploader(
147
+ "πŸ“Ή Upload Video",
148
+ type=["mp4", "mov", "avi"],
149
+ key="video_uploader"
150
+ )
151
+ if uploaded is not None:
152
+ st.session_state.uploaded_video = uploaded
153
+
154
+ def show_video_preview():
155
+ """Show video preview in the UI"""
156
+ st.markdown("### Video Preview")
157
+ if st.session_state.uploaded_video is not None:
158
+ video_bytes = st.session_state.uploaded_video.getvalue()
159
+ st.video(video_bytes)
160
+ st.session_state.uploaded_video.seek(0)
161
+
162
+ def handle_background_selection():
163
+ """Handle background selection UI with all options"""
164
+ st.markdown("### Background Options")
165
+
166
+ # Background type selection
167
+ bg_type = st.radio(
168
+ "Select Background Type:",
169
+ ["Image", "Color", "Blur", "Professional Backgrounds", "AI Generated"],
170
+ horizontal=True,
171
+ key="bg_type_radio"
172
+ )
173
+
174
+ st.session_state.bg_type = bg_type.lower()
175
+
176
+ # Show appropriate controls based on selection
177
+ if bg_type == "Image":
178
+ handle_image_background()
179
+ elif bg_type == "Color":
180
+ handle_color_background()
181
+ elif bg_type == "Blur":
182
+ handle_blur_background()
183
+ elif bg_type == "Professional Backgrounds":
184
+ handle_professional_backgrounds()
185
+ elif bg_type == "AI Generated":
186
+ handle_ai_generated_background()
187
+
188
+ def handle_image_background():
189
+ """Handle image background selection"""
190
+ bg_image = st.file_uploader(
191
+ "πŸ–ΌοΈ Upload Background Image",
192
+ type=["jpg", "png", "jpeg"],
193
+ key="bg_image_uploader"
194
+ )
195
+ if bg_image is not None:
196
+ st.session_state.bg_image = Image.open(bg_image)
197
+ st.image(
198
+ st.session_state.bg_image,
199
+ caption="Selected Background",
200
+ use_container_width=True
201
+ )
202
 
203
+ def handle_color_background():
204
+ """Handle color background selection with presets"""
205
+ st.markdown("#### Select a Color")
206
+
207
+ # Color presets
208
+ color_presets = {
209
+ "Pure White": "#FFFFFF",
210
+ "Pure Black": "#000000",
211
+ "Light Gray": "#F5F5F5",
212
+ "Dark Gray": "#333333",
213
+ "Professional Blue": "#0078D4",
214
+ "Corporate Green": "#107C10",
215
+ "Warm Beige": "#F5F5DC",
216
+ "Custom": st.session_state.get('bg_color', "#00FF00")
 
217
  }
218
+
219
+ # Create color swatches
220
+ cols = st.columns(4)
221
+ selected_color = None
222
+
223
+ for i, (name, color) in enumerate(color_presets.items()):
224
+ with cols[i % 4]:
225
+ if name == "Custom":
226
+ # Show color picker for custom color
227
+ st.session_state.bg_color = st.color_picker(
228
+ "Custom Color",
229
+ st.session_state.get('bg_color', "#00FF00"),
230
+ key="custom_color_picker"
231
+ )
232
+ else:
233
+ # Create a color swatch
234
+ if st.button(
235
+ "",
236
+ key=f"color_{name}",
237
+ help=name,
238
+ type="secondary",
239
+ use_container_width=True
240
+ ):
241
+ st.session_state.bg_color = color
242
+
243
+ # Show the color preview
244
+ st.markdown(
245
+ f'<div style="background-color:{color}; height:30px; border-radius:4px; margin-top:-10px;"></div>',
246
+ unsafe_allow_html=True
247
+ )
248
+ st.caption(name)
249
+
250
+ def handle_blur_background():
251
+ """Handle blur background selection"""
252
+ blur_strength = st.select_slider(
253
+ "Blur Strength:",
254
+ options=["Subtle", "Medium", "Strong", "Maximum"],
255
+ value="Medium",
256
+ key="blur_strength"
257
+ )
258
+
259
+ # Show preview of blur effect
260
+ st.markdown("**Preview**")
261
+ preview_img = np.zeros((100, 200, 3), dtype=np.uint8)
262
+ cv2.putText(
263
+ preview_img,
264
+ "Blur Effect",
265
+ (20, 50),
266
+ cv2.FONT_HERSHEY_SIMPLEX,
267
+ 0.8,
268
+ (255, 255, 255),
269
+ 2
270
+ )
271
+
272
+ # Apply blur based on selection
273
+ if blur_strength == "Subtle":
274
+ preview_img = cv2.GaussianBlur(preview_img, (15, 15), 5)
275
+ elif blur_strength == "Medium":
276
+ preview_img = cv2.GaussianBlur(preview_img, (25, 25), 10)
277
+ elif blur_strength == "Strong":
278
+ preview_img = cv2.GaussianBlur(preview_img, (35, 35), 15)
279
+ else: # Maximum
280
+ preview_img = cv2.GaussianBlur(preview_img, (51, 51), 20)
281
+
282
+ st.image(preview_img, use_column_width=True)
283
+ st.caption(f"Selected: {blur_strength} blur")
284
+
285
+ def handle_professional_backgrounds():
286
+ """Handle professional background selection"""
287
+ categories = {
288
+ "Office Settings": ["Modern Office", "Executive Office", "Home Office", "Conference Room"],
289
+ "Virtual Backgrounds": ["Professional", "Minimalist", "Creative", "Branded"],
290
+ "Nature Scenes": ["Forest", "Beach", "Mountain", "City Skyline"],
291
+ "Abstract Designs": ["Gradient", "Geometric", "Particles", "Bokeh"]
292
+ }
293
+
294
+ # Category selection
295
+ selected_category = st.selectbox(
296
+ "Select Category:",
297
+ list(categories.keys()),
298
+ key="bg_category"
299
+ )
300
+
301
+ # Show thumbnails for selected category
302
+ st.markdown("#### Available Backgrounds")
303
+ cols = st.columns(2)
304
+
305
+ for i, bg in enumerate(categories[selected_category]):
306
+ with cols[i % 2]:
307
+ # Create a placeholder image (replace with actual thumbnails)
308
+ img = np.zeros((120, 200, 3), dtype=np.uint8)
309
+ cv2.putText(
310
+ img,
311
+ bg,
312
+ (20, 60),
313
+ cv2.FONT_HERSHEY_SIMPLEX,
314
+ 0.7,
315
+ (255, 255, 255),
316
+ 2
317
+ )
318
+
319
+ if st.button(
320
+ f"Use {bg}",
321
+ key=f"prof_bg_{bg}",
322
+ use_container_width=True
323
+ ):
324
+ st.session_state.selected_bg = bg
325
+ st.success(f"Selected: {bg}")
326
+
327
+ st.image(img, use_column_width=True)
328
+
329
+ def handle_ai_generated_background():
330
+ """Handle AI generated background selection"""
331
+ ai_prompts = [
332
+ "Professional office with bookshelf",
333
+ "Modern co-working space",
334
+ "Neutral abstract background",
335
+ "City skyline at night",
336
+ "Minimalist home office setup",
337
+ "Corporate meeting room",
338
+ "Creative studio background",
339
+ "Custom prompt..."
340
+ ]
341
+
342
+ # Prompt selection
343
+ selected_prompt = st.selectbox(
344
+ "Select a prompt or create your own:",
345
+ ai_prompts,
346
+ key="ai_prompt_select"
347
+ )
348
+
349
+ if selected_prompt == "Custom prompt...":
350
+ custom_prompt = st.text_input(
351
+ "Enter your custom prompt:",
352
+ key="custom_ai_prompt"
353
  )
354
+ if custom_prompt:
355
+ selected_prompt = custom_prompt
356
+
357
+ # Generate button
358
+ if st.button(
359
+ "οΏ½οΏ½οΏ½ Generate Background",
360
+ key="generate_ai_bg",
361
+ use_container_width=True
362
+ ):
363
+ with st.spinner(f"Generating '{selected_prompt}'..."):
364
+ # Simulate generation
365
+ time.sleep(2)
366
+
367
+ # Create a placeholder for the generated image
368
+ img = np.zeros((300, 500, 3), dtype=np.uint8)
369
+ cv2.putText(
370
+ img,
371
+ f"AI Generated:\n{selected_prompt}",
372
+ (50, 150),
373
+ cv2.FONT_HERSHEY_SIMPLEX,
374
+ 0.8,
375
+ (255, 255, 255),
376
+ 2,
377
+ cv2.LINE_AA
378
+ )
379
+
380
+ # Show the "generated" image
381
+ st.image(img, use_column_width=True)
382
+
383
+ # Add use button
384
+ if st.button(
385
+ "Use This Background",
386
+ key="use_ai_bg",
387
+ use_container_width=True
388
+ ):
389
+ st.session_state.bg_image = Image.fromarray(img)
390
+ st.success("Background selected!")
391
+
392
+ def process_video(input_file, background, bg_type="image"):
 
 
393
  """
394
+ Mock video processing that works without SAM2/MatA
395
+ Just applies a simple effect to simulate background replacement
396
  """
397
+ try:
398
+ # Create a temporary directory for processing
399
+ with tempfile.TemporaryDirectory() as temp_dir:
400
+ temp_dir = Path(temp_dir)
401
+
402
+ # Save the uploaded video to a temporary file
403
+ input_path = str(temp_dir / "input.mp4")
404
+ with open(input_path, "wb") as f:
405
+ f.write(input_file.getvalue())
406
+
407
+ # Set up progress bar
408
+ progress_bar = st.progress(0)
409
+ status_text = st.empty()
410
+
411
+ def update_progress(progress, message):
412
+ progress = max(0, min(1, progress))
413
+ progress_bar.progress(progress)
414
+ status_text.text(f"Status: {message}")
415
+
416
+ # Simulate processing steps
417
+ update_progress(0.1, "Loading video...")
418
+ time.sleep(1)
419
+
420
+ update_progress(0.3, "Processing frames...")
421
+ time.sleep(2)
422
+
423
+ # Create a simple output video that just adds a colored border
424
+ cap = cv2.VideoCapture(input_path)
425
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
426
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
427
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
428
+
429
+ output_path = str(temp_dir / "output.mp4")
430
+ fourcc = cv2.VideoWriter_fourcentCC(*'mp4v')
431
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
432
+
433
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
434
+
435
+ for i in range(frame_count):
436
+ ret, frame = cap.read()
437
+ if not ret:
438
+ break
439
+
440
+ # Simple effect: add a colored border to simulate processing
441
+ border_size = 20
442
+ if bg_type == "color":
443
+ color_hex = st.session_state.bg_color.lstrip('#')
444
+ color_bgr = tuple(int(color_hex[i:i+2], 16) for i in (4, 2, 0)) # RGB to BGR
445
+ else:
446
+ color_bgr = (0, 255, 0) # Default green border
447
+
448
+ frame = cv2.copyMakeBorder(
449
+ frame,
450
+ border_size, border_size, border_size, border_size,
451
+ cv2.BORDER_CONSTANT,
452
+ value=color_bgr
453
+ )
454
+
455
+ # Resize back to original dimensions
456
+ frame = cv2.resize(frame, (width, height))
457
+
458
+ out.write(frame)
459
+
460
+ # Update progress
461
+ if i % 10 == 0:
462
+ update_progress(0.3 + 0.7 * (i/frame_count), f"Processing frame {i}/{frame_count}")
463
+
464
+ cap.release()
465
+ out.release()
466
+
467
+ update_progress(1.0, "Processing complete!")
468
+ time.sleep(0.5)
469
+
470
+ return output_path
471
+
472
+ except Exception as e:
473
+ logger.error(f"Error in mock video processing: {str(e)}", exc_info=True)
474
+ st.error(f"An error occurred during processing: {str(e)}")
475
+ return None
476
+
477
+ def main():
478
+ # Add custom CSS and logo
479
+ add_logo()
480
+
481
+ # Title
482
+ st.markdown(
483
+ """
484
+ <div class="title-container">
485
+ <h1>πŸŽ₯ Video Background Replacer</h1>
486
+ </div>
487
+ """,
488
+ unsafe_allow_html=True
489
+ )
490
+
491
+ st.markdown("---")
492
+
493
+ # Initialize session state
494
+ initialize_session_state()
495
+
496
+ # Main layout
497
+ col1, col2 = st.columns([1, 1], gap="large")
498
+
499
+ with col1:
500
+ st.header("1. Upload Video")
501
+ handle_video_upload()
502
+ show_video_preview()
503
+
504
+ with col2:
505
+ st.header("2. Background Settings")
506
+ handle_background_selection()
507
+
508
+ st.header("3. Process & Download")
509
+ if st.button(
510
+ "πŸš€ Process Video",
511
+ type="primary",
512
+ disabled=not st.session_state.uploaded_video or st.session_state.processing,
513
+ use_container_width=True
514
+ ):
515
+ with st.spinner("Processing video (this may take a few minutes)..."):
516
+ st.session_state.processing = True
517
+
518
+ try:
519
+ # Prepare background based on type
520
+ background = None
521
+ if st.session_state.bg_type == "image" and 'bg_image' in st.session_state and st.session_state.bg_image is not None:
522
+ background = st.session_state.bg_image
523
+ elif st.session_state.bg_type == "color" and 'bg_color' in st.session_state:
524
+ background = st.session_state.bg_color
525
+
526
+ # Process the video
527
+ output_path = process_video(
528
+ st.session_state.uploaded_video,
529
+ background,
530
+ bg_type=st.session_state.bg_type
531
+ )
532
+
533
+ if output_path and os.path.exists(output_path):
534
+ # Store the path to the processed video
535
+ st.session_state.processed_video_path = output_path
536
+ st.success("βœ… Video processing complete!")
537
+ else:
538
+ st.error("❌ Failed to process video. Please check the logs for details.")
539
+
540
+ except Exception as e:
541
+ st.error(f"❌ An error occurred: {str(e)}")
542
+ logger.exception("Video processing failed")
543
+
544
+ finally:
545
+ st.session_state.processing = False
546
+
547
+ # Show processed video if available
548
+ if 'processed_video_path' in st.session_state and st.session_state.processed_video_path:
549
+ st.markdown("### Processed Video")
550
+
551
+ try:
552
+ # Display the video directly from the file
553
+ with open(st.session_state.processed_video_path, 'rb') as f:
554
+ video_bytes = f.read()
555
+ st.video(video_bytes)
556
+
557
+ # Download button
558
+ st.download_button(
559
+ label="πŸ’Ύ Download Processed Video",
560
+ data=video_bytes,
561
+ file_name="processed_video.mp4",
562
+ mime="video/mp4",
563
+ use_container_width=True
564
+ )
565
+ except Exception as e:
566
+ st.error(f"Error displaying video: {str(e)}")
567
+ logger.error(f"Error displaying video: {str(e)}", exc_info=True)
568
 
 
 
 
569
  if __name__ == "__main__":
570
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_backup.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
4
+ =======================================================
5
+ - Sets up Gradio UI and launches pipeline
6
+ - Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
7
+
8
+ Changes (2025-09-18):
9
+ - Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
10
+ - Added toggleable "mount mode": run Gradio inside our own FastAPI app
11
+ and provide a safe /config route shim (uses demo.get_config_file()).
12
+ - Kept your startup diagnostics, GPU logging, and heartbeats
13
+ """
14
+
15
+ print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
16
+
17
+ # ---------------------------------------------------------------------
18
+ # Imports & basic setup
19
+ # ---------------------------------------------------------------------
20
+ import sys
21
+ import os
22
+ import gc
23
+ import json
24
+ import logging
25
+ import threading
26
+ import time
27
+ import warnings
28
+ import traceback
29
+ import subprocess
30
+ from pathlib import Path
31
+ from loguru import logger
32
+
33
+ # Logging (loguru to stderr)
34
+ logger.remove()
35
+ logger.add(
36
+ sys.stderr,
37
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> "
38
+ "| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
39
+ )
40
+
41
+ # Warnings
42
+ warnings.filterwarnings("ignore", category=UserWarning)
43
+ warnings.filterwarnings("ignore", category=FutureWarning)
44
+ warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
45
+
46
+ # Environment (lightweight & safe in Spaces)
47
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
48
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
49
+
50
+ # Paths
51
+ BASE_DIR = Path(__file__).parent.absolute()
52
+ THIRD_PARTY_DIR = BASE_DIR / "third_party"
53
+ SAM2_DIR = THIRD_PARTY_DIR / "sam2"
54
+ CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
55
+
56
+ # Python path extends
57
+ for p in (str(THIRD_PARTY_DIR), str(SAM2_DIR)):
58
+ if p not in sys.path:
59
+ sys.path.insert(0, p)
60
+
61
+ logger.info(f"Base directory: {BASE_DIR}")
62
+ logger.info(f"Python path[0:5]: {sys.path[:5]}")
63
+
64
+ # ---------------------------------------------------------------------
65
+ # GPU / Torch diagnostics (non-blocking)
66
+ # ---------------------------------------------------------------------
67
+ try:
68
+ import torch
69
+ except Exception as e:
70
+ logger.warning("Torch import failed at startup: %s", e)
71
+ torch = None
72
+
73
+ DEVICE = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
74
+ if DEVICE == "cuda":
75
+ os.environ["SAM2_DEVICE"] = "cuda"
76
+ os.environ["MATANY_DEVICE"] = "cuda"
77
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
78
+ try:
79
+ logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
80
+ except Exception:
81
+ logger.info("CUDA device name not available at startup.")
82
+ else:
83
+ os.environ["SAM2_DEVICE"] = "cpu"
84
+ os.environ["MATANY_DEVICE"] = "cpu"
85
+ logger.warning("CUDA not available, falling back to CPU")
86
+
87
+ def verify_models():
88
+ """Verify critical model files exist and are loadable (cheap checks)."""
89
+ results = {"status": "success", "details": {}}
90
+ try:
91
+ sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
92
+ if not os.path.exists(sam2_model_path):
93
+ raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
94
+ # Cheap load test (map to CPU to avoid VRAM use during boot)
95
+ if torch:
96
+ sd = torch.load(sam2_model_path, map_location="cpu")
97
+ if not isinstance(sd, dict):
98
+ raise ValueError("Invalid SAM2 checkpoint format")
99
+ results["details"]["sam2"] = {
100
+ "status": "success",
101
+ "path": sam2_model_path,
102
+ "size_mb": round(os.path.getsize(sam2_model_path) / (1024 * 1024), 2),
103
+ }
104
+ except Exception as e:
105
+ results["status"] = "error"
106
+ results["details"]["sam2"] = {
107
+ "status": "error",
108
+ "error": str(e),
109
+ "traceback": traceback.format_exc(),
110
+ }
111
+ return results
112
+
113
+ def run_startup_diagnostics():
114
+ diag = {
115
+ "system": {
116
+ "python": sys.version,
117
+ "pytorch": getattr(torch, "__version__", None) if torch else None,
118
+ "cuda_available": bool(torch and torch.cuda.is_available()),
119
+ "device_count": (torch.cuda.device_count() if torch and torch.cuda.is_available() else 0),
120
+ "cuda_version": getattr(getattr(torch, "version", None), "cuda", None) if torch else None,
121
+ },
122
+ "paths": {
123
+ "base_dir": str(BASE_DIR),
124
+ "checkpoints_dir": str(CHECKPOINTS_DIR),
125
+ "sam2_dir": str(SAM2_DIR),
126
+ },
127
+ "env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
128
+ }
129
+ diag["model_verification"] = verify_models()
130
+ return diag
131
+
132
+ startup_diag = run_startup_diagnostics()
133
+ logger.info("Startup diagnostics completed")
134
+
135
+ # Noisy heartbeat so logs show life during import time
136
+ def _heartbeat():
137
+ i = 0
138
+ while True:
139
+ i += 1
140
+ print(f"[startup-heartbeat] {i*5}s…", flush=True)
141
+ time.sleep(5)
142
+
143
+ threading.Thread(target=_heartbeat, daemon=True).start()
144
+
145
+ # Optional perf tuning import (non-fatal)
146
+ try:
147
+ import perf_tuning # noqa: F401
148
+ logger.info("perf_tuning imported successfully.")
149
+ except Exception as e:
150
+ logger.info("perf_tuning not available: %s", e)
151
+
152
+ # MatAnyone non-instantiating probe
153
+ try:
154
+ import inspect
155
+ from matanyone.inference import inference_core as ic # type: ignore
156
+ sigs = {}
157
+ for name in ("InferenceCore",):
158
+ obj = getattr(ic, name, None)
159
+ if obj:
160
+ sigs[name] = "callable" if callable(obj) else "present"
161
+ logger.info(f"[MATANY] probe (non-instantiating): {sigs}")
162
+ except Exception as e:
163
+ logger.info(f"[MATANY] probe skipped: {e}")
164
+
165
+ # ---------------------------------------------------------------------
166
+ # Gradio import and web-stack probes
167
+ # ---------------------------------------------------------------------
168
+ import gradio as gr
169
+
170
+ # Standard logger for some libs that use stdlib logging
171
+ py_logger = logging.getLogger("backgroundfx_pro")
172
+ if not py_logger.handlers:
173
+ h = logging.StreamHandler()
174
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
175
+ py_logger.addHandler(h)
176
+ py_logger.setLevel(logging.INFO)
177
+
178
+ def _log_web_stack_versions_and_paths():
179
+ import inspect
180
+ try:
181
+ import fastapi, starlette, pydantic, httpx, anyio
182
+ try:
183
+ import pydantic_core
184
+ pc_ver = pydantic_core.__version__
185
+ except Exception:
186
+ pc_ver = "unknown"
187
+ logger.info(
188
+ "[WEB-STACK] fastapi=%s | starlette=%s | pydantic=%s | pydantic-core=%s | httpx=%s | anyio=%s",
189
+ getattr(fastapi, "__version__", "?"),
190
+ getattr(starlette, "__version__", "?"),
191
+ getattr(pydantic, "__version__", "?"),
192
+ pc_ver,
193
+ getattr(httpx, "__version__", "?"),
194
+ getattr(anyio, "__version__", "?"),
195
+ )
196
+ except Exception as e:
197
+ logger.warning("[WEB-STACK] version probe failed: %s", e)
198
+
199
+ try:
200
+ import gradio
201
+ import gradio.routes as gr_routes
202
+ import gradio.queueing as gr_queueing
203
+ logger.info("[PATH] gradio.__file__ = %s", getattr(gradio, "__file__", "?"))
204
+ logger.info("[PATH] gradio.routes = %s", inspect.getfile(gr_routes))
205
+ logger.info("[PATH] gradio.queueing = %s", inspect.getfile(gr_queueing))
206
+ import starlette.exceptions as st_exc
207
+ logger.info("[PATH] starlette.exceptions= %s", inspect.getfile(st_exc))
208
+ except Exception as e:
209
+ logger.warning("[PATH] probe failed: %s", e)
210
+
211
+ def _post_launch_diag():
212
+ try:
213
+ if not torch:
214
+ return
215
+ avail = torch.cuda.is_available()
216
+ logger.info("CUDA available (post-launch): %s", avail)
217
+ if avail:
218
+ idx = torch.cuda.current_device()
219
+ name = torch.cuda.get_device_name(idx)
220
+ cap = torch.cuda.get_device_capability(idx)
221
+ logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
222
+ except Exception as e:
223
+ logger.warning("Post-launch CUDA diag failed: %s", e)
224
+
225
+ # ---------------------------------------------------------------------
226
+ # UI factory (uses your existing builder)
227
+ # ---------------------------------------------------------------------
228
+ def build_ui() -> gr.Blocks:
229
+ # FIX: import from ui_core_interface (not from ui)
230
+ from ui_core_interface import create_interface
231
+ return create_interface()
232
+
233
+ # ---------------------------------------------------------------------
234
+ # Optional: custom FastAPI mount mode
235
+ # ---------------------------------------------------------------------
236
+ def build_fastapi_with_gradio(demo: gr.Blocks):
237
+ """
238
+ Returns a FastAPI app with Gradio mounted at root.
239
+ Also exposes JSON health and a config shim using demo.get_config_file().
240
+ """
241
+ from fastapi import FastAPI
242
+ from fastapi.responses import JSONResponse
243
+
244
+ app = FastAPI(title="VideoBackgroundReplacer2")
245
+
246
+ @app.get("/healthz")
247
+ def _healthz():
248
+ return {"ok": True, "ts": time.time()}
249
+
250
+ @app.get("/config")
251
+ def _config():
252
+ try:
253
+ cfg = demo.get_config_file()
254
+ return JSONResponse(content=cfg)
255
+ except Exception as e:
256
+ return JSONResponse(
257
+ status_code=500,
258
+ content={"error": "config_generation_failed", "detail": str(e)},
259
+ )
260
+
261
+ # Mount Gradio UI at root; our /config route remains at parent level
262
+ app = gr.mount_gradio_app(app, demo, path="/")
263
+ return app
264
+
265
+ # ---------------------------------------------------------------------
266
+ # Entrypoint
267
+ # ---------------------------------------------------------------------
268
+ if __name__ == "__main__":
269
+ host = os.environ.get("HOST", "0.0.0.0")
270
+ port = int(os.environ.get("PORT", "7860"))
271
+ mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
272
+
273
+ logger.info("Launching on %s:%s (mount_mode=%s)…", host, port, mount_mode)
274
+ _log_web_stack_versions_and_paths()
275
+
276
+ demo = build_ui()
277
+ demo.queue(max_size=16, api_open=False)
278
+
279
+ threading.Thread(target=_post_launch_diag, daemon=True).start()
280
+
281
+ if mount_mode:
282
+ try:
283
+ from uvicorn import run as uvicorn_run
284
+ except Exception:
285
+ logger.error("uvicorn is not installed; mount mode cannot start.")
286
+ raise
287
+
288
+ app = build_fastapi_with_gradio(demo)
289
+ uvicorn_run(app=app, host=host, port=port, log_level="info")
290
+ else:
291
+ demo.launch(
292
+ server_name=host,
293
+ server_port=port,
294
+ share=False,
295
+ show_api=False,
296
+ show_error=True,
297
+ quiet=False,
298
+ debug=True,
299
+ max_threads=1,
300
+ )
pipeline_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Optional, Union, Callable
7
+ import logging
8
+ from PIL import Image
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class VideoProcessor:
15
+ def __init__(self, temp_dir: Optional[str] = None):
16
+ """
17
+ Initialize the video processor.
18
+
19
+ Args:
20
+ temp_dir: Directory for temporary files. If None, creates a temp directory.
21
+ """
22
+ self.temp_dir = Path(temp_dir) if temp_dir else Path(tempfile.mkdtemp(prefix="bg_replace_"))
23
+ self.temp_dir.mkdir(parents=True, exist_ok=True)
24
+ self.device = self._get_device()
25
+ logger.info(f"Initialized VideoProcessor with device: {self.device}")
26
+
27
+ def _get_device(self) -> str:
28
+ """Check if CUDA is available."""
29
+ try:
30
+ import torch
31
+ return "cuda" if torch.cuda.is_available() else "cpu"
32
+ except ImportError:
33
+ return "cpu"
34
+
35
+ def _create_static_bg_video(
36
+ self,
37
+ bg_image: np.ndarray,
38
+ reference_video: str,
39
+ output_path: str
40
+ ) -> str:
41
+ """
42
+ Create a static background video matching the input video's duration.
43
+ """
44
+ cap = cv2.VideoCapture(reference_video)
45
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
46
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
47
+ fps = cap.get(cv2.CAP_PROP_FPS)
48
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
+ cap.release()
50
+
51
+ # Resize background image
52
+ bg_image = cv2.resize(bg_image, (width, height))
53
+
54
+ # Write video
55
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
56
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
57
+
58
+ for _ in range(total_frames):
59
+ out.write(bg_image)
60
+
61
+ out.release()
62
+ return output_path
63
+
64
+ def _process_with_pipeline(
65
+ self,
66
+ input_video: str,
67
+ background: Optional[Union[str, np.ndarray]] = None,
68
+ bg_type: str = "blur",
69
+ progress_callback: Optional[Callable] = None
70
+ ) -> str:
71
+ """
72
+ Process video using the two-stage pipeline.
73
+ """
74
+ try:
75
+ # Import the pipeline
76
+ from integrated_pipeline import TwoStageProcessor
77
+
78
+ # Update progress
79
+ if progress_callback:
80
+ progress_callback(0.1, "Initializing pipeline...")
81
+
82
+ # Handle background
83
+ bg_video_path = ""
84
+ if bg_type == "image" and background is not None:
85
+ if isinstance(background, str):
86
+ bg_image = cv2.imread(background)
87
+ else:
88
+ bg_image = background
89
+
90
+ bg_video_path = str(self.temp_dir / "background.mp4")
91
+ self._create_static_bg_video(bg_image, input_video, bg_video_path)
92
+
93
+ # Initialize processor
94
+ processor = TwoStageProcessor(temp_dir=str(self.temp_dir))
95
+
96
+ # Process video
97
+ output_path = str(self.temp_dir / "output.mp4")
98
+
99
+ # Mock click points (center of frame)
100
+ click_points = [[0.5, 0.5]]
101
+
102
+ # Process
103
+ success = processor.process_video(
104
+ input_video=input_video,
105
+ background_video=bg_video_path if bg_type == "image" else "",
106
+ click_points=click_points,
107
+ output_path=output_path,
108
+ use_matanyone=True,
109
+ progress_callback=progress_callback
110
+ )
111
+
112
+ if not success:
113
+ raise RuntimeError("Video processing failed")
114
+
115
+ return output_path
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error in pipeline: {str(e)}")
119
+ raise
120
+
121
+ def process_video(
122
+ self,
123
+ input_path: Union[str, bytes],
124
+ background: Optional[Union[str, np.ndarray]] = None,
125
+ bg_type: str = "blur",
126
+ progress_callback: Optional[Callable] = None
127
+ ) -> bytes:
128
+ """
129
+ Process a video with the given background.
130
+
131
+ Args:
132
+ input_path: Path to input video or bytes
133
+ background: Background image path or numpy array
134
+ bg_type: Type of background ("image", "color", or "blur")
135
+ progress_callback: Optional callback for progress updates
136
+
137
+ Returns:
138
+ Processed video as bytes
139
+ """
140
+ try:
141
+ # Save input to temp file if it's bytes
142
+ if isinstance(input_path, bytes):
143
+ input_video = str(self.temp_dir / "input.mp4")
144
+ with open(input_video, "wb") as f:
145
+ f.write(input_path)
146
+ else:
147
+ input_video = input_path
148
+
149
+ # Process the video
150
+ output_path = self._process_with_pipeline(
151
+ input_video,
152
+ background,
153
+ bg_type,
154
+ progress_callback
155
+ )
156
+
157
+ # Read the output file
158
+ with open(output_path, "rb") as f:
159
+ return f.read()
160
+
161
+ except Exception as e:
162
+ logger.error(f"Error processing video: {str(e)}")
163
+ raise
164
+
165
+ # Global instance
166
+ video_processor = VideoProcessor()
167
+
168
+ def process_video_pipeline(
169
+ input_data: Union[str, bytes],
170
+ background: Optional[Union[str, np.ndarray]] = None,
171
+ bg_type: str = "blur",
172
+ progress_callback: Optional[Callable] = None
173
+ ) -> bytes:
174
+ """
175
+ High-level function to process a video.
176
+
177
+ Args:
178
+ input_data: Input video path or bytes
179
+ background: Background image path or numpy array
180
+ bg_type: Type of background ("image", "color", or "blur")
181
+ progress_callback: Optional progress callback
182
+
183
+ Returns:
184
+ Processed video as bytes
185
+ """
186
+ return video_processor.process_video(
187
+ input_data,
188
+ background,
189
+ bg_type,
190
+ progress_callback
191
+ )
requirements.txt CHANGED
@@ -35,22 +35,21 @@ iopath>=0.1.10,<0.2.0
35
  kornia>=0.7.0,<0.8.0
36
  tqdm>=4.60.0,<5.0.0
37
 
38
- # ===== UI and API =====
39
- # Bump to avoid gradio_client 1.3.0 bug ("bool is not iterable")
40
- gradio==4.42.0
41
 
42
- # ===== Web stack pins for Gradio 4.42.0 =====
43
- fastapi==0.109.2
44
- starlette==0.36.3
45
- uvicorn==0.29.0
46
- httpx==0.27.2
47
- anyio==4.4.0
48
  orjson>=3.10.0
49
 
50
  # ===== Pydantic family (avoid breaking core 2.23.x) =====
51
  pydantic==2.8.2
52
  pydantic-core==2.20.1
53
- annotated-types==0.6.0
54
  typing-extensions==4.12.2
55
 
56
  # ===== Helpers and Utilities =====
@@ -69,4 +68,4 @@ nvidia-ml-py3>=7.352.0,<12.0.0
69
  loguru>=0.6.0,<1.0.0
70
 
71
  # File handling
72
- python-multipart>=0.0.5,<1.0.0
 
35
  kornia>=0.7.0,<0.8.0
36
  tqdm>=4.60.0,<5.0.0
37
 
38
+ # ===== Streamlit UI =====
39
+ streamlit>=1.32.0
40
+ streamlit-webrtc>=0.50.0 # For real-time video processing
41
 
42
+ # ===== Web stack =====
43
+ fastapi>=0.104.0
44
+ uvicorn>=0.24.0
45
+ httpx>=0.25.0
46
+ anyio>=4.0.0
 
47
  orjson>=3.10.0
48
 
49
  # ===== Pydantic family (avoid breaking core 2.23.x) =====
50
  pydantic==2.8.2
51
  pydantic-core==2.20.1
52
+ annotated-types==0.60.0
53
  typing-extensions==4.12.2
54
 
55
  # ===== Helpers and Utilities =====
 
68
  loguru>=0.6.0,<1.0.0
69
 
70
  # File handling
71
+ python-multipart>=0.0.5,<1.0.0
streamlit_app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_ui.py
2
+ import streamlit as st
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import time
7
+ from pathlib import Path
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ import logging
12
+ import io
13
+
14
+ # Add project root to path
15
+ sys.path.append(str(Path(__file__).parent.absolute()))
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Set page config
22
+ st.set_page_config(
23
+ page_title="🎬 Advanced Video Background Replacer",
24
+ page_icon="πŸŽ₯",
25
+ layout="wide",
26
+ initial_sidebar_state="expanded"
27
+ )
28
+
29
+ # Custom CSS for better UI
30
+ st.markdown("""
31
+ <style>
32
+ .main .block-container {
33
+ padding-top: 2rem;
34
+ padding-bottom: 2rem;
35
+ }
36
+ .stButton>button {
37
+ width: 100%;
38
+ background-color: #4CAF50;
39
+ color: white;
40
+ font-weight: bold;
41
+ transition: all 0.3s;
42
+ }
43
+ .stButton>button:hover {
44
+ background-color: #45a049;
45
+ }
46
+ .stProgress > div > div > div > div {
47
+ background-color: #4CAF50;
48
+ }
49
+ .stAlert {
50
+ border-radius: 10px;
51
+ }
52
+ .stTabs [data-baseweb="tab-list"] {
53
+ gap: 10px;
54
+ }
55
+ .stTabs [data-baseweb="tab"] {
56
+ height: 50px;
57
+ white-space: pre;
58
+ background-color: #f0f2f6;
59
+ border-radius: 4px 4px 0 0;
60
+ padding: 10px 20px;
61
+ margin-right: 5px;
62
+ }
63
+ .stTabs [aria-selected="true"] {
64
+ background-color: #4CAF50;
65
+ color: white;
66
+ }
67
+ .video-container {
68
+ border: 2px dashed #4CAF50;
69
+ border-radius: 10px;
70
+ padding: 10px;
71
+ margin-bottom: 20px;
72
+ }
73
+ </style>
74
+ """, unsafe_allow_html=True)
75
+
76
+ def initialize_session_state():
77
+ """Initialize all session state variables"""
78
+ if 'uploaded_video' not in st.session_state:
79
+ st.session_state.uploaded_video = None
80
+ if 'bg_image' not in st.session_state:
81
+ st.session_state.bg_image = None
82
+ if 'bg_color' not in st.session_state:
83
+ st.session_state.bg_color = "#00FF00"
84
+ if 'processed_video_path' not in st.session_state:
85
+ st.session_state.processed_video_path = None
86
+ if 'processing' not in st.session_state:
87
+ st.session_state.processing = False
88
+ if 'progress' not in st.session_state:
89
+ st.session_state.progress = 0
90
+ if 'progress_text' not in st.session_state:
91
+ st.session_state.progress_text = "Ready"
92
+
93
+ def handle_video_upload():
94
+ """Handle video file upload"""
95
+ uploaded = st.file_uploader(
96
+ "πŸ“Ή Upload Video",
97
+ type=["mp4", "mov", "avi"],
98
+ key="video_uploader"
99
+ )
100
+ if uploaded is not None:
101
+ st.session_state.uploaded_video = uploaded
102
+
103
+ def show_video_preview():
104
+ """Show video preview in the UI"""
105
+ st.markdown("### Video Preview")
106
+ if st.session_state.uploaded_video is not None:
107
+ video_bytes = st.session_state.uploaded_video.getvalue()
108
+ st.video(video_bytes)
109
+ st.session_state.uploaded_video.seek(0)
110
+
111
+ def handle_background_selection():
112
+ """Handle background selection UI"""
113
+ st.markdown("### Background Options")
114
+ bg_type = st.radio(
115
+ "Select Background Type:",
116
+ ["Image", "Color", "Blur"],
117
+ horizontal=True,
118
+ index=0
119
+ )
120
+
121
+ if bg_type == "Image":
122
+ bg_image = st.file_uploader(
123
+ "πŸ–ΌοΈ Upload Background Image",
124
+ type=["jpg", "png", "jpeg"],
125
+ key="bg_image_uploader"
126
+ )
127
+ if bg_image is not None:
128
+ st.session_state.bg_image = Image.open(bg_image)
129
+ st.image(
130
+ st.session_state.bg_image,
131
+ caption="Selected Background",
132
+ use_container_width=True
133
+ )
134
+
135
+ elif bg_type == "Color":
136
+ st.session_state.bg_color = st.color_picker(
137
+ "🎨 Choose Background Color",
138
+ st.session_state.bg_color
139
+ )
140
+ color_rgb = tuple(int(st.session_state.bg_color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
141
+ color_display = np.zeros((100, 100, 3), dtype=np.uint8)
142
+ color_display[:, :] = color_rgb[::-1] # RGB to BGR for OpenCV
143
+ st.image(color_display, caption="Selected Color", width=200)
144
+
145
+ return bg_type.lower()
146
+
147
+ def process_video(input_file, background, bg_type="image"):
148
+ """
149
+ Process video with the selected background using SAM2 and MatAnyone pipeline.
150
+ Returns the path to the processed video file.
151
+ """
152
+ try:
153
+ # Create a temporary directory for processing
154
+ with tempfile.TemporaryDirectory() as temp_dir:
155
+ temp_dir = Path(temp_dir)
156
+
157
+ # Save the uploaded video to a temporary file
158
+ input_path = str(temp_dir / "input.mp4")
159
+ with open(input_path, "wb") as f:
160
+ f.write(input_file.getvalue())
161
+
162
+ # Prepare background
163
+ bg_path = None
164
+ if bg_type == "image" and background is not None:
165
+ # Convert PIL Image to OpenCV format
166
+ bg_cv = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
167
+ bg_path = str(temp_dir / "background.jpg")
168
+ cv2.imwrite(bg_path, bg_cv)
169
+ elif bg_type == "color" and hasattr(st.session_state, 'bg_color'):
170
+ # Create a solid color image
171
+ color_hex = st.session_state.bg_color.lstrip('#')
172
+ color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
173
+ bg_path = str(temp_dir / "background.jpg")
174
+ cv2.imwrite(bg_path, np.ones((100, 100, 3), dtype=np.uint8) * color_rgb[::-1])
175
+
176
+ # Set up progress callback
177
+ progress_bar = st.progress(0)
178
+ status_text = st.empty()
179
+
180
+ def progress_callback(progress, message):
181
+ progress = max(0, min(1, float(progress)))
182
+ progress_bar.progress(progress)
183
+ status_text.text(f"Status: {message}")
184
+ st.session_state.progress = int(progress * 100)
185
+ st.session_state.progress_text = message
186
+
187
+ # Process the video
188
+ output_path = str(temp_dir / "output.mp4")
189
+
190
+ # Mock click points (center of the frame)
191
+ click_points = [[0.5, 0.5]]
192
+
193
+ # Import the pipeline processor
194
+ from integrated_pipeline import TwoStageProcessor
195
+
196
+ # Initialize the processor
197
+ processor = TwoStageProcessor(temp_dir=str(temp_dir))
198
+
199
+ # Process the video
200
+ success = processor.process_video(
201
+ input_video=input_path,
202
+ background_video=bg_path if bg_type == "image" else "",
203
+ click_points=click_points,
204
+ output_path=output_path,
205
+ use_matanyone=True,
206
+ progress_callback=progress_callback
207
+ )
208
+
209
+ if not success:
210
+ raise RuntimeError("Video processing failed")
211
+
212
+ # Return the path to the processed video
213
+ return output_path
214
+
215
+ except Exception as e:
216
+ logger.error(f"Error in video processing: {str(e)}", exc_info=True)
217
+ st.error(f"An error occurred during processing: {str(e)}")
218
+ return None
219
+
220
+ def main():
221
+ st.title("🎬 Advanced Video Background Replacer")
222
+ st.markdown("---")
223
+
224
+ # Initialize session state
225
+ initialize_session_state()
226
+
227
+ # Main layout
228
+ col1, col2 = st.columns([1, 1], gap="large")
229
+
230
+ with col1:
231
+ st.header("1. Upload Video")
232
+ handle_video_upload()
233
+ show_video_preview()
234
+
235
+ with col2:
236
+ st.header("2. Background Settings")
237
+ bg_type = handle_background_selection()
238
+
239
+ st.header("3. Process & Download")
240
+ if st.button(
241
+ "πŸš€ Process Video",
242
+ type="primary",
243
+ disabled=not st.session_state.uploaded_video or st.session_state.processing,
244
+ use_container_width=True
245
+ ):
246
+ with st.spinner("Processing video (this may take a few minutes)..."):
247
+ st.session_state.processing = True
248
+
249
+ try:
250
+ # Prepare background based on type
251
+ background = None
252
+ if bg_type == "image" and 'bg_image' in st.session_state and st.session_state.bg_image is not None:
253
+ background = st.session_state.bg_image
254
+ elif bg_type == "color" and 'bg_color' in st.session_state:
255
+ background = st.session_state.bg_color
256
+
257
+ # Process the video
258
+ output_path = process_video(
259
+ st.session_state.uploaded_video,
260
+ background,
261
+ bg_type=bg_type
262
+ )
263
+
264
+ if output_path and os.path.exists(output_path):
265
+ # Store the path to the processed video
266
+ st.session_state.processed_video_path = output_path
267
+ st.success("βœ… Video processing complete!")
268
+ else:
269
+ st.error("❌ Failed to process video. Please check the logs for details.")
270
+
271
+ except Exception as e:
272
+ st.error(f"❌ An error occurred: {str(e)}")
273
+ logger.exception("Video processing failed")
274
+
275
+ finally:
276
+ st.session_state.processing = False
277
+
278
+ # Show processed video if available
279
+ if 'processed_video_path' in st.session_state and st.session_state.processed_video_path:
280
+ st.markdown("### Processed Video")
281
+
282
+ try:
283
+ # Display the video directly from the file
284
+ with open(st.session_state.processed_video_path, 'rb') as f:
285
+ video_bytes = f.read()
286
+ st.video(video_bytes)
287
+
288
+ # Download button
289
+ st.download_button(
290
+ label="πŸ’Ύ Download Processed Video",
291
+ data=video_bytes,
292
+ file_name="processed_video.mp4",
293
+ mime="video/mp4",
294
+ use_container_width=True
295
+ )
296
+ except Exception as e:
297
+ st.error(f"Error displaying video: {str(e)}")
298
+ logger.error(f"Error displaying video: {str(e)}", exc_info=True)
299
+
300
+ if __name__ == "__main__":
301
+ main()