Update to Streamlit UI with new features and logo
Browse files- Dockerfile +23 -27
- VideoBackgroundReplacer2/.dockerignore +78 -0
- VideoBackgroundReplacer2/5.0.0 +60 -0
- VideoBackgroundReplacer2/DEPLOYMENT.md +90 -0
- VideoBackgroundReplacer2/Dockerfile +137 -0
- VideoBackgroundReplacer2/README.md +94 -0
- VideoBackgroundReplacer2/app.py +300 -0
- VideoBackgroundReplacer2/integrated_pipeline.py +421 -0
- VideoBackgroundReplacer2/models/__init__.py +868 -0
- VideoBackgroundReplacer2/models/matanyone_loader.py +290 -0
- VideoBackgroundReplacer2/models/sam2_loader.py +262 -0
- VideoBackgroundReplacer2/pipeline.py +477 -0
- VideoBackgroundReplacer2/requirements.txt +72 -0
- VideoBackgroundReplacer2/two_stage_pipeline.py +388 -0
- VideoBackgroundReplacer2/ui.py +140 -0
- VideoBackgroundReplacer2/ui_core_functionality.py +662 -0
- VideoBackgroundReplacer2/ui_core_interface.py +430 -0
- VideoBackgroundReplacer2/update_pins.py +197 -0
- VideoBackgroundReplacer2/utils/__init__.py +0 -0
- VideoBackgroundReplacer2/utils/paths.py +29 -0
- VideoBackgroundReplacer2/utils/perf_tuning.py +21 -0
- app.py +558 -288
- app_backup.py +300 -0
- pipeline_utils.py +191 -0
- requirements.txt +10 -11
- streamlit_app.py +301 -0
Dockerfile
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# ===============================
|
| 2 |
# Hugging Face Space β Stable Dockerfile
|
| 3 |
-
# CUDA 12.1.1 + PyTorch 2.5.1 (cu121) +
|
| 4 |
# SAM2 installed from source; MatAnyone via pip (repo)
|
| 5 |
# ===============================
|
| 6 |
|
|
@@ -20,7 +20,9 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|
| 20 |
NUMEXPR_NUM_THREADS=1 \
|
| 21 |
HF_HOME=/home/user/app/.hf \
|
| 22 |
TORCH_HOME=/home/user/app/.torch \
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ---- Non-root user ----
|
| 26 |
RUN useradd -m -u 1000 user
|
|
@@ -34,7 +36,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
| 34 |
build-essential gcc g++ pkg-config \
|
| 35 |
libffi-dev libssl-dev libc6-dev \
|
| 36 |
libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 libgomp1 \
|
| 37 |
-
|
| 38 |
|
| 39 |
# ---- Python bootstrap ----
|
| 40 |
RUN python3 -m pip install --upgrade pip setuptools wheel
|
|
@@ -42,17 +44,11 @@ RUN python3 -m pip install --upgrade pip setuptools wheel
|
|
| 42 |
# ---- Install PyTorch (CUDA 12.1 wheels) ----
|
| 43 |
RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
|
| 44 |
torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
|
| 45 |
-
|
| 46 |
import torch
|
| 47 |
print("PyTorch:", torch.__version__)
|
| 48 |
print("CUDA available:", torch.cuda.is_available())
|
| 49 |
print("torch.version.cuda:", getattr(torch.version, "cuda", None))
|
| 50 |
-
try:
|
| 51 |
-
import torchaudio, torchvision
|
| 52 |
-
print("torchaudio:", torchaudio.__version__)
|
| 53 |
-
import torchvision as tv; print("torchvision:", tv.__version__)
|
| 54 |
-
except Exception as e:
|
| 55 |
-
print("aux libs check:", e)
|
| 56 |
PY
|
| 57 |
|
| 58 |
# ---- Copy deps first (better caching) ----
|
|
@@ -92,19 +88,20 @@ RUN mkdir -p /home/user/app/checkpoints /home/user/app/.hf /home/user/app/.torch
|
|
| 92 |
chmod -R 755 /home/user/app && \
|
| 93 |
find /home/user/app -type d -exec chmod 755 {} \; && \
|
| 94 |
find /home/user/app -type f -exec chmod 644 {} \; && \
|
| 95 |
-
chmod +x /home/user/app/
|
| 96 |
|
| 97 |
-
# ---- Healthcheck
|
| 98 |
HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD \
|
| 99 |
["python3","-c","import torch; print('torch', torch.__version__, '| cuda', getattr(torch.version,'cuda',None), '| ok=', torch.cuda.is_available())"]
|
| 100 |
|
| 101 |
# ---- Runtime ----
|
| 102 |
USER user
|
| 103 |
-
EXPOSE
|
| 104 |
|
|
|
|
| 105 |
CMD ["sh", "-c", "\
|
| 106 |
echo '===========================================' && \
|
| 107 |
-
echo '===
|
| 108 |
echo '===========================================' && \
|
| 109 |
echo 'Timestamp:' $(date) && \
|
| 110 |
echo 'Current directory:' $(pwd) && \
|
|
@@ -115,23 +112,22 @@ CMD ["sh", "-c", "\
|
|
| 115 |
echo 'Files in app directory:' && \
|
| 116 |
ls -la && \
|
| 117 |
echo '' && \
|
| 118 |
-
echo '===
|
| 119 |
-
if [ -f
|
| 120 |
-
echo 'β
|
| 121 |
-
echo 'File size:' $(wc -c <
|
| 122 |
-
echo 'File permissions:' $(ls -l
|
| 123 |
echo 'Testing Python imports...' && \
|
| 124 |
-
python3 -B -c 'import
|
| 125 |
python3 -B -c 'import torch; print(\"β
Torch:\", torch.__version__)' && \
|
| 126 |
-
echo 'Testing
|
| 127 |
-
python3 -B -c 'import sys; sys.path.insert(0, \".\"); import
|
| 128 |
echo 'β
All checks passed!'; \
|
| 129 |
else \
|
| 130 |
-
echo 'β ERROR:
|
| 131 |
exit 1; \
|
| 132 |
fi && \
|
| 133 |
echo '' && \
|
| 134 |
-
echo '=== STARTING
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
"]
|
|
|
|
| 1 |
# ===============================
|
| 2 |
# Hugging Face Space β Stable Dockerfile
|
| 3 |
+
# CUDA 12.1.1 + PyTorch 2.5.1 (cu121) + Streamlit 1.32.0
|
| 4 |
# SAM2 installed from source; MatAnyone via pip (repo)
|
| 5 |
# ===============================
|
| 6 |
|
|
|
|
| 20 |
NUMEXPR_NUM_THREADS=1 \
|
| 21 |
HF_HOME=/home/user/app/.hf \
|
| 22 |
TORCH_HOME=/home/user/app/.torch \
|
| 23 |
+
STREAMLIT_SERVER_PORT=8501 \
|
| 24 |
+
STREAMLIT_SERVER_HEADLESS=true \
|
| 25 |
+
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
|
| 26 |
|
| 27 |
# ---- Non-root user ----
|
| 28 |
RUN useradd -m -u 1000 user
|
|
|
|
| 36 |
build-essential gcc g++ pkg-config \
|
| 37 |
libffi-dev libssl-dev libc6-dev \
|
| 38 |
libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 libgomp1 \
|
| 39 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 40 |
|
| 41 |
# ---- Python bootstrap ----
|
| 42 |
RUN python3 -m pip install --upgrade pip setuptools wheel
|
|
|
|
| 44 |
# ---- Install PyTorch (CUDA 12.1 wheels) ----
|
| 45 |
RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
|
| 46 |
torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
|
| 47 |
+
&& python3 - <<'PY'
|
| 48 |
import torch
|
| 49 |
print("PyTorch:", torch.__version__)
|
| 50 |
print("CUDA available:", torch.cuda.is_available())
|
| 51 |
print("torch.version.cuda:", getattr(torch.version, "cuda", None))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
PY
|
| 53 |
|
| 54 |
# ---- Copy deps first (better caching) ----
|
|
|
|
| 88 |
chmod -R 755 /home/user/app && \
|
| 89 |
find /home/user/app -type d -exec chmod 755 {} \; && \
|
| 90 |
find /home/user/app -type f -exec chmod 644 {} \; && \
|
| 91 |
+
chmod +x /home/user/app/app.py || true
|
| 92 |
|
| 93 |
+
# ---- Healthcheck ----
|
| 94 |
HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD \
|
| 95 |
["python3","-c","import torch; print('torch', torch.__version__, '| cuda', getattr(torch.version,'cuda',None), '| ok=', torch.cuda.is_available())"]
|
| 96 |
|
| 97 |
# ---- Runtime ----
|
| 98 |
USER user
|
| 99 |
+
EXPOSE 8501
|
| 100 |
|
| 101 |
+
# Streamlit server command
|
| 102 |
CMD ["sh", "-c", "\
|
| 103 |
echo '===========================================' && \
|
| 104 |
+
echo '=== MYAVATAR STREAMLIT CONTAINER STARTUP ===' && \
|
| 105 |
echo '===========================================' && \
|
| 106 |
echo 'Timestamp:' $(date) && \
|
| 107 |
echo 'Current directory:' $(pwd) && \
|
|
|
|
| 112 |
echo 'Files in app directory:' && \
|
| 113 |
ls -la && \
|
| 114 |
echo '' && \
|
| 115 |
+
echo '=== APP.PY VERIFICATION ===' && \
|
| 116 |
+
if [ -f app.py ]; then \
|
| 117 |
+
echo 'β
app.py found' && \
|
| 118 |
+
echo 'File size:' $(wc -c < app.py) 'bytes' && \
|
| 119 |
+
echo 'File permissions:' $(ls -l app.py) && \
|
| 120 |
echo 'Testing Python imports...' && \
|
| 121 |
+
python3 -B -c 'import streamlit; print(\"β
Streamlit:\", streamlit.__version__)' && \
|
| 122 |
python3 -B -c 'import torch; print(\"β
Torch:\", torch.__version__)' && \
|
| 123 |
+
echo 'Testing app.py import...' && \
|
| 124 |
+
python3 -B -c 'import sys; sys.path.insert(0, \".\"); import app; print(\"β
app.py imports successfully\")' && \
|
| 125 |
echo 'β
All checks passed!'; \
|
| 126 |
else \
|
| 127 |
+
echo 'β ERROR: app.py not found!' && \
|
| 128 |
exit 1; \
|
| 129 |
fi && \
|
| 130 |
echo '' && \
|
| 131 |
+
echo '=== STARTING STREAMLIT SERVER ===' && \
|
| 132 |
+
streamlit run --server.port=8501 --server.address=0.0.0.0 app.py \
|
| 133 |
+
"]
|
|
|
VideoBackgroundReplacer2/.dockerignore
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ===========================
|
| 2 |
+
# .dockerignore for HF Spaces
|
| 3 |
+
# ===========================
|
| 4 |
+
|
| 5 |
+
# VCS
|
| 6 |
+
.git
|
| 7 |
+
.gitignore
|
| 8 |
+
.gitattributes
|
| 9 |
+
|
| 10 |
+
# Python cache / build
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*.pyo
|
| 14 |
+
*.pyd
|
| 15 |
+
*.pdb
|
| 16 |
+
*.egg-info/
|
| 17 |
+
dist/
|
| 18 |
+
build/
|
| 19 |
+
.pytest_cache/
|
| 20 |
+
.python-version
|
| 21 |
+
|
| 22 |
+
# Virtual environments
|
| 23 |
+
.env
|
| 24 |
+
.venv/
|
| 25 |
+
env/
|
| 26 |
+
venv/
|
| 27 |
+
|
| 28 |
+
# External repos (cloned in Docker, not copied from local)
|
| 29 |
+
third_party/
|
| 30 |
+
|
| 31 |
+
# Hugging Face / Torch caches
|
| 32 |
+
.cache/
|
| 33 |
+
huggingface/
|
| 34 |
+
torch/
|
| 35 |
+
data/
|
| 36 |
+
|
| 37 |
+
# HF Space metadata/state
|
| 38 |
+
.hf_space/
|
| 39 |
+
space.log
|
| 40 |
+
gradio_cached_examples/
|
| 41 |
+
gradio_static/
|
| 42 |
+
__outputs__/
|
| 43 |
+
|
| 44 |
+
# Logs & temp files
|
| 45 |
+
*.log
|
| 46 |
+
logs/
|
| 47 |
+
tmp/
|
| 48 |
+
temp/
|
| 49 |
+
*.swp
|
| 50 |
+
.coverage
|
| 51 |
+
coverage.xml
|
| 52 |
+
|
| 53 |
+
# Media test assets
|
| 54 |
+
*.mp4
|
| 55 |
+
*.avi
|
| 56 |
+
*.mov
|
| 57 |
+
*.mkv
|
| 58 |
+
*.png
|
| 59 |
+
*.jpg
|
| 60 |
+
*.jpeg
|
| 61 |
+
*.gif
|
| 62 |
+
|
| 63 |
+
# OS / IDE cruft
|
| 64 |
+
.DS_Store
|
| 65 |
+
Thumbs.db
|
| 66 |
+
.vscode/
|
| 67 |
+
.idea/
|
| 68 |
+
*.sublime-project
|
| 69 |
+
*.sublime-workspace
|
| 70 |
+
|
| 71 |
+
# Node / frontend (if present)
|
| 72 |
+
node_modules/
|
| 73 |
+
npm-debug.log
|
| 74 |
+
yarn-debug.log
|
| 75 |
+
yarn-error.log
|
| 76 |
+
|
| 77 |
+
# ---- Optional: allow specific checkpoints if needed ----
|
| 78 |
+
!checkpoints/
|
VideoBackgroundReplacer2/5.0.0
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Defaulting to user installation because normal site-packages is not writeable
|
| 2 |
+
Requirement already satisfied: gradio in c:\users\mogen\appdata\roaming\python\python313\site-packages (4.44.0)
|
| 3 |
+
Requirement already satisfied: aiofiles<24.0,>=22.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (23.2.1)
|
| 4 |
+
Requirement already satisfied: anyio<5.0,>=3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (4.9.0)
|
| 5 |
+
Requirement already satisfied: fastapi<1.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.115.12)
|
| 6 |
+
Requirement already satisfied: ffmpy in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.6.1)
|
| 7 |
+
Requirement already satisfied: gradio-client==1.3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (1.3.0)
|
| 8 |
+
Requirement already satisfied: httpx>=0.24.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.27.2)
|
| 9 |
+
Requirement already satisfied: huggingface-hub>=0.19.3 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.34.4)
|
| 10 |
+
Requirement already satisfied: importlib-resources<7.0,>=1.3 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (6.5.2)
|
| 11 |
+
Requirement already satisfied: jinja2<4.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (3.1.6)
|
| 12 |
+
Requirement already satisfied: markupsafe~=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.1.5)
|
| 13 |
+
Requirement already satisfied: matplotlib~=3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (3.10.5)
|
| 14 |
+
Requirement already satisfied: numpy<3.0,>=1.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (1.26.4)
|
| 15 |
+
Requirement already satisfied: orjson~=3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (3.11.2)
|
| 16 |
+
Requirement already satisfied: packaging in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (24.2)
|
| 17 |
+
Requirement already satisfied: pandas<3.0,>=1.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.2.3)
|
| 18 |
+
Requirement already satisfied: pillow<11.0,>=8.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (10.4.0)
|
| 19 |
+
Requirement already satisfied: pydantic>=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.11.5)
|
| 20 |
+
Requirement already satisfied: pydub in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.25.1)
|
| 21 |
+
Requirement already satisfied: python-multipart>=0.0.9 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.0.20)
|
| 22 |
+
Requirement already satisfied: pyyaml<7.0,>=5.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (6.0.2)
|
| 23 |
+
Requirement already satisfied: ruff>=0.2.2 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.12.9)
|
| 24 |
+
Requirement already satisfied: semantic-version~=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.10.0)
|
| 25 |
+
Requirement already satisfied: tomlkit==0.12.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.12.0)
|
| 26 |
+
Requirement already satisfied: typer<1.0,>=0.12 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.16.0)
|
| 27 |
+
Requirement already satisfied: typing-extensions~=4.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (4.14.1)
|
| 28 |
+
Requirement already satisfied: urllib3~=2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (2.5.0)
|
| 29 |
+
Requirement already satisfied: uvicorn>=0.14.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio) (0.34.3)
|
| 30 |
+
Requirement already satisfied: fsspec in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio-client==1.3.0->gradio) (2025.5.1)
|
| 31 |
+
Requirement already satisfied: websockets<13.0,>=10.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from gradio-client==1.3.0->gradio) (10.4)
|
| 32 |
+
Requirement already satisfied: idna>=2.8 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from anyio<5.0,>=3.0->gradio) (3.10)
|
| 33 |
+
Requirement already satisfied: sniffio>=1.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)
|
| 34 |
+
Requirement already satisfied: starlette<0.47.0,>=0.40.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from fastapi<1.0->gradio) (0.46.2)
|
| 35 |
+
Requirement already satisfied: certifi in c:\users\mogen\appdata\roaming\python\python313\site-packages (from httpx>=0.24.1->gradio) (2025.7.9)
|
| 36 |
+
Requirement already satisfied: httpcore==1.* in c:\users\mogen\appdata\roaming\python\python313\site-packages (from httpx>=0.24.1->gradio) (1.0.9)
|
| 37 |
+
Requirement already satisfied: h11>=0.16 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.16.0)
|
| 38 |
+
Requirement already satisfied: filelock in c:\users\mogen\appdata\roaming\python\python313\site-packages (from huggingface-hub>=0.19.3->gradio) (3.18.0)
|
| 39 |
+
Requirement already satisfied: requests in c:\users\mogen\appdata\roaming\python\python313\site-packages (from huggingface-hub>=0.19.3->gradio) (2.32.3)
|
| 40 |
+
Requirement already satisfied: tqdm>=4.42.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from huggingface-hub>=0.19.3->gradio) (4.67.1)
|
| 41 |
+
Requirement already satisfied: contourpy>=1.0.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (1.3.3)
|
| 42 |
+
Requirement already satisfied: cycler>=0.10 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (0.12.1)
|
| 43 |
+
Requirement already satisfied: fonttools>=4.22.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (4.59.1)
|
| 44 |
+
Requirement already satisfied: kiwisolver>=1.3.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (1.4.9)
|
| 45 |
+
Requirement already satisfied: pyparsing>=2.3.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (3.2.3)
|
| 46 |
+
Requirement already satisfied: python-dateutil>=2.7 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from matplotlib~=3.0->gradio) (2.8.2)
|
| 47 |
+
Requirement already satisfied: pytz>=2020.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
|
| 48 |
+
Requirement already satisfied: tzdata>=2022.7 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
|
| 49 |
+
Requirement already satisfied: annotated-types>=0.6.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pydantic>=2.0->gradio) (0.7.0)
|
| 50 |
+
Requirement already satisfied: pydantic-core==2.33.2 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pydantic>=2.0->gradio) (2.33.2)
|
| 51 |
+
Requirement already satisfied: typing-inspection>=0.4.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from pydantic>=2.0->gradio) (0.4.1)
|
| 52 |
+
Requirement already satisfied: click>=8.0.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from typer<1.0,>=0.12->gradio) (8.2.1)
|
| 53 |
+
Requirement already satisfied: shellingham>=1.3.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from typer<1.0,>=0.12->gradio) (1.5.4)
|
| 54 |
+
Requirement already satisfied: rich>=10.11.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from typer<1.0,>=0.12->gradio) (14.0.0)
|
| 55 |
+
Requirement already satisfied: colorama in c:\users\mogen\appdata\roaming\python\python313\site-packages (from click>=8.0.0->typer<1.0,>=0.12->gradio) (0.4.6)
|
| 56 |
+
Requirement already satisfied: six>=1.5 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.17.0)
|
| 57 |
+
Requirement already satisfied: markdown-it-py>=2.2.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)
|
| 58 |
+
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.1)
|
| 59 |
+
Requirement already satisfied: charset-normalizer<4,>=2 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.4.2)
|
| 60 |
+
Requirement already satisfied: mdurl~=0.1 in c:\users\mogen\appdata\roaming\python\python313\site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)
|
VideoBackgroundReplacer2/DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VideoBackgroundReplacer2 Deployment Guide
|
| 2 |
+
|
| 3 |
+
This guide provides instructions for deploying the VideoBackgroundReplacer2 application to Hugging Face Spaces with GPU acceleration.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Docker
|
| 8 |
+
- Git
|
| 9 |
+
- Python 3.8+
|
| 10 |
+
- NVIDIA Container Toolkit (for local GPU testing)
|
| 11 |
+
- Hugging Face account with access to GPU Spaces
|
| 12 |
+
|
| 13 |
+
## Local Development
|
| 14 |
+
|
| 15 |
+
### 1. Clone the repository
|
| 16 |
+
```bash
|
| 17 |
+
git clone <repository-url>
|
| 18 |
+
cd VideoBackgroundReplacer2
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### 2. Build the Docker image
|
| 22 |
+
```bash
|
| 23 |
+
# Make the build script executable
|
| 24 |
+
chmod +x build_and_deploy.sh
|
| 25 |
+
|
| 26 |
+
# Build the image
|
| 27 |
+
./build_and_deploy.sh
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### 3. Run the container locally
|
| 31 |
+
```bash
|
| 32 |
+
docker run --gpus all -p 7860:7860 -v $(pwd)/checkpoints:/home/user/app/checkpoints videobackgroundreplacer2:latest
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Hugging Face Spaces Deployment
|
| 36 |
+
|
| 37 |
+
### 1. Create a new Space
|
| 38 |
+
- Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 39 |
+
- Click "Create new Space"
|
| 40 |
+
- Select "Docker" as the SDK
|
| 41 |
+
- Choose a name and set the space to private if needed
|
| 42 |
+
- Select GPU as the hardware
|
| 43 |
+
|
| 44 |
+
### 2. Configure the Space
|
| 45 |
+
Add the following environment variables to your Space settings:
|
| 46 |
+
- `SAM2_DEVICE`: `cuda`
|
| 47 |
+
- `MATANY_DEVICE`: `cuda`
|
| 48 |
+
- `PYTORCH_CUDA_ALLOC_CONF`: `max_split_size_mb:256,garbage_collection_threshold:0.8`
|
| 49 |
+
- `TORCH_CUDA_ARCH_LIST`: `7.5 8.0 8.6+PTX`
|
| 50 |
+
|
| 51 |
+
### 3. Deploy to Hugging Face
|
| 52 |
+
```bash
|
| 53 |
+
# Set your Hugging Face token
|
| 54 |
+
export HF_TOKEN=your_hf_token
|
| 55 |
+
export HF_USERNAME=your_username
|
| 56 |
+
|
| 57 |
+
# Build and deploy
|
| 58 |
+
./build_and_deploy.sh
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Health Check
|
| 62 |
+
|
| 63 |
+
You can verify the installation by running:
|
| 64 |
+
```bash
|
| 65 |
+
docker run --rm videobackgroundreplacer2:latest python3 health_check.py
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Troubleshooting
|
| 69 |
+
|
| 70 |
+
### Build Failures
|
| 71 |
+
- Ensure you have enough disk space (at least 10GB free)
|
| 72 |
+
- Check Docker logs for specific error messages
|
| 73 |
+
- Verify your internet connection is stable
|
| 74 |
+
|
| 75 |
+
### Runtime Issues
|
| 76 |
+
- Check container logs: `docker logs <container_id>`
|
| 77 |
+
- Verify GPU is detected: `nvidia-smi` inside the container
|
| 78 |
+
- Check disk space: `df -h`
|
| 79 |
+
|
| 80 |
+
## Performance Optimization
|
| 81 |
+
|
| 82 |
+
- For faster inference, use the `sam2_hiera_tiny` model
|
| 83 |
+
- Adjust batch size based on available GPU memory
|
| 84 |
+
- Enable gradient checkpointing for large models
|
| 85 |
+
|
| 86 |
+
## Monitoring
|
| 87 |
+
|
| 88 |
+
- Use `nvidia-smi` to monitor GPU usage
|
| 89 |
+
- Check container logs for any warnings or errors
|
| 90 |
+
- Monitor memory usage with `htop` or similar tools
|
VideoBackgroundReplacer2/Dockerfile
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ===============================
|
| 2 |
+
# Hugging Face Space β Stable Dockerfile
|
| 3 |
+
# CUDA 12.1.1 + PyTorch 2.5.1 (cu121) + Gradio 4.41.3
|
| 4 |
+
# SAM2 installed from source; MatAnyone via pip (repo)
|
| 5 |
+
# ===============================
|
| 6 |
+
|
| 7 |
+
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
| 8 |
+
|
| 9 |
+
# ---- Environment (runtime hygiene) ----
|
| 10 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 11 |
+
PYTHONUNBUFFERED=1 \
|
| 12 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 13 |
+
PIP_NO_CACHE_DIR=1 \
|
| 14 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
| 15 |
+
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX" \
|
| 16 |
+
CUDA_VISIBLE_DEVICES="0" \
|
| 17 |
+
OMP_NUM_THREADS=4 \
|
| 18 |
+
OPENBLAS_NUM_THREADS=1 \
|
| 19 |
+
MKL_NUM_THREADS=1 \
|
| 20 |
+
NUMEXPR_NUM_THREADS=1 \
|
| 21 |
+
HF_HOME=/home/user/app/.hf \
|
| 22 |
+
TORCH_HOME=/home/user/app/.torch \
|
| 23 |
+
GRADIO_SERVER_PORT=7860
|
| 24 |
+
|
| 25 |
+
# ---- Non-root user ----
|
| 26 |
+
RUN useradd -m -u 1000 user
|
| 27 |
+
ENV HOME=/home/user
|
| 28 |
+
WORKDIR $HOME/app
|
| 29 |
+
|
| 30 |
+
# ---- System deps ----
|
| 31 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 32 |
+
git ffmpeg wget curl \
|
| 33 |
+
python3 python3-pip python3-venv python3-dev \
|
| 34 |
+
build-essential gcc g++ pkg-config \
|
| 35 |
+
libffi-dev libssl-dev libc6-dev \
|
| 36 |
+
libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 libgomp1 \
|
| 37 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 38 |
+
|
| 39 |
+
# ---- Python bootstrap ----
|
| 40 |
+
RUN python3 -m pip install --upgrade pip setuptools wheel
|
| 41 |
+
|
| 42 |
+
# ---- Install PyTorch (CUDA 12.1 wheels) ----
|
| 43 |
+
RUN python3 -m pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cu121 \
|
| 44 |
+
torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
|
| 45 |
+
&& python3 - <<'PY'
|
| 46 |
+
import torch
|
| 47 |
+
print("PyTorch:", torch.__version__)
|
| 48 |
+
print("CUDA available:", torch.cuda.is_available())
|
| 49 |
+
print("torch.version.cuda:", getattr(torch.version, "cuda", None))
|
| 50 |
+
try:
|
| 51 |
+
import torchaudio, torchvision
|
| 52 |
+
print("torchaudio:", torchaudio.__version__)
|
| 53 |
+
import torchvision as tv; print("torchvision:", tv.__version__)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print("aux libs check:", e)
|
| 56 |
+
PY
|
| 57 |
+
|
| 58 |
+
# ---- Copy deps first (better caching) ----
|
| 59 |
+
COPY --chown=user:user requirements.txt ./
|
| 60 |
+
|
| 61 |
+
# ---- Install remaining Python deps ----
|
| 62 |
+
RUN python3 -m pip install --no-cache-dir -r requirements.txt
|
| 63 |
+
|
| 64 |
+
# ---- MatAnyone (pip install from repo with retry) ----
|
| 65 |
+
RUN echo "Installing MatAnyone..." && \
|
| 66 |
+
(python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone || \
|
| 67 |
+
(echo "Retrying MatAnyone..." && \
|
| 68 |
+
python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone)) && \
|
| 69 |
+
python3 -c "import matanyone; print('MatAnyone import OK')"
|
| 70 |
+
|
| 71 |
+
# ---- App code ----
|
| 72 |
+
COPY --chown=user:user . .
|
| 73 |
+
|
| 74 |
+
# ---- SAM2 from source (editable) ----
|
| 75 |
+
RUN echo "Installing SAM2 (editable)..." && \
|
| 76 |
+
git clone --depth=1 https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
|
| 77 |
+
cd third_party/sam2 && python3 -m pip install --no-cache-dir -e .
|
| 78 |
+
|
| 79 |
+
# ---- App env ----
|
| 80 |
+
ENV PYTHONPATH=/home/user/app:/home/user/app/third_party:/home/user/app/third_party/sam2 \
|
| 81 |
+
FFMPEG_BIN=ffmpeg \
|
| 82 |
+
THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
|
| 83 |
+
ENABLE_MATANY=1 \
|
| 84 |
+
SAM2_DEVICE=cuda \
|
| 85 |
+
MATANY_DEVICE=cuda \
|
| 86 |
+
TF_CPP_MIN_LOG_LEVEL=2 \
|
| 87 |
+
SAM2_CHECKPOINT=/home/user/app/checkpoints/sam2_hiera_large.pt
|
| 88 |
+
|
| 89 |
+
# ---- Create writable dirs (caches + checkpoints) ----
|
| 90 |
+
RUN mkdir -p /home/user/app/checkpoints /home/user/app/.hf /home/user/app/.torch && \
|
| 91 |
+
chown -R user:user /home/user/app && \
|
| 92 |
+
chmod -R 755 /home/user/app && \
|
| 93 |
+
find /home/user/app -type d -exec chmod 755 {} \; && \
|
| 94 |
+
find /home/user/app -type f -exec chmod 644 {} \; && \
|
| 95 |
+
chmod +x /home/user/app/ui.py || true
|
| 96 |
+
|
| 97 |
+
# ---- Healthcheck (use exec-form, no heredoc) ----
|
| 98 |
+
HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD \
|
| 99 |
+
["python3","-c","import torch; print('torch', torch.__version__, '| cuda', getattr(torch.version,'cuda',None), '| ok=', torch.cuda.is_available())"]
|
| 100 |
+
|
| 101 |
+
# ---- Runtime ----
|
| 102 |
+
USER user
|
| 103 |
+
EXPOSE 7860
|
| 104 |
+
|
| 105 |
+
CMD ["sh", "-c", "\
|
| 106 |
+
echo '===========================================' && \
|
| 107 |
+
echo '=== BACKGROUNDFX PRO CONTAINER STARTUP ===' && \
|
| 108 |
+
echo '===========================================' && \
|
| 109 |
+
echo 'Timestamp:' $(date) && \
|
| 110 |
+
echo 'Current directory:' $(pwd) && \
|
| 111 |
+
echo 'Current user:' $(whoami) && \
|
| 112 |
+
echo 'User ID:' $(id) && \
|
| 113 |
+
echo '' && \
|
| 114 |
+
echo '=== FILE SYSTEM CHECK ===' && \
|
| 115 |
+
echo 'Files in app directory:' && \
|
| 116 |
+
ls -la && \
|
| 117 |
+
echo '' && \
|
| 118 |
+
echo '=== UI.PY VERIFICATION ===' && \
|
| 119 |
+
if [ -f ui.py ]; then \
|
| 120 |
+
echo 'β
ui.py found' && \
|
| 121 |
+
echo 'File size:' $(wc -c < ui.py) 'bytes' && \
|
| 122 |
+
echo 'File permissions:' $(ls -l ui.py) && \
|
| 123 |
+
echo 'Testing Python imports...' && \
|
| 124 |
+
python3 -B -c 'import gradio; print(\"β
Gradio:\", gradio.__version__)' && \
|
| 125 |
+
python3 -B -c 'import torch; print(\"β
Torch:\", torch.__version__)' && \
|
| 126 |
+
echo 'Testing ui.py import...' && \
|
| 127 |
+
python3 -B -c 'import sys; sys.path.insert(0, \".\"); import ui; print(\"β
ui.py imports successfully\")' && \
|
| 128 |
+
echo 'β
All checks passed!'; \
|
| 129 |
+
else \
|
| 130 |
+
echo 'β ERROR: ui.py not found!' && \
|
| 131 |
+
exit 1; \
|
| 132 |
+
fi && \
|
| 133 |
+
echo '' && \
|
| 134 |
+
echo '=== STARTING APPLICATION ===' && \
|
| 135 |
+
echo 'Launching ui.py with bytecode disabled...' && \
|
| 136 |
+
python3 -B -u ui.py \
|
| 137 |
+
"]
|
VideoBackgroundReplacer2/README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: π¬ BackgroundFX Pro - SAM2 + MatAnyone
|
| 3 |
+
emoji: π₯
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
license: mit
|
| 9 |
+
tags:
|
| 10 |
+
- video
|
| 11 |
+
- background-removal
|
| 12 |
+
- segmentation
|
| 13 |
+
- matting
|
| 14 |
+
- SAM2
|
| 15 |
+
- MatAnyone
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# π¬ BackgroundFX Pro β Professional Video Background Replacement
|
| 19 |
+
|
| 20 |
+
BackgroundFX Pro is a GPU-accelerated app for Hugging Face Spaces (Docker) that replaces video backgrounds using:
|
| 21 |
+
- **SAM2** β high-quality object segmentation
|
| 22 |
+
- **MatAnyone** β temporal video matting for stable alpha over time
|
| 23 |
+
|
| 24 |
+
Built on: **CUDA 12.1.1**, **PyTorch 2.5.1 (cu121)**, **torchvision 0.20.1**, **Gradio 4.41.0**.
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## β¨ Features
|
| 29 |
+
|
| 30 |
+
- Replace backgrounds with: **solid color**, **AI-generated** image (procedural), **custom uploaded image**, or **Unsplash** search
|
| 31 |
+
- Optimized for **T4 GPUs** on Hugging Face
|
| 32 |
+
- Caching & logs stored in the repo volume:
|
| 33 |
+
- HF cache β `./.hf`
|
| 34 |
+
- Torch cache β `./.torch`
|
| 35 |
+
- App data & logs β `./data` (see `data/run.log`)
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## π Try It
|
| 40 |
+
|
| 41 |
+
Open the Space in your browser (GPU required):
|
| 42 |
+
https://huggingface.co/spaces/MogensR/VideoBackgroundReplacer2
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## π±οΈ How to Use
|
| 47 |
+
|
| 48 |
+
1. **Upload a video** (`.mp4`, `.avi`, `.mov`, `.mkv`).
|
| 49 |
+
2. Choose a **Background Type**: Upload Image, AI Generate, Gradient, Solid, or Unsplash.
|
| 50 |
+
3. If not uploading, enter a prompt and click **Generate Background**.
|
| 51 |
+
4. Click **Process Video**.
|
| 52 |
+
5. Preview and **Download Result**.
|
| 53 |
+
|
| 54 |
+
> Tip: Start with 720p/1080p on T4; 4K can exceed memory.
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## ποΈ Project Structure (key files)
|
| 59 |
+
|
| 60 |
+
- `Dockerfile`
|
| 61 |
+
- `requirements.txt`
|
| 62 |
+
- `ui.py`
|
| 63 |
+
- `ui_core_interface.py`
|
| 64 |
+
- `ui_core_functionality.py`
|
| 65 |
+
- `two_stage_pipeline.py`
|
| 66 |
+
- `models/sam2_loader.py`
|
| 67 |
+
- `models/matanyone_loader.py`
|
| 68 |
+
- `utils/__init__.py`
|
| 69 |
+
- `data/` (created at runtime for logs/outputs)
|
| 70 |
+
- `tmp/` (created at runtime for jobs/temp files)
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## βοΈ Runtime Notes
|
| 75 |
+
|
| 76 |
+
- Binds to `PORT` / `GRADIO_SERVER_PORT` (defaults to **7860**).
|
| 77 |
+
- Heartbeat logs every ~2s with memory & disk stats.
|
| 78 |
+
- If thereβs no final βPROCESS EXITINGβ line, it was likely an **OOM** or hard kill.
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## π§ͺ Local Development (Docker)
|
| 83 |
+
|
| 84 |
+
Requires an NVIDIA GPU with CUDA drivers.
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
git clone https://huggingface.co/spaces/MogensR/VideoBackgroundReplacer2
|
| 88 |
+
cd VideoBackgroundReplacer2
|
| 89 |
+
|
| 90 |
+
# Build (Ubuntu 22.04, CUDA 12.1.1; installs Torch 2.5.1+cu121)
|
| 91 |
+
docker build -t backgroundfx-pro .
|
| 92 |
+
|
| 93 |
+
# Run
|
| 94 |
+
docker run --gpus all -p 7860:7860 backgroundfx-pro
|
VideoBackgroundReplacer2/app.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
|
| 4 |
+
=======================================================
|
| 5 |
+
- Sets up Gradio UI and launches pipeline
|
| 6 |
+
- Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
|
| 7 |
+
|
| 8 |
+
Changes (2025-09-18):
|
| 9 |
+
- Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
|
| 10 |
+
- Added toggleable "mount mode": run Gradio inside our own FastAPI app
|
| 11 |
+
and provide a safe /config route shim (uses demo.get_config_file()).
|
| 12 |
+
- Kept your startup diagnostics, GPU logging, and heartbeats
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------
|
| 18 |
+
# Imports & basic setup
|
| 19 |
+
# ---------------------------------------------------------------------
|
| 20 |
+
import sys
|
| 21 |
+
import os
|
| 22 |
+
import gc
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import threading
|
| 26 |
+
import time
|
| 27 |
+
import warnings
|
| 28 |
+
import traceback
|
| 29 |
+
import subprocess
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from loguru import logger
|
| 32 |
+
|
| 33 |
+
# Logging (loguru to stderr)
|
| 34 |
+
logger.remove()
|
| 35 |
+
logger.add(
|
| 36 |
+
sys.stderr,
|
| 37 |
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> "
|
| 38 |
+
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Warnings
|
| 42 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 43 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 44 |
+
warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
|
| 45 |
+
|
| 46 |
+
# Environment (lightweight & safe in Spaces)
|
| 47 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 48 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 49 |
+
|
| 50 |
+
# Paths
|
| 51 |
+
BASE_DIR = Path(__file__).parent.absolute()
|
| 52 |
+
THIRD_PARTY_DIR = BASE_DIR / "third_party"
|
| 53 |
+
SAM2_DIR = THIRD_PARTY_DIR / "sam2"
|
| 54 |
+
CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
|
| 55 |
+
|
| 56 |
+
# Python path extends
|
| 57 |
+
for p in (str(THIRD_PARTY_DIR), str(SAM2_DIR)):
|
| 58 |
+
if p not in sys.path:
|
| 59 |
+
sys.path.insert(0, p)
|
| 60 |
+
|
| 61 |
+
logger.info(f"Base directory: {BASE_DIR}")
|
| 62 |
+
logger.info(f"Python path[0:5]: {sys.path[:5]}")
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------
|
| 65 |
+
# GPU / Torch diagnostics (non-blocking)
|
| 66 |
+
# ---------------------------------------------------------------------
|
| 67 |
+
try:
|
| 68 |
+
import torch
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.warning("Torch import failed at startup: %s", e)
|
| 71 |
+
torch = None
|
| 72 |
+
|
| 73 |
+
DEVICE = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
|
| 74 |
+
if DEVICE == "cuda":
|
| 75 |
+
os.environ["SAM2_DEVICE"] = "cuda"
|
| 76 |
+
os.environ["MATANY_DEVICE"] = "cuda"
|
| 77 |
+
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
|
| 78 |
+
try:
|
| 79 |
+
logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
|
| 80 |
+
except Exception:
|
| 81 |
+
logger.info("CUDA device name not available at startup.")
|
| 82 |
+
else:
|
| 83 |
+
os.environ["SAM2_DEVICE"] = "cpu"
|
| 84 |
+
os.environ["MATANY_DEVICE"] = "cpu"
|
| 85 |
+
logger.warning("CUDA not available, falling back to CPU")
|
| 86 |
+
|
| 87 |
+
def verify_models():
|
| 88 |
+
"""Verify critical model files exist and are loadable (cheap checks)."""
|
| 89 |
+
results = {"status": "success", "details": {}}
|
| 90 |
+
try:
|
| 91 |
+
sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
|
| 92 |
+
if not os.path.exists(sam2_model_path):
|
| 93 |
+
raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
|
| 94 |
+
# Cheap load test (map to CPU to avoid VRAM use during boot)
|
| 95 |
+
if torch:
|
| 96 |
+
sd = torch.load(sam2_model_path, map_location="cpu")
|
| 97 |
+
if not isinstance(sd, dict):
|
| 98 |
+
raise ValueError("Invalid SAM2 checkpoint format")
|
| 99 |
+
results["details"]["sam2"] = {
|
| 100 |
+
"status": "success",
|
| 101 |
+
"path": sam2_model_path,
|
| 102 |
+
"size_mb": round(os.path.getsize(sam2_model_path) / (1024 * 1024), 2),
|
| 103 |
+
}
|
| 104 |
+
except Exception as e:
|
| 105 |
+
results["status"] = "error"
|
| 106 |
+
results["details"]["sam2"] = {
|
| 107 |
+
"status": "error",
|
| 108 |
+
"error": str(e),
|
| 109 |
+
"traceback": traceback.format_exc(),
|
| 110 |
+
}
|
| 111 |
+
return results
|
| 112 |
+
|
| 113 |
+
def run_startup_diagnostics():
|
| 114 |
+
diag = {
|
| 115 |
+
"system": {
|
| 116 |
+
"python": sys.version,
|
| 117 |
+
"pytorch": getattr(torch, "__version__", None) if torch else None,
|
| 118 |
+
"cuda_available": bool(torch and torch.cuda.is_available()),
|
| 119 |
+
"device_count": (torch.cuda.device_count() if torch and torch.cuda.is_available() else 0),
|
| 120 |
+
"cuda_version": getattr(getattr(torch, "version", None), "cuda", None) if torch else None,
|
| 121 |
+
},
|
| 122 |
+
"paths": {
|
| 123 |
+
"base_dir": str(BASE_DIR),
|
| 124 |
+
"checkpoints_dir": str(CHECKPOINTS_DIR),
|
| 125 |
+
"sam2_dir": str(SAM2_DIR),
|
| 126 |
+
},
|
| 127 |
+
"env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
|
| 128 |
+
}
|
| 129 |
+
diag["model_verification"] = verify_models()
|
| 130 |
+
return diag
|
| 131 |
+
|
| 132 |
+
startup_diag = run_startup_diagnostics()
|
| 133 |
+
logger.info("Startup diagnostics completed")
|
| 134 |
+
|
| 135 |
+
# Noisy heartbeat so logs show life during import time
|
| 136 |
+
def _heartbeat():
|
| 137 |
+
i = 0
|
| 138 |
+
while True:
|
| 139 |
+
i += 1
|
| 140 |
+
print(f"[startup-heartbeat] {i*5}sβ¦", flush=True)
|
| 141 |
+
time.sleep(5)
|
| 142 |
+
|
| 143 |
+
threading.Thread(target=_heartbeat, daemon=True).start()
|
| 144 |
+
|
| 145 |
+
# Optional perf tuning import (non-fatal)
|
| 146 |
+
try:
|
| 147 |
+
import perf_tuning # noqa: F401
|
| 148 |
+
logger.info("perf_tuning imported successfully.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.info("perf_tuning not available: %s", e)
|
| 151 |
+
|
| 152 |
+
# MatAnyone non-instantiating probe
|
| 153 |
+
try:
|
| 154 |
+
import inspect
|
| 155 |
+
from matanyone.inference import inference_core as ic # type: ignore
|
| 156 |
+
sigs = {}
|
| 157 |
+
for name in ("InferenceCore",):
|
| 158 |
+
obj = getattr(ic, name, None)
|
| 159 |
+
if obj:
|
| 160 |
+
sigs[name] = "callable" if callable(obj) else "present"
|
| 161 |
+
logger.info(f"[MATANY] probe (non-instantiating): {sigs}")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.info(f"[MATANY] probe skipped: {e}")
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------
|
| 166 |
+
# Gradio import and web-stack probes
|
| 167 |
+
# ---------------------------------------------------------------------
|
| 168 |
+
import gradio as gr
|
| 169 |
+
|
| 170 |
+
# Standard logger for some libs that use stdlib logging
|
| 171 |
+
py_logger = logging.getLogger("backgroundfx_pro")
|
| 172 |
+
if not py_logger.handlers:
|
| 173 |
+
h = logging.StreamHandler()
|
| 174 |
+
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 175 |
+
py_logger.addHandler(h)
|
| 176 |
+
py_logger.setLevel(logging.INFO)
|
| 177 |
+
|
| 178 |
+
def _log_web_stack_versions_and_paths():
|
| 179 |
+
import inspect
|
| 180 |
+
try:
|
| 181 |
+
import fastapi, starlette, pydantic, httpx, anyio
|
| 182 |
+
try:
|
| 183 |
+
import pydantic_core
|
| 184 |
+
pc_ver = pydantic_core.__version__
|
| 185 |
+
except Exception:
|
| 186 |
+
pc_ver = "unknown"
|
| 187 |
+
logger.info(
|
| 188 |
+
"[WEB-STACK] fastapi=%s | starlette=%s | pydantic=%s | pydantic-core=%s | httpx=%s | anyio=%s",
|
| 189 |
+
getattr(fastapi, "__version__", "?"),
|
| 190 |
+
getattr(starlette, "__version__", "?"),
|
| 191 |
+
getattr(pydantic, "__version__", "?"),
|
| 192 |
+
pc_ver,
|
| 193 |
+
getattr(httpx, "__version__", "?"),
|
| 194 |
+
getattr(anyio, "__version__", "?"),
|
| 195 |
+
)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.warning("[WEB-STACK] version probe failed: %s", e)
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
import gradio
|
| 201 |
+
import gradio.routes as gr_routes
|
| 202 |
+
import gradio.queueing as gr_queueing
|
| 203 |
+
logger.info("[PATH] gradio.__file__ = %s", getattr(gradio, "__file__", "?"))
|
| 204 |
+
logger.info("[PATH] gradio.routes = %s", inspect.getfile(gr_routes))
|
| 205 |
+
logger.info("[PATH] gradio.queueing = %s", inspect.getfile(gr_queueing))
|
| 206 |
+
import starlette.exceptions as st_exc
|
| 207 |
+
logger.info("[PATH] starlette.exceptions= %s", inspect.getfile(st_exc))
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.warning("[PATH] probe failed: %s", e)
|
| 210 |
+
|
| 211 |
+
def _post_launch_diag():
|
| 212 |
+
try:
|
| 213 |
+
if not torch:
|
| 214 |
+
return
|
| 215 |
+
avail = torch.cuda.is_available()
|
| 216 |
+
logger.info("CUDA available (post-launch): %s", avail)
|
| 217 |
+
if avail:
|
| 218 |
+
idx = torch.cuda.current_device()
|
| 219 |
+
name = torch.cuda.get_device_name(idx)
|
| 220 |
+
cap = torch.cuda.get_device_capability(idx)
|
| 221 |
+
logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning("Post-launch CUDA diag failed: %s", e)
|
| 224 |
+
|
| 225 |
+
# ---------------------------------------------------------------------
|
| 226 |
+
# UI factory (uses your existing builder)
|
| 227 |
+
# ---------------------------------------------------------------------
|
| 228 |
+
def build_ui() -> gr.Blocks:
|
| 229 |
+
# FIX: import from ui_core_interface (not from ui)
|
| 230 |
+
from ui_core_interface import create_interface
|
| 231 |
+
return create_interface()
|
| 232 |
+
|
| 233 |
+
# ---------------------------------------------------------------------
|
| 234 |
+
# Optional: custom FastAPI mount mode
|
| 235 |
+
# ---------------------------------------------------------------------
|
| 236 |
+
def build_fastapi_with_gradio(demo: gr.Blocks):
|
| 237 |
+
"""
|
| 238 |
+
Returns a FastAPI app with Gradio mounted at root.
|
| 239 |
+
Also exposes JSON health and a config shim using demo.get_config_file().
|
| 240 |
+
"""
|
| 241 |
+
from fastapi import FastAPI
|
| 242 |
+
from fastapi.responses import JSONResponse
|
| 243 |
+
|
| 244 |
+
app = FastAPI(title="VideoBackgroundReplacer2")
|
| 245 |
+
|
| 246 |
+
@app.get("/healthz")
|
| 247 |
+
def _healthz():
|
| 248 |
+
return {"ok": True, "ts": time.time()}
|
| 249 |
+
|
| 250 |
+
@app.get("/config")
|
| 251 |
+
def _config():
|
| 252 |
+
try:
|
| 253 |
+
cfg = demo.get_config_file()
|
| 254 |
+
return JSONResponse(content=cfg)
|
| 255 |
+
except Exception as e:
|
| 256 |
+
return JSONResponse(
|
| 257 |
+
status_code=500,
|
| 258 |
+
content={"error": "config_generation_failed", "detail": str(e)},
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Mount Gradio UI at root; our /config route remains at parent level
|
| 262 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
| 263 |
+
return app
|
| 264 |
+
|
| 265 |
+
# ---------------------------------------------------------------------
|
| 266 |
+
# Entrypoint
|
| 267 |
+
# ---------------------------------------------------------------------
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
host = os.environ.get("HOST", "0.0.0.0")
|
| 270 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 271 |
+
mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
|
| 272 |
+
|
| 273 |
+
logger.info("Launching on %s:%s (mount_mode=%s)β¦", host, port, mount_mode)
|
| 274 |
+
_log_web_stack_versions_and_paths()
|
| 275 |
+
|
| 276 |
+
demo = build_ui()
|
| 277 |
+
demo.queue(max_size=16, api_open=False)
|
| 278 |
+
|
| 279 |
+
threading.Thread(target=_post_launch_diag, daemon=True).start()
|
| 280 |
+
|
| 281 |
+
if mount_mode:
|
| 282 |
+
try:
|
| 283 |
+
from uvicorn import run as uvicorn_run
|
| 284 |
+
except Exception:
|
| 285 |
+
logger.error("uvicorn is not installed; mount mode cannot start.")
|
| 286 |
+
raise
|
| 287 |
+
|
| 288 |
+
app = build_fastapi_with_gradio(demo)
|
| 289 |
+
uvicorn_run(app=app, host=host, port=port, log_level="info")
|
| 290 |
+
else:
|
| 291 |
+
demo.launch(
|
| 292 |
+
server_name=host,
|
| 293 |
+
server_port=port,
|
| 294 |
+
share=False,
|
| 295 |
+
show_api=False,
|
| 296 |
+
show_error=True,
|
| 297 |
+
quiet=False,
|
| 298 |
+
debug=True,
|
| 299 |
+
max_threads=1,
|
| 300 |
+
)
|
VideoBackgroundReplacer2/integrated_pipeline.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
integrated_pipeline.py - Two-stage pipeline with fallback compatibility
|
| 4 |
+
- Stage 1: SAM2 -> lossless mask stream + metadata, then unload SAM2
|
| 5 |
+
- Stage 2: Read masks -> MatAnyone -> composite -> final output
|
| 6 |
+
- Maintains compatibility with existing UI calls
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import gc
|
| 12 |
+
import json
|
| 13 |
+
import subprocess
|
| 14 |
+
import tempfile
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, Any, Optional, Tuple
|
| 17 |
+
import numpy as np
|
| 18 |
+
import cv2
|
| 19 |
+
|
| 20 |
+
# Add the parent directory to Python path for imports
|
| 21 |
+
current_dir = Path(__file__).parent
|
| 22 |
+
parent_dir = current_dir.parent
|
| 23 |
+
sys.path.append(str(parent_dir))
|
| 24 |
+
|
| 25 |
+
class TwoStageProcessor:
|
| 26 |
+
def __init__(self, temp_dir: Optional[str] = None):
|
| 27 |
+
self.temp_dir = Path(temp_dir) if temp_dir else Path(tempfile.mkdtemp())
|
| 28 |
+
self.temp_dir.mkdir(exist_ok=True)
|
| 29 |
+
|
| 30 |
+
# Stage outputs
|
| 31 |
+
self.masks_path = self.temp_dir / "masks.mkv"
|
| 32 |
+
self.metadata_path = self.temp_dir / "meta.json"
|
| 33 |
+
|
| 34 |
+
def process_video(self, input_video: str, background_video: str,
|
| 35 |
+
click_points: list, output_path: str,
|
| 36 |
+
use_matanyone: bool = True, progress_callback=None) -> bool:
|
| 37 |
+
"""
|
| 38 |
+
Main entry point - maintains compatibility with existing UI
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
# Stage 1: Generate masks
|
| 42 |
+
if progress_callback:
|
| 43 |
+
progress_callback("Stage 1: Generating masks with SAM2...")
|
| 44 |
+
|
| 45 |
+
if not self._stage1_generate_masks(input_video, click_points, progress_callback):
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
# Stage 2: Process and composite
|
| 49 |
+
if progress_callback:
|
| 50 |
+
progress_callback("Stage 2: Processing and compositing...")
|
| 51 |
+
|
| 52 |
+
return self._stage2_composite(input_video, background_video,
|
| 53 |
+
output_path, use_matanyone, progress_callback)
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Two-stage processing failed: {e}")
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
def _stage1_generate_masks(self, input_video: str, click_points: list,
|
| 60 |
+
progress_callback=None) -> bool:
|
| 61 |
+
"""Stage 1: SAM2 mask generation with complete memory cleanup"""
|
| 62 |
+
try:
|
| 63 |
+
# Import SAM2 only when needed
|
| 64 |
+
print("Loading SAM2...")
|
| 65 |
+
import torch
|
| 66 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 67 |
+
|
| 68 |
+
# Initialize SAM2
|
| 69 |
+
checkpoint = "checkpoints/sam2.1_hiera_large.pt"
|
| 70 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 71 |
+
|
| 72 |
+
if not os.path.exists(checkpoint):
|
| 73 |
+
print(f"SAM2 checkpoint not found: {checkpoint}")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
|
| 77 |
+
|
| 78 |
+
# Get video info
|
| 79 |
+
cap = cv2.VideoCapture(input_video)
|
| 80 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 81 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 82 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 83 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 84 |
+
cap.release()
|
| 85 |
+
|
| 86 |
+
# Save metadata
|
| 87 |
+
metadata = {
|
| 88 |
+
"fps": fps,
|
| 89 |
+
"frame_count": frame_count,
|
| 90 |
+
"width": width,
|
| 91 |
+
"height": height,
|
| 92 |
+
"click_points": click_points
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
with open(self.metadata_path, 'w') as f:
|
| 96 |
+
json.dump(metadata, f, indent=2)
|
| 97 |
+
|
| 98 |
+
# Initialize inference state
|
| 99 |
+
inference_state = predictor.init_state(video_path=input_video)
|
| 100 |
+
|
| 101 |
+
# Add prompts
|
| 102 |
+
for i, point in enumerate(click_points):
|
| 103 |
+
x, y = point
|
| 104 |
+
predictor.add_new_points_or_box(
|
| 105 |
+
inference_state=inference_state,
|
| 106 |
+
frame_idx=0,
|
| 107 |
+
obj_id=i,
|
| 108 |
+
points=np.array([[x, y]], dtype=np.float32),
|
| 109 |
+
labels=np.array([1], np.int32),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Setup FFmpeg for lossless mask encoding
|
| 113 |
+
ffmpeg_cmd = [
|
| 114 |
+
'ffmpeg', '-y', '-f', 'rawvideo',
|
| 115 |
+
'-pix_fmt', 'gray', '-s', f'{width}x{height}',
|
| 116 |
+
'-r', str(fps), '-i', '-',
|
| 117 |
+
'-c:v', 'ffv1', '-level', '3', '-pix_fmt', 'gray',
|
| 118 |
+
str(self.masks_path)
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
ffmpeg_process = subprocess.Popen(
|
| 122 |
+
ffmpeg_cmd, stdin=subprocess.PIPE,
|
| 123 |
+
stderr=subprocess.PIPE, stdout=subprocess.PIPE
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Generate and stream masks
|
| 127 |
+
print(f"Processing {frame_count} frames...")
|
| 128 |
+
|
| 129 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
| 130 |
+
if progress_callback:
|
| 131 |
+
progress = (out_frame_idx + 1) / frame_count * 50 # 50% of total progress for stage 1
|
| 132 |
+
progress_callback(f"Generating masks... Frame {out_frame_idx + 1}/{frame_count}", progress)
|
| 133 |
+
|
| 134 |
+
# Combine masks from all objects
|
| 135 |
+
combined_mask = np.zeros((height, width), dtype=np.uint8)
|
| 136 |
+
for obj_id in out_obj_ids:
|
| 137 |
+
mask = (out_mask_logits[obj_id] > 0.0).squeeze()
|
| 138 |
+
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8) * 255
|
| 139 |
+
|
| 140 |
+
# Write to FFmpeg
|
| 141 |
+
ffmpeg_process.stdin.write(combined_mask.tobytes())
|
| 142 |
+
|
| 143 |
+
# Finalize FFmpeg
|
| 144 |
+
ffmpeg_process.stdin.close()
|
| 145 |
+
ffmpeg_process.wait()
|
| 146 |
+
|
| 147 |
+
if ffmpeg_process.returncode != 0:
|
| 148 |
+
error = ffmpeg_process.stderr.read().decode()
|
| 149 |
+
print(f"FFmpeg error: {error}")
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
print("Stage 1 complete: Masks saved")
|
| 153 |
+
|
| 154 |
+
# CRITICAL: Complete memory cleanup
|
| 155 |
+
del predictor
|
| 156 |
+
del inference_state
|
| 157 |
+
if 'torch' in locals():
|
| 158 |
+
if torch.cuda.is_available():
|
| 159 |
+
torch.cuda.empty_cache()
|
| 160 |
+
torch.cuda.synchronize()
|
| 161 |
+
|
| 162 |
+
# Force garbage collection
|
| 163 |
+
gc.collect()
|
| 164 |
+
|
| 165 |
+
# Clear SAM2 from sys.modules to prevent memory leaks
|
| 166 |
+
modules_to_clear = [mod for mod in sys.modules.keys() if 'sam2' in mod.lower()]
|
| 167 |
+
for mod in modules_to_clear:
|
| 168 |
+
del sys.modules[mod]
|
| 169 |
+
|
| 170 |
+
print("SAM2 completely unloaded from memory")
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Stage 1 failed: {e}")
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
def _stage2_composite(self, input_video: str, background_video: str,
|
| 178 |
+
output_path: str, use_matanyone: bool, progress_callback=None) -> bool:
|
| 179 |
+
"""Stage 2: Read masks, refine with MatAnyone, and composite"""
|
| 180 |
+
try:
|
| 181 |
+
# Load metadata
|
| 182 |
+
with open(self.metadata_path, 'r') as f:
|
| 183 |
+
metadata = json.load(f)
|
| 184 |
+
|
| 185 |
+
frame_count = metadata["frame_count"]
|
| 186 |
+
|
| 187 |
+
# Read masks back from lossless stream
|
| 188 |
+
masks = self._read_mask_stream()
|
| 189 |
+
if masks is None:
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
# Optional MatAnyone refinement
|
| 193 |
+
if use_matanyone:
|
| 194 |
+
if progress_callback:
|
| 195 |
+
progress_callback("Refining masks with MatAnyone...")
|
| 196 |
+
masks = self._refine_with_matanyone(input_video, masks, progress_callback)
|
| 197 |
+
if masks is None:
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
# Final composition
|
| 201 |
+
if progress_callback:
|
| 202 |
+
progress_callback("Compositing final video...")
|
| 203 |
+
|
| 204 |
+
return self._composite_final_video(input_video, background_video,
|
| 205 |
+
masks, output_path, metadata, progress_callback)
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"Stage 2 failed: {e}")
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
def _read_mask_stream(self) -> Optional[list]:
|
| 212 |
+
"""Read masks from the lossless FFV1 stream"""
|
| 213 |
+
try:
|
| 214 |
+
# Load metadata for dimensions
|
| 215 |
+
with open(self.metadata_path, 'r') as f:
|
| 216 |
+
metadata = json.load(f)
|
| 217 |
+
|
| 218 |
+
width = metadata["width"]
|
| 219 |
+
height = metadata["height"]
|
| 220 |
+
frame_count = metadata["frame_count"]
|
| 221 |
+
|
| 222 |
+
# Use FFmpeg to decode masks
|
| 223 |
+
ffmpeg_cmd = [
|
| 224 |
+
'ffmpeg', '-i', str(self.masks_path),
|
| 225 |
+
'-f', 'rawvideo', '-pix_fmt', 'gray', '-'
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
process = subprocess.Popen(
|
| 229 |
+
ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
masks = []
|
| 233 |
+
frame_size = width * height
|
| 234 |
+
|
| 235 |
+
for frame_idx in range(frame_count):
|
| 236 |
+
frame_data = process.stdout.read(frame_size)
|
| 237 |
+
if len(frame_data) != frame_size:
|
| 238 |
+
print(f"Unexpected frame size at frame {frame_idx}")
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
mask = np.frombuffer(frame_data, dtype=np.uint8).reshape((height, width))
|
| 242 |
+
masks.append(mask)
|
| 243 |
+
|
| 244 |
+
process.stdout.close()
|
| 245 |
+
process.wait()
|
| 246 |
+
|
| 247 |
+
if process.returncode != 0:
|
| 248 |
+
error = process.stderr.read().decode()
|
| 249 |
+
print(f"FFmpeg decode error: {error}")
|
| 250 |
+
return None
|
| 251 |
+
|
| 252 |
+
print(f"Successfully read {len(masks)} masks from stream")
|
| 253 |
+
return masks
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Failed to read mask stream: {e}")
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
def _refine_with_matanyone(self, input_video: str, masks: list, progress_callback=None) -> Optional[list]:
|
| 260 |
+
"""Apply MatAnyone refinement to masks"""
|
| 261 |
+
try:
|
| 262 |
+
# Import MatAnyone only when needed
|
| 263 |
+
from matanyone.mat_anywhere import matting_inference_video
|
| 264 |
+
|
| 265 |
+
# Create temp directory for MatAnyone
|
| 266 |
+
matanyone_temp = self.temp_dir / "matanyone"
|
| 267 |
+
matanyone_temp.mkdir(exist_ok=True)
|
| 268 |
+
|
| 269 |
+
# Save masks as individual frames for MatAnyone
|
| 270 |
+
mask_dir = matanyone_temp / "masks"
|
| 271 |
+
mask_dir.mkdir(exist_ok=True)
|
| 272 |
+
|
| 273 |
+
for i, mask in enumerate(masks):
|
| 274 |
+
cv2.imwrite(str(mask_dir / f"mask_{i:06d}.png"), mask)
|
| 275 |
+
|
| 276 |
+
# Run MatAnyone
|
| 277 |
+
refined_masks_dir = matanyone_temp / "refined"
|
| 278 |
+
refined_masks_dir.mkdir(exist_ok=True)
|
| 279 |
+
|
| 280 |
+
success = matting_inference_video(
|
| 281 |
+
video_path=input_video,
|
| 282 |
+
mask_dir=str(mask_dir),
|
| 283 |
+
output_dir=str(refined_masks_dir),
|
| 284 |
+
progress_callback=progress_callback
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if not success:
|
| 288 |
+
print("MatAnyone refinement failed, using original masks")
|
| 289 |
+
return masks
|
| 290 |
+
|
| 291 |
+
# Load refined masks
|
| 292 |
+
refined_masks = []
|
| 293 |
+
for i in range(len(masks)):
|
| 294 |
+
refined_path = refined_masks_dir / f"refined_{i:06d}.png"
|
| 295 |
+
if refined_path.exists():
|
| 296 |
+
refined_mask = cv2.imread(str(refined_path), cv2.IMREAD_GRAYSCALE)
|
| 297 |
+
refined_masks.append(refined_mask)
|
| 298 |
+
else:
|
| 299 |
+
refined_masks.append(masks[i]) # Fallback to original
|
| 300 |
+
|
| 301 |
+
return refined_masks
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"MatAnyone refinement failed: {e}, using original masks")
|
| 305 |
+
return masks
|
| 306 |
+
|
| 307 |
+
def _composite_final_video(self, input_video: str, background_video: str,
|
| 308 |
+
masks: list, output_path: str, metadata: Dict[str, Any],
|
| 309 |
+
progress_callback=None) -> bool:
|
| 310 |
+
"""Create final composite video"""
|
| 311 |
+
try:
|
| 312 |
+
# Setup video capture
|
| 313 |
+
fg_cap = cv2.VideoCapture(input_video)
|
| 314 |
+
bg_cap = cv2.VideoCapture(background_video)
|
| 315 |
+
|
| 316 |
+
fps = metadata["fps"]
|
| 317 |
+
width = metadata["width"]
|
| 318 |
+
height = metadata["height"]
|
| 319 |
+
|
| 320 |
+
# Setup output writer
|
| 321 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 322 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 323 |
+
|
| 324 |
+
frame_idx = 0
|
| 325 |
+
total_frames = len(masks)
|
| 326 |
+
|
| 327 |
+
while frame_idx < total_frames:
|
| 328 |
+
# Read frames
|
| 329 |
+
ret_fg, fg_frame = fg_cap.read()
|
| 330 |
+
ret_bg, bg_frame = bg_cap.read()
|
| 331 |
+
|
| 332 |
+
if not ret_fg:
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
if not ret_bg:
|
| 336 |
+
# Loop background if shorter
|
| 337 |
+
bg_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
| 338 |
+
ret_bg, bg_frame = bg_cap.read()
|
| 339 |
+
|
| 340 |
+
if not ret_bg:
|
| 341 |
+
print("No background frame available")
|
| 342 |
+
break
|
| 343 |
+
|
| 344 |
+
# Resize background to match foreground
|
| 345 |
+
bg_frame = cv2.resize(bg_frame, (width, height))
|
| 346 |
+
|
| 347 |
+
# Get mask
|
| 348 |
+
mask = masks[frame_idx]
|
| 349 |
+
mask_norm = mask.astype(np.float32) / 255.0
|
| 350 |
+
mask_3ch = np.stack([mask_norm, mask_norm, mask_norm], axis=-1)
|
| 351 |
+
|
| 352 |
+
# Composite
|
| 353 |
+
composite = (fg_frame * mask_3ch + bg_frame * (1 - mask_3ch)).astype(np.uint8)
|
| 354 |
+
out.write(composite)
|
| 355 |
+
|
| 356 |
+
frame_idx += 1
|
| 357 |
+
|
| 358 |
+
if progress_callback and frame_idx % 10 == 0:
|
| 359 |
+
progress = 50 + (frame_idx / total_frames) * 50 # 50-100% for stage 2
|
| 360 |
+
progress_callback(f"Compositing... Frame {frame_idx}/{total_frames}", progress)
|
| 361 |
+
|
| 362 |
+
# Cleanup
|
| 363 |
+
fg_cap.release()
|
| 364 |
+
bg_cap.release()
|
| 365 |
+
out.release()
|
| 366 |
+
|
| 367 |
+
print(f"Final video saved to: {output_path}")
|
| 368 |
+
return True
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
print(f"Final composition failed: {e}")
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
def cleanup(self):
|
| 375 |
+
"""Clean up temporary files"""
|
| 376 |
+
try:
|
| 377 |
+
if self.temp_dir.exists():
|
| 378 |
+
import shutil
|
| 379 |
+
shutil.rmtree(self.temp_dir)
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(f"Cleanup failed: {e}")
|
| 382 |
+
|
| 383 |
+
# Compatibility wrapper for existing UI
|
| 384 |
+
def process_video_two_stage(input_video: str, background_video: str,
|
| 385 |
+
click_points: list, output_path: str,
|
| 386 |
+
use_matanyone: bool = True, progress_callback=None) -> bool:
|
| 387 |
+
"""
|
| 388 |
+
Drop-in replacement for existing process_video function
|
| 389 |
+
"""
|
| 390 |
+
processor = TwoStageProcessor()
|
| 391 |
+
try:
|
| 392 |
+
result = processor.process_video(
|
| 393 |
+
input_video, background_video, click_points,
|
| 394 |
+
output_path, use_matanyone, progress_callback
|
| 395 |
+
)
|
| 396 |
+
return result
|
| 397 |
+
finally:
|
| 398 |
+
processor.cleanup()
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
# Test the pipeline
|
| 402 |
+
import argparse
|
| 403 |
+
parser = argparse.ArgumentParser()
|
| 404 |
+
parser.add_argument("--input", required=True)
|
| 405 |
+
parser.add_argument("--background", required=True)
|
| 406 |
+
parser.add_argument("--output", required=True)
|
| 407 |
+
parser.add_argument("--clicks", required=True, help="JSON string of click points")
|
| 408 |
+
parser.add_argument("--no-matanyone", action="store_true")
|
| 409 |
+
|
| 410 |
+
args = parser.parse_args()
|
| 411 |
+
|
| 412 |
+
click_points = json.loads(args.clicks)
|
| 413 |
+
use_matanyone = not args.no_matanyone
|
| 414 |
+
|
| 415 |
+
success = process_video_two_stage(
|
| 416 |
+
args.input, args.background, click_points,
|
| 417 |
+
args.output, use_matanyone,
|
| 418 |
+
lambda msg, prog=None: print(f"Progress: {msg} ({prog}%)" if prog else msg)
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
print("Processing completed!" if success else "Processing failed!")
|
VideoBackgroundReplacer2/models/__init__.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BackgroundFX Pro - Model Loading & Utilities (Hardened)
|
| 4 |
+
======================================================
|
| 5 |
+
- Avoids heavy CUDA/Hydra work at import time
|
| 6 |
+
- Adds timeouts to subprocess probes
|
| 7 |
+
- Safer sys.path wiring for third_party repos
|
| 8 |
+
- MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession
|
| 9 |
+
|
| 10 |
+
Changes (2025-09-16):
|
| 11 |
+
- Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0
|
| 12 |
+
- Updated load_matany to apply T=1 squeeze patch before InferenceCore import
|
| 13 |
+
- Added patch status logging and MatAnyone version
|
| 14 |
+
- Added InferenceCore attributes logging for debugging
|
| 15 |
+
- Fixed InferenceCore import path to matanyone.inference.inference_core
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import cv2
|
| 23 |
+
import subprocess
|
| 24 |
+
import inspect
|
| 25 |
+
import logging
|
| 26 |
+
import importlib.metadata
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional, Tuple, Dict, Any, Union, Callable
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import yaml
|
| 32 |
+
|
| 33 |
+
# Import torch for GPU memory monitoring
|
| 34 |
+
try:
|
| 35 |
+
import torch
|
| 36 |
+
except ImportError:
|
| 37 |
+
torch = None
|
| 38 |
+
|
| 39 |
+
# --------------------------------------------------------------------------------------
|
| 40 |
+
# Logging (ensure a handler exists very early)
|
| 41 |
+
# --------------------------------------------------------------------------------------
|
| 42 |
+
logger = logging.getLogger("backgroundfx_pro")
|
| 43 |
+
if not logger.handlers:
|
| 44 |
+
_h = logging.StreamHandler()
|
| 45 |
+
_h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 46 |
+
logger.addHandler(_h)
|
| 47 |
+
logger.setLevel(logging.INFO)
|
| 48 |
+
|
| 49 |
+
# Pin OpenCV threads (helps libgomp stability in Spaces)
|
| 50 |
+
try:
|
| 51 |
+
cv_threads = int(os.environ.get("CV_THREADS", "1"))
|
| 52 |
+
if hasattr(cv2, "setNumThreads"):
|
| 53 |
+
cv2.setNumThreads(cv_threads)
|
| 54 |
+
except Exception:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
# --------------------------------------------------------------------------------------
|
| 58 |
+
# Optional dependencies
|
| 59 |
+
# --------------------------------------------------------------------------------------
|
| 60 |
+
try:
|
| 61 |
+
import mediapipe as mp # type: ignore
|
| 62 |
+
_HAS_MEDIAPIPE = True
|
| 63 |
+
except Exception:
|
| 64 |
+
_HAS_MEDIAPIPE = False
|
| 65 |
+
|
| 66 |
+
# --------------------------------------------------------------------------------------
|
| 67 |
+
# Path setup for third_party repos
|
| 68 |
+
# --------------------------------------------------------------------------------------
|
| 69 |
+
ROOT = Path(__file__).resolve().parent.parent # project root
|
| 70 |
+
TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
|
| 71 |
+
TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve()
|
| 72 |
+
|
| 73 |
+
def _add_sys_path(p: Path) -> None:
|
| 74 |
+
if p.exists():
|
| 75 |
+
p_str = str(p)
|
| 76 |
+
if p_str not in sys.path:
|
| 77 |
+
sys.path.insert(0, p_str)
|
| 78 |
+
else:
|
| 79 |
+
logger.warning(f"third_party path not found: {p}")
|
| 80 |
+
|
| 81 |
+
_add_sys_path(TP_SAM2)
|
| 82 |
+
_add_sys_path(TP_MATANY)
|
| 83 |
+
|
| 84 |
+
# --------------------------------------------------------------------------------------
|
| 85 |
+
# Safe Torch accessors (no top-level import)
|
| 86 |
+
# --------------------------------------------------------------------------------------
|
| 87 |
+
def _torch():
|
| 88 |
+
try:
|
| 89 |
+
import torch # local import avoids early CUDA init during module import
|
| 90 |
+
return torch
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.warning(f"[models.safe-torch] import failed: {e}")
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def _has_cuda() -> bool:
|
| 96 |
+
t = _torch()
|
| 97 |
+
if t is None:
|
| 98 |
+
return False
|
| 99 |
+
try:
|
| 100 |
+
return bool(t.cuda.is_available())
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}")
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def _pick_device(env_key: str) -> str:
|
| 106 |
+
requested = os.environ.get(env_key, "").strip().lower()
|
| 107 |
+
has_cuda = _has_cuda()
|
| 108 |
+
|
| 109 |
+
# Log all CUDA-related environment variables
|
| 110 |
+
cuda_env_vars = {
|
| 111 |
+
'FORCE_CUDA_DEVICE': os.environ.get('FORCE_CUDA_DEVICE', ''),
|
| 112 |
+
'CUDA_MEMORY_FRACTION': os.environ.get('CUDA_MEMORY_FRACTION', ''),
|
| 113 |
+
'PYTORCH_CUDA_ALLOC_CONF': os.environ.get('PYTORCH_CUDA_ALLOC_CONF', ''),
|
| 114 |
+
'REQUIRE_CUDA': os.environ.get('REQUIRE_CUDA', ''),
|
| 115 |
+
'SAM2_DEVICE': os.environ.get('SAM2_DEVICE', ''),
|
| 116 |
+
'MATANY_DEVICE': os.environ.get('MATANY_DEVICE', ''),
|
| 117 |
+
}
|
| 118 |
+
logger.info(f"CUDA environment variables: {cuda_env_vars}")
|
| 119 |
+
|
| 120 |
+
logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
|
| 121 |
+
|
| 122 |
+
# Force CUDA if available (empty string counts as no explicit CPU request)
|
| 123 |
+
if has_cuda and requested not in {"cpu"}:
|
| 124 |
+
logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
|
| 125 |
+
return "cuda"
|
| 126 |
+
elif requested in {"cuda", "cpu"}:
|
| 127 |
+
logger.info(f"Using explicitly requested device: {requested}")
|
| 128 |
+
return requested
|
| 129 |
+
|
| 130 |
+
result = "cuda" if has_cuda else "cpu"
|
| 131 |
+
logger.info(f"Auto-selected device: {result}")
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
# --------------------------------------------------------------------------------------
|
| 135 |
+
# Basic Utilities
|
| 136 |
+
# --------------------------------------------------------------------------------------
|
| 137 |
+
def _ffmpeg_bin() -> str:
|
| 138 |
+
return os.environ.get("FFMPEG_BIN", "ffmpeg")
|
| 139 |
+
|
| 140 |
+
def _probe_ffmpeg(timeout: int = 2) -> bool:
|
| 141 |
+
try:
|
| 142 |
+
subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout)
|
| 143 |
+
return True
|
| 144 |
+
except Exception:
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
def _ensure_dir(p: Path) -> None:
|
| 148 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]:
|
| 151 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 152 |
+
if not cap.isOpened():
|
| 153 |
+
return None, 0, (0, 0)
|
| 154 |
+
fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
|
| 155 |
+
ok, frame = cap.read()
|
| 156 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
|
| 157 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
|
| 158 |
+
cap.release()
|
| 159 |
+
if not ok:
|
| 160 |
+
return None, fps, (w, h)
|
| 161 |
+
return frame, fps, (w, h)
|
| 162 |
+
|
| 163 |
+
def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str:
|
| 164 |
+
if mask.dtype == bool:
|
| 165 |
+
mask = (mask.astype(np.uint8) * 255)
|
| 166 |
+
elif mask.dtype != np.uint8:
|
| 167 |
+
mask = np.clip(mask, 0, 255).astype(np.uint8)
|
| 168 |
+
cv2.imwrite(str(path), mask)
|
| 169 |
+
return str(path)
|
| 170 |
+
|
| 171 |
+
def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray:
|
| 172 |
+
tw, th = target_wh
|
| 173 |
+
h, w = image.shape[:2]
|
| 174 |
+
if h == 0 or w == 0 or tw == 0 or th == 0:
|
| 175 |
+
return image
|
| 176 |
+
scale = min(tw / w, th / h)
|
| 177 |
+
nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
|
| 178 |
+
resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
|
| 179 |
+
canvas = np.zeros((th, tw, 3), dtype=resized.dtype)
|
| 180 |
+
x0 = (tw - nw) // 2
|
| 181 |
+
y0 = (th - nh) // 2
|
| 182 |
+
canvas[y0:y0+nh, x0:x0+nw] = resized
|
| 183 |
+
return canvas
|
| 184 |
+
|
| 185 |
+
def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter:
|
| 186 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 187 |
+
return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size)
|
| 188 |
+
|
| 189 |
+
def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool:
|
| 190 |
+
"""Copy video from silent_video + audio from src_video into out_path (AAC)."""
|
| 191 |
+
try:
|
| 192 |
+
cmd = [
|
| 193 |
+
_ffmpeg_bin(), "-y",
|
| 194 |
+
"-i", str(silent_video),
|
| 195 |
+
"-i", str(src_video),
|
| 196 |
+
"-map", "0:v:0",
|
| 197 |
+
"-map", "1:a:0?",
|
| 198 |
+
"-c:v", "copy",
|
| 199 |
+
"-c:a", "aac", "-b:a", "192k",
|
| 200 |
+
"-shortest",
|
| 201 |
+
str(out_path)
|
| 202 |
+
]
|
| 203 |
+
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 204 |
+
return True
|
| 205 |
+
except Exception as e:
|
| 206 |
+
logger.warning(f"Audio mux failed; returning silent video. Reason: {e}")
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
# --------------------------------------------------------------------------------------
|
| 210 |
+
# Compositing & Image Processing
|
| 211 |
+
# --------------------------------------------------------------------------------------
|
| 212 |
+
def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray:
|
| 213 |
+
if alpha.dtype != np.float32:
|
| 214 |
+
a = alpha.astype(np.float32)
|
| 215 |
+
if a.max() > 1.0:
|
| 216 |
+
a = a / 255.0
|
| 217 |
+
else:
|
| 218 |
+
a = alpha.copy()
|
| 219 |
+
|
| 220 |
+
a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8)
|
| 221 |
+
if erode_px > 0:
|
| 222 |
+
k = max(1, int(erode_px))
|
| 223 |
+
a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
|
| 224 |
+
if dilate_px > 0:
|
| 225 |
+
k = max(1, int(dilate_px))
|
| 226 |
+
a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1)
|
| 227 |
+
a = a_u8.astype(np.float32) / 255.0
|
| 228 |
+
|
| 229 |
+
if blur_px and blur_px > 0:
|
| 230 |
+
rad = max(1, int(round(blur_px)))
|
| 231 |
+
a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0)
|
| 232 |
+
|
| 233 |
+
return np.clip(a, 0.0, 1.0)
|
| 234 |
+
|
| 235 |
+
def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray:
|
| 236 |
+
x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0)
|
| 237 |
+
return np.power(x, gamma)
|
| 238 |
+
|
| 239 |
+
def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray:
|
| 240 |
+
x = np.clip(lin, 0.0, 1.0)
|
| 241 |
+
return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8)
|
| 242 |
+
|
| 243 |
+
def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray:
|
| 244 |
+
r = max(1, int(radius))
|
| 245 |
+
inv = 1.0 - alpha01
|
| 246 |
+
inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0)
|
| 247 |
+
lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount))
|
| 248 |
+
return lw
|
| 249 |
+
|
| 250 |
+
def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray:
|
| 251 |
+
w = 1.0 - 2.0 * np.abs(alpha01 - 0.5)
|
| 252 |
+
w = np.clip(w, 0.0, 1.0)
|
| 253 |
+
hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
|
| 254 |
+
H, S, V = cv2.split(hsv)
|
| 255 |
+
S = S * (1.0 - amount * w)
|
| 256 |
+
hsv2 = cv2.merge([H, np.clip(S, 0, 255), V])
|
| 257 |
+
out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB)
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
def _composite_frame_pro(
|
| 261 |
+
fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray,
|
| 262 |
+
erode_px: int = None, dilate_px: int = None, blur_px: float = None,
|
| 263 |
+
lw_radius: int = None, lw_amount: float = None, despill_amount: float = None
|
| 264 |
+
) -> np.ndarray:
|
| 265 |
+
erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1"))
|
| 266 |
+
dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2"))
|
| 267 |
+
blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5"))
|
| 268 |
+
lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5"))
|
| 269 |
+
lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18"))
|
| 270 |
+
despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35"))
|
| 271 |
+
|
| 272 |
+
a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px)
|
| 273 |
+
fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount)
|
| 274 |
+
|
| 275 |
+
fg_lin = _to_linear(fg_rgb)
|
| 276 |
+
bg_lin = _to_linear(bg_rgb)
|
| 277 |
+
lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount)
|
| 278 |
+
lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8))
|
| 279 |
+
|
| 280 |
+
comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin
|
| 281 |
+
comp = _to_srgb(comp_lin)
|
| 282 |
+
return comp
|
| 283 |
+
|
| 284 |
+
# --------------------------------------------------------------------------------------
|
| 285 |
+
# SAM2 Integration
|
| 286 |
+
# --------------------------------------------------------------------------------------
|
| 287 |
+
def _resolve_sam2_cfg(cfg_str: str) -> str:
|
| 288 |
+
"""Resolve SAM2 config path - return relative path for Hydra compatibility."""
|
| 289 |
+
logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
|
| 290 |
+
|
| 291 |
+
# Get the third-party SAM2 directory
|
| 292 |
+
tp_sam2 = os.environ.get("THIRD_PARTY_SAM2_DIR", "/home/user/app/third_party/sam2")
|
| 293 |
+
logger.info(f"TP_SAM2 = {tp_sam2}")
|
| 294 |
+
|
| 295 |
+
# Check if the full path exists
|
| 296 |
+
candidate = os.path.join(tp_sam2, cfg_str)
|
| 297 |
+
logger.info(f"Candidate path: {candidate}")
|
| 298 |
+
logger.info(f"Candidate exists: {os.path.exists(candidate)}")
|
| 299 |
+
|
| 300 |
+
if os.path.exists(candidate):
|
| 301 |
+
# For Hydra compatibility, return just the relative path within sam2 package
|
| 302 |
+
if cfg_str.startswith("sam2/configs/"):
|
| 303 |
+
relative_path = cfg_str.replace("sam2/configs/", "configs/")
|
| 304 |
+
else:
|
| 305 |
+
relative_path = cfg_str
|
| 306 |
+
logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
|
| 307 |
+
return relative_path
|
| 308 |
+
|
| 309 |
+
# If not found, try some fallback paths
|
| 310 |
+
fallbacks = [
|
| 311 |
+
os.path.join(tp_sam2, "sam2", cfg_str),
|
| 312 |
+
os.path.join(tp_sam2, "configs", cfg_str),
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
for fallback in fallbacks:
|
| 316 |
+
logger.info(f"Trying fallback: {fallback}")
|
| 317 |
+
if os.path.exists(fallback):
|
| 318 |
+
# Extract relative path for Hydra
|
| 319 |
+
if "configs/" in fallback:
|
| 320 |
+
relative_path = "configs/" + fallback.split("configs/")[-1]
|
| 321 |
+
logger.info(f"Returning fallback relative path: {relative_path}")
|
| 322 |
+
return relative_path
|
| 323 |
+
|
| 324 |
+
logger.warning(f"Config not found, returning original: {cfg_str}")
|
| 325 |
+
return cfg_str
|
| 326 |
+
|
| 327 |
+
def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
|
| 328 |
+
"""If config references 'hieradet', try to find a 'hiera' config."""
|
| 329 |
+
try:
|
| 330 |
+
with open(cfg_path, "r") as f:
|
| 331 |
+
data = yaml.safe_load(f)
|
| 332 |
+
model = data.get("model", {}) or {}
|
| 333 |
+
enc = model.get("image_encoder") or {}
|
| 334 |
+
trunk = enc.get("trunk") or {}
|
| 335 |
+
target = trunk.get("_target_") or trunk.get("target")
|
| 336 |
+
if isinstance(target, str) and "hieradet" in target:
|
| 337 |
+
for y in TP_SAM2.rglob("*.yaml"):
|
| 338 |
+
try:
|
| 339 |
+
with open(y, "r") as f2:
|
| 340 |
+
d2 = yaml.safe_load(f2) or {}
|
| 341 |
+
e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
|
| 342 |
+
t2 = (e2.get("trunk") or {})
|
| 343 |
+
tgt2 = t2.get("_target_") or t2.get("target")
|
| 344 |
+
if isinstance(tgt2, str) and ".hiera." in tgt2:
|
| 345 |
+
logger.info(f"SAM2: switching config from 'hieradet' β 'hiera': {y}")
|
| 346 |
+
return str(y)
|
| 347 |
+
except Exception:
|
| 348 |
+
continue
|
| 349 |
+
except Exception:
|
| 350 |
+
pass
|
| 351 |
+
return None
|
| 352 |
+
|
| 353 |
+
def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
| 354 |
+
"""Robust SAM2 loader with config resolution and error handling."""
|
| 355 |
+
meta = {"sam2_import_ok": False, "sam2_init_ok": False}
|
| 356 |
+
try:
|
| 357 |
+
from sam2.build_sam import build_sam2 # type: ignore
|
| 358 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
|
| 359 |
+
meta["sam2_import_ok"] = True
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.warning(f"SAM2 import failed: {e}")
|
| 362 |
+
return None, False, meta
|
| 363 |
+
|
| 364 |
+
# Check GPU memory before loading
|
| 365 |
+
if torch and torch.cuda.is_available():
|
| 366 |
+
mem_before = torch.cuda.memory_allocated() / 1024**3
|
| 367 |
+
logger.info(f"π GPU memory before SAM2 load: {mem_before:.2f}GB")
|
| 368 |
+
|
| 369 |
+
device = _pick_device("SAM2_DEVICE")
|
| 370 |
+
cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
|
| 371 |
+
cfg = _resolve_sam2_cfg(cfg_env)
|
| 372 |
+
ckpt = os.environ.get("SAM2_CHECKPOINT", "")
|
| 373 |
+
|
| 374 |
+
def _try_build(cfg_path: str):
|
| 375 |
+
logger.info(f"_try_build called with cfg_path: {cfg_path}")
|
| 376 |
+
params = set(inspect.signature(build_sam2).parameters.keys())
|
| 377 |
+
logger.info(f"build_sam2 parameters: {list(params)}")
|
| 378 |
+
kwargs = {}
|
| 379 |
+
if "config_file" in params:
|
| 380 |
+
kwargs["config_file"] = cfg_path
|
| 381 |
+
logger.info(f"Using config_file parameter: {cfg_path}")
|
| 382 |
+
elif "model_cfg" in params:
|
| 383 |
+
kwargs["model_cfg"] = cfg_path
|
| 384 |
+
logger.info(f"Using model_cfg parameter: {cfg_path}")
|
| 385 |
+
if ckpt:
|
| 386 |
+
if "checkpoint" in params:
|
| 387 |
+
kwargs["checkpoint"] = ckpt
|
| 388 |
+
elif "ckpt_path" in params:
|
| 389 |
+
kwargs["ckpt_path"] = ckpt
|
| 390 |
+
elif "weights" in params:
|
| 391 |
+
kwargs["weights"] = ckpt
|
| 392 |
+
if "device" in params:
|
| 393 |
+
kwargs["device"] = device
|
| 394 |
+
try:
|
| 395 |
+
logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
|
| 396 |
+
result = build_sam2(**kwargs)
|
| 397 |
+
logger.info(f"build_sam2 succeeded with kwargs")
|
| 398 |
+
# Log actual device of the model
|
| 399 |
+
if hasattr(result, 'device'):
|
| 400 |
+
logger.info(f"SAM2 model device: {result.device}")
|
| 401 |
+
elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
|
| 402 |
+
logger.info(f"SAM2 model device: {result.image_encoder.device}")
|
| 403 |
+
return result
|
| 404 |
+
except TypeError as e:
|
| 405 |
+
logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
|
| 406 |
+
pos = [cfg_path]
|
| 407 |
+
if ckpt:
|
| 408 |
+
pos.append(ckpt)
|
| 409 |
+
if "device" not in kwargs:
|
| 410 |
+
pos.append(device)
|
| 411 |
+
logger.info(f"Calling build_sam2 with positional args: {pos}")
|
| 412 |
+
result = build_sam2(*pos)
|
| 413 |
+
logger.info(f"build_sam2 succeeded with positional args")
|
| 414 |
+
return result
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
try:
|
| 418 |
+
sam = _try_build(cfg)
|
| 419 |
+
except Exception:
|
| 420 |
+
alt_cfg = _find_hiera_config_if_hieradet(cfg)
|
| 421 |
+
if alt_cfg:
|
| 422 |
+
sam = _try_build(alt_cfg)
|
| 423 |
+
else:
|
| 424 |
+
raise
|
| 425 |
+
|
| 426 |
+
if sam is not None:
|
| 427 |
+
predictor = SAM2ImagePredictor(sam)
|
| 428 |
+
meta["sam2_init_ok"] = True
|
| 429 |
+
meta["sam2_device"] = device
|
| 430 |
+
return predictor, True, meta
|
| 431 |
+
else:
|
| 432 |
+
return None, False, meta
|
| 433 |
+
|
| 434 |
+
except Exception as e:
|
| 435 |
+
logger.error(f"SAM2 loading failed: {e}")
|
| 436 |
+
return None, False, meta
|
| 437 |
+
|
| 438 |
+
def run_sam2_mask(predictor: object,
|
| 439 |
+
first_frame_bgr: np.ndarray,
|
| 440 |
+
point: Optional[Tuple[int, int]] = None,
|
| 441 |
+
auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
|
| 442 |
+
"""Return (mask_uint8_0_255, ok)."""
|
| 443 |
+
if predictor is None:
|
| 444 |
+
return None, False
|
| 445 |
+
try:
|
| 446 |
+
rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
|
| 447 |
+
predictor.set_image(rgb)
|
| 448 |
+
|
| 449 |
+
if auto:
|
| 450 |
+
h, w = rgb.shape[:2]
|
| 451 |
+
box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
|
| 452 |
+
masks, _, _ = predictor.predict(box=box)
|
| 453 |
+
elif point is not None:
|
| 454 |
+
x, y = int(point[0]), int(point[1])
|
| 455 |
+
pts = np.array([[x, y]], dtype=np.int32)
|
| 456 |
+
labels = np.array([1], dtype=np.int32)
|
| 457 |
+
masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
|
| 458 |
+
else:
|
| 459 |
+
h, w = rgb.shape[:2]
|
| 460 |
+
box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
|
| 461 |
+
masks, _, _ = predictor.predict(box=box)
|
| 462 |
+
|
| 463 |
+
if masks is None or len(masks) == 0:
|
| 464 |
+
return None, False
|
| 465 |
+
|
| 466 |
+
m = masks[0].astype(np.uint8) * 255
|
| 467 |
+
return m, True
|
| 468 |
+
except Exception as e:
|
| 469 |
+
logger.warning(f"SAM2 mask failed: {e}")
|
| 470 |
+
return None, False
|
| 471 |
+
|
| 472 |
+
def _refine_mask_grabcut(image_bgr: np.ndarray,
|
| 473 |
+
mask_u8: np.ndarray,
|
| 474 |
+
iters: int = None,
|
| 475 |
+
trimap_erode: int = None,
|
| 476 |
+
trimap_dilate: int = None) -> np.ndarray:
|
| 477 |
+
"""Use SAM2 seed as initialization for GrabCut refinement."""
|
| 478 |
+
iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters)
|
| 479 |
+
e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode)
|
| 480 |
+
d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate)
|
| 481 |
+
|
| 482 |
+
h, w = mask_u8.shape[:2]
|
| 483 |
+
m = (mask_u8 > 127).astype(np.uint8) * 255
|
| 484 |
+
|
| 485 |
+
sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1)
|
| 486 |
+
sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1)
|
| 487 |
+
|
| 488 |
+
gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
|
| 489 |
+
gc_mask[sure_bg > 0] = cv2.GC_BGD
|
| 490 |
+
gc_mask[sure_fg > 0] = cv2.GC_FGD
|
| 491 |
+
|
| 492 |
+
bgdModel = np.zeros((1, 65), np.float64)
|
| 493 |
+
fgdModel = np.zeros((1, 65), np.float64)
|
| 494 |
+
try:
|
| 495 |
+
cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
|
| 496 |
+
out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
|
| 497 |
+
out = cv2.medianBlur(out, 5)
|
| 498 |
+
return out
|
| 499 |
+
except Exception as e:
|
| 500 |
+
logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}")
|
| 501 |
+
return m
|
| 502 |
+
|
| 503 |
+
# --------------------------------------------------------------------------------------
|
| 504 |
+
# MatAnyone Integration
|
| 505 |
+
# --------------------------------------------------------------------------------------
|
| 506 |
+
def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
| 507 |
+
"""
|
| 508 |
+
Probe MatAnyone availability with T=1 squeeze patch for conv2d compatibility.
|
| 509 |
+
Returns (None, available, meta); actual instantiation happens in MatAnyoneSession.
|
| 510 |
+
"""
|
| 511 |
+
meta = {"matany_import_ok": False, "matany_init_ok": False}
|
| 512 |
+
enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower()
|
| 513 |
+
if enable_env in {"0", "false", "off", "no"}:
|
| 514 |
+
logger.info("MatAnyone disabled by ENABLE_MATANY=0.")
|
| 515 |
+
meta["disabled"] = True
|
| 516 |
+
return None, False, meta
|
| 517 |
+
|
| 518 |
+
# Apply T=1 squeeze patch before importing InferenceCore
|
| 519 |
+
try:
|
| 520 |
+
from .matany_compat_patch import apply_matany_t1_squeeze_guard
|
| 521 |
+
if apply_matany_t1_squeeze_guard():
|
| 522 |
+
logger.info("[MatAnyCompat] T=1 squeeze guard applied")
|
| 523 |
+
meta["patch_applied"] = True
|
| 524 |
+
else:
|
| 525 |
+
logger.warning("[MatAnyCompat] T=1 squeeze patch failed; conv2d errors may occur")
|
| 526 |
+
meta["patch_applied"] = False
|
| 527 |
+
except Exception as e:
|
| 528 |
+
logger.warning(f"[MatAnyCompat] Patch import failed: {e}")
|
| 529 |
+
meta["patch_applied"] = False
|
| 530 |
+
|
| 531 |
+
try:
|
| 532 |
+
from matanyone.inference.inference_core import InferenceCore # type: ignore
|
| 533 |
+
meta["matany_import_ok"] = True
|
| 534 |
+
# Log MatAnyone version and InferenceCore attributes
|
| 535 |
+
try:
|
| 536 |
+
version = importlib.metadata.version("matanyone")
|
| 537 |
+
logger.info(f"[MATANY] MatAnyone version: {version}")
|
| 538 |
+
except Exception:
|
| 539 |
+
logger.info("[MATANY] MatAnyone version unknown")
|
| 540 |
+
logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}")
|
| 541 |
+
device = _pick_device("MATANY_DEVICE")
|
| 542 |
+
repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
|
| 543 |
+
meta["matany_repo_id"] = repo_id
|
| 544 |
+
meta["matany_device"] = device
|
| 545 |
+
return None, True, meta
|
| 546 |
+
except Exception as e:
|
| 547 |
+
logger.warning(f"MatAnyone import failed: {e}")
|
| 548 |
+
return None, False, meta
|
| 549 |
+
|
| 550 |
+
# --------------------------------------------------------------------------------------
|
| 551 |
+
# Fallback Functions
|
| 552 |
+
# --------------------------------------------------------------------------------------
|
| 553 |
+
def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray:
|
| 554 |
+
"""Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255."""
|
| 555 |
+
h, w = first_frame_bgr.shape[:2]
|
| 556 |
+
if _HAS_MEDIAPIPE:
|
| 557 |
+
try:
|
| 558 |
+
mp_selfie = mp.solutions.selfie_segmentation
|
| 559 |
+
with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter:
|
| 560 |
+
rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
|
| 561 |
+
res = segmenter.process(rgb)
|
| 562 |
+
m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255
|
| 563 |
+
m = cv2.medianBlur(m, 5)
|
| 564 |
+
return m
|
| 565 |
+
except Exception as e:
|
| 566 |
+
logger.warning(f"MediaPipe fallback failed: {e}")
|
| 567 |
+
|
| 568 |
+
# Ultimate fallback: GrabCut
|
| 569 |
+
mask = np.zeros((h, w), np.uint8)
|
| 570 |
+
rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h))
|
| 571 |
+
bgdModel = np.zeros((1, 65), np.float64)
|
| 572 |
+
fgdModel = np.zeros((1, 65), np.float64)
|
| 573 |
+
try:
|
| 574 |
+
cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
|
| 575 |
+
mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
|
| 576 |
+
return mask_bin
|
| 577 |
+
except Exception as e:
|
| 578 |
+
logger.warning(f"GrabCut failed: {e}")
|
| 579 |
+
return np.zeros((h, w), dtype=np.uint8)
|
| 580 |
+
|
| 581 |
+
def composite_video(fg_path: Union[str, Path],
|
| 582 |
+
alpha_path: Union[str, Path],
|
| 583 |
+
bg_image_path: Union[str, Path],
|
| 584 |
+
out_path: Union[str, Path],
|
| 585 |
+
fps: int,
|
| 586 |
+
size: Tuple[int, int]) -> bool:
|
| 587 |
+
"""Blend MatAnyone FG+ALPHA over background using pro compositor."""
|
| 588 |
+
fg_cap = cv2.VideoCapture(str(fg_path))
|
| 589 |
+
al_cap = cv2.VideoCapture(str(alpha_path))
|
| 590 |
+
if not fg_cap.isOpened() or not al_cap.isOpened():
|
| 591 |
+
return False
|
| 592 |
+
|
| 593 |
+
w, h = size
|
| 594 |
+
bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
|
| 595 |
+
if bg is None:
|
| 596 |
+
bg = np.full((h, w, 3), 127, dtype=np.uint8)
|
| 597 |
+
bg_f = _resize_keep_ar(bg, (w, h))
|
| 598 |
+
|
| 599 |
+
if _probe_ffmpeg():
|
| 600 |
+
tmp_out = Path(str(out_path) + ".tmp.mp4")
|
| 601 |
+
writer = _video_writer(tmp_out, fps, (w, h))
|
| 602 |
+
post_h264 = True
|
| 603 |
+
else:
|
| 604 |
+
writer = _video_writer(Path(out_path), fps, (w, h))
|
| 605 |
+
post_h264 = False
|
| 606 |
+
|
| 607 |
+
ok_any = False
|
| 608 |
+
try:
|
| 609 |
+
while True:
|
| 610 |
+
ok_fg, fg = fg_cap.read()
|
| 611 |
+
ok_al, al = al_cap.read()
|
| 612 |
+
if not ok_fg or not ok_al:
|
| 613 |
+
break
|
| 614 |
+
fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC)
|
| 615 |
+
al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
|
| 616 |
+
|
| 617 |
+
comp = _composite_frame_pro(
|
| 618 |
+
cv2.cvtColor(fg, cv2.COLOR_BGR2RGB),
|
| 619 |
+
al_gray,
|
| 620 |
+
cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
|
| 621 |
+
)
|
| 622 |
+
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
|
| 623 |
+
ok_any = True
|
| 624 |
+
finally:
|
| 625 |
+
fg_cap.release()
|
| 626 |
+
al_cap.release()
|
| 627 |
+
writer.release()
|
| 628 |
+
|
| 629 |
+
if post_h264 and ok_any:
|
| 630 |
+
try:
|
| 631 |
+
cmd = [
|
| 632 |
+
_ffmpeg_bin(), "-y",
|
| 633 |
+
"-i", str(tmp_out),
|
| 634 |
+
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
|
| 635 |
+
str(out_path)
|
| 636 |
+
]
|
| 637 |
+
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 638 |
+
tmp_out.unlink(missing_ok=True)
|
| 639 |
+
except Exception as e:
|
| 640 |
+
logger.warning(f"ffmpeg finalize failed: {e}")
|
| 641 |
+
Path(out_path).unlink(missing_ok=True)
|
| 642 |
+
tmp_out.replace(out_path)
|
| 643 |
+
|
| 644 |
+
return ok_any
|
| 645 |
+
|
| 646 |
+
def fallback_composite(video_path: Union[str, Path],
|
| 647 |
+
mask_path: Union[str, Path],
|
| 648 |
+
bg_image_path: Union[str, Path],
|
| 649 |
+
out_path: Union[str, Path]) -> bool:
|
| 650 |
+
"""Static-mask compositing using pro compositor."""
|
| 651 |
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 652 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 653 |
+
if mask is None or not cap.isOpened():
|
| 654 |
+
return False
|
| 655 |
+
|
| 656 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
|
| 657 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
|
| 658 |
+
fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25))
|
| 659 |
+
|
| 660 |
+
bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR)
|
| 661 |
+
if bg is None:
|
| 662 |
+
bg = np.full((h, w, 3), 127, dtype=np.uint8)
|
| 663 |
+
|
| 664 |
+
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 665 |
+
bg_f = _resize_keep_ar(bg, (w, h))
|
| 666 |
+
|
| 667 |
+
if _probe_ffmpeg():
|
| 668 |
+
tmp_out = Path(str(out_path) + ".tmp.mp4")
|
| 669 |
+
writer = _video_writer(tmp_out, fps, (w, h))
|
| 670 |
+
use_post_ffmpeg = True
|
| 671 |
+
else:
|
| 672 |
+
writer = _video_writer(Path(out_path), fps, (w, h))
|
| 673 |
+
use_post_ffmpeg = False
|
| 674 |
+
|
| 675 |
+
ok_any = False
|
| 676 |
+
try:
|
| 677 |
+
while True:
|
| 678 |
+
ok, frame = cap.read()
|
| 679 |
+
if not ok:
|
| 680 |
+
break
|
| 681 |
+
comp = _composite_frame_pro(
|
| 682 |
+
cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
|
| 683 |
+
mask_resized,
|
| 684 |
+
cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB)
|
| 685 |
+
)
|
| 686 |
+
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
|
| 687 |
+
ok_any = True
|
| 688 |
+
finally:
|
| 689 |
+
cap.release()
|
| 690 |
+
writer.release()
|
| 691 |
+
|
| 692 |
+
if use_post_ffmpeg and ok_any:
|
| 693 |
+
try:
|
| 694 |
+
cmd = [
|
| 695 |
+
_ffmpeg_bin(), "-y",
|
| 696 |
+
"-i", str(tmp_out),
|
| 697 |
+
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart",
|
| 698 |
+
str(out_path)
|
| 699 |
+
]
|
| 700 |
+
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 701 |
+
tmp_out.unlink(missing_ok=True)
|
| 702 |
+
except Exception as e:
|
| 703 |
+
logger.warning(f"ffmpeg H.264 finalize failed: {e}")
|
| 704 |
+
Path(out_path).unlink(missing_ok=True)
|
| 705 |
+
tmp_out.replace(out_path)
|
| 706 |
+
|
| 707 |
+
return ok_any
|
| 708 |
+
|
| 709 |
+
# --------------------------------------------------------------------------------------
|
| 710 |
+
# Stage-A (Transparent Export) Functions
|
| 711 |
+
# --------------------------------------------------------------------------------------
|
| 712 |
+
def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray:
|
| 713 |
+
y, x = np.mgrid[0:h, 0:w]
|
| 714 |
+
c = ((x // tile) + (y // tile)) % 2
|
| 715 |
+
a = np.where(c == 0, 200, 150).astype(np.uint8)
|
| 716 |
+
return np.stack([a, a, a], axis=-1)
|
| 717 |
+
|
| 718 |
+
def _build_stage_a_rgba_vp9_from_fg_alpha(
|
| 719 |
+
fg_path: Union[str, Path],
|
| 720 |
+
alpha_path: Union[str, Path],
|
| 721 |
+
out_webm: Union[str, Path],
|
| 722 |
+
fps: int,
|
| 723 |
+
size: Tuple[int, int],
|
| 724 |
+
src_audio: Optional[Union[str, Path]] = None,
|
| 725 |
+
) -> bool:
|
| 726 |
+
if not _probe_ffmpeg():
|
| 727 |
+
return False
|
| 728 |
+
w, h = size
|
| 729 |
+
try:
|
| 730 |
+
cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)]
|
| 731 |
+
if src_audio:
|
| 732 |
+
cmd += ["-i", str(src_audio)]
|
| 733 |
+
fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \
|
| 734 |
+
f"[0:v]scale={w}:{h},fps={fps}[fg];" \
|
| 735 |
+
f"[fg][al]alphamerge[outv]"
|
| 736 |
+
cmd += ["-filter_complex", fcx, "-map", "[outv]"]
|
| 737 |
+
if src_audio:
|
| 738 |
+
cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"]
|
| 739 |
+
cmd += [
|
| 740 |
+
"-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
|
| 741 |
+
"-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
|
| 742 |
+
"-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
|
| 743 |
+
]
|
| 744 |
+
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 745 |
+
return True
|
| 746 |
+
except Exception as e:
|
| 747 |
+
logger.warning(f"Stage-A VP9(alpha) build failed: {e}")
|
| 748 |
+
return False
|
| 749 |
+
|
| 750 |
+
def _build_stage_a_rgba_vp9_from_mask(
|
| 751 |
+
video_path: Union[str, Path],
|
| 752 |
+
mask_png: Union[str, Path],
|
| 753 |
+
out_webm: Union[str, Path],
|
| 754 |
+
fps: int,
|
| 755 |
+
size: Tuple[int, int],
|
| 756 |
+
) -> bool:
|
| 757 |
+
if not _probe_ffmpeg():
|
| 758 |
+
return False
|
| 759 |
+
w, h = size
|
| 760 |
+
try:
|
| 761 |
+
cmd = [
|
| 762 |
+
_ffmpeg_bin(), "-y",
|
| 763 |
+
"-i", str(video_path),
|
| 764 |
+
"-loop", "1", "-i", str(mask_png),
|
| 765 |
+
"-filter_complex",
|
| 766 |
+
f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];"
|
| 767 |
+
f"[0:v]scale={w}:{h},fps={fps}[fg];"
|
| 768 |
+
f"[fg][al]alphamerge[outv]",
|
| 769 |
+
"-map", "[outv]",
|
| 770 |
+
"-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p",
|
| 771 |
+
"-crf", os.environ.get("STAGEA_VP9_CRF", "28"),
|
| 772 |
+
"-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm),
|
| 773 |
+
]
|
| 774 |
+
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 775 |
+
return True
|
| 776 |
+
except Exception as e:
|
| 777 |
+
logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}")
|
| 778 |
+
return False
|
| 779 |
+
|
| 780 |
+
def _build_stage_a_checkerboard_from_fg_alpha(
|
| 781 |
+
fg_path: Union[str, Path],
|
| 782 |
+
alpha_path: Union[str, Path],
|
| 783 |
+
out_mp4: Union[str, Path],
|
| 784 |
+
fps: int,
|
| 785 |
+
size: Tuple[int, int],
|
| 786 |
+
) -> bool:
|
| 787 |
+
fg_cap = cv2.VideoCapture(str(fg_path))
|
| 788 |
+
al_cap = cv2.VideoCapture(str(alpha_path))
|
| 789 |
+
if not fg_cap.isOpened() or not al_cap.isOpened():
|
| 790 |
+
return False
|
| 791 |
+
w, h = size
|
| 792 |
+
writer = _video_writer(Path(out_mp4), fps, (w, h))
|
| 793 |
+
bg = _checkerboard_bg(w, h)
|
| 794 |
+
ok_any = False
|
| 795 |
+
try:
|
| 796 |
+
while True:
|
| 797 |
+
okf, fg = fg_cap.read()
|
| 798 |
+
oka, al = al_cap.read()
|
| 799 |
+
if not okf or not oka:
|
| 800 |
+
break
|
| 801 |
+
fg = cv2.resize(fg, (w, h))
|
| 802 |
+
al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY)
|
| 803 |
+
comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg)
|
| 804 |
+
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
|
| 805 |
+
ok_any = True
|
| 806 |
+
finally:
|
| 807 |
+
fg_cap.release()
|
| 808 |
+
al_cap.release()
|
| 809 |
+
writer.release()
|
| 810 |
+
return ok_any
|
| 811 |
+
|
| 812 |
+
def _build_stage_a_checkerboard_from_mask(
|
| 813 |
+
video_path: Union[str, Path],
|
| 814 |
+
mask_png: Union[str, Path],
|
| 815 |
+
out_mp4: Union[str, Path],
|
| 816 |
+
fps: int,
|
| 817 |
+
size: Tuple[int, int],
|
| 818 |
+
) -> bool:
|
| 819 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 820 |
+
if not cap.isOpened():
|
| 821 |
+
return False
|
| 822 |
+
w, h = size
|
| 823 |
+
mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE)
|
| 824 |
+
if mask is None:
|
| 825 |
+
return False
|
| 826 |
+
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 827 |
+
writer = _video_writer(Path(out_mp4), fps, (w, h))
|
| 828 |
+
bg = _checkerboard_bg(w, h)
|
| 829 |
+
ok_any = False
|
| 830 |
+
try:
|
| 831 |
+
while True:
|
| 832 |
+
ok, frame = cap.read()
|
| 833 |
+
if not ok:
|
| 834 |
+
break
|
| 835 |
+
frame = cv2.resize(frame, (w, h))
|
| 836 |
+
comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg)
|
| 837 |
+
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
|
| 838 |
+
ok_any = True
|
| 839 |
+
finally:
|
| 840 |
+
cap.release()
|
| 841 |
+
writer.release()
|
| 842 |
+
return ok_any
|
| 843 |
+
|
| 844 |
+
# --------------------------------------------------------------------------------------
|
| 845 |
+
# MatAnyone Integration
|
| 846 |
+
# --------------------------------------------------------------------------------------
|
| 847 |
+
def run_matany(
|
| 848 |
+
video_path: Union[str, Path],
|
| 849 |
+
mask_path: Optional[Union[str, Path]],
|
| 850 |
+
out_dir: Union[str, Path],
|
| 851 |
+
device: Optional[str] = None,
|
| 852 |
+
progress_callback: Optional[Callable[[float, str], None]] = None,
|
| 853 |
+
) -> Tuple[Path, Path]:
|
| 854 |
+
"""
|
| 855 |
+
Run MatAnyone streaming matting via our shape-guarded adapter.
|
| 856 |
+
Returns (alpha_mp4_path, fg_mp4_path).
|
| 857 |
+
Raises MatAnyError on failure.
|
| 858 |
+
"""
|
| 859 |
+
from .matanyone_loader import MatAnyoneSession, MatAnyError
|
| 860 |
+
|
| 861 |
+
session = MatAnyoneSession(device=device, precision="auto")
|
| 862 |
+
alpha_p, fg_p = session.process_stream(
|
| 863 |
+
video_path=Path(video_path),
|
| 864 |
+
seed_mask_path=Path(mask_path) if mask_path else None,
|
| 865 |
+
out_dir=Path(out_dir),
|
| 866 |
+
progress_cb=progress_callback,
|
| 867 |
+
)
|
| 868 |
+
return alpha_p, fg_p
|
VideoBackgroundReplacer2/models/matanyone_loader.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
MatAnyone adapter β Using Official API (File-Based)
|
| 5 |
+
|
| 6 |
+
Fixed to use MatAnyone's official process_video() API instead of
|
| 7 |
+
bypassing it with internal tensor manipulation. This eliminates
|
| 8 |
+
all 5D tensor dimension issues.
|
| 9 |
+
|
| 10 |
+
Changes (2025-09-17):
|
| 11 |
+
- Replaced custom tensor processing with official MatAnyone API
|
| 12 |
+
- Uses file-based input/output as designed by MatAnyone authors
|
| 13 |
+
- Eliminates all tensor dimension compatibility issues
|
| 14 |
+
- Simplified error handling and logging
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
import logging
|
| 21 |
+
import tempfile
|
| 22 |
+
import importlib.metadata
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Optional, Callable, Tuple
|
| 25 |
+
|
| 26 |
+
log = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# ---------- Progress helper ----------
|
| 29 |
+
def _env_flag(name: str, default: str = "0") -> bool:
|
| 30 |
+
return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"}
|
| 31 |
+
|
| 32 |
+
_PROGRESS_CB_ENABLED = _env_flag("MATANY_PROGRESS", "1")
|
| 33 |
+
_PROGRESS_MIN_INTERVAL = float(os.getenv("MATANY_PROGRESS_MIN_SEC", "0.25"))
|
| 34 |
+
_progress_last = 0.0
|
| 35 |
+
_progress_last_msg = None
|
| 36 |
+
_progress_disabled = False
|
| 37 |
+
|
| 38 |
+
def _emit_progress(cb, pct: float, msg: str):
|
| 39 |
+
global _progress_last, _progress_last_msg, _progress_disabled
|
| 40 |
+
if not cb or not _PROGRESS_CB_ENABLED or _progress_disabled:
|
| 41 |
+
return
|
| 42 |
+
now = time.time()
|
| 43 |
+
if (now - _progress_last) < _PROGRESS_MIN_INTERVAL and msg == _progress_last_msg:
|
| 44 |
+
return
|
| 45 |
+
try:
|
| 46 |
+
try:
|
| 47 |
+
cb(pct, msg) # preferred (pct, msg)
|
| 48 |
+
except TypeError:
|
| 49 |
+
cb(msg) # legacy (msg)
|
| 50 |
+
_progress_last = now
|
| 51 |
+
_progress_last_msg = msg
|
| 52 |
+
except Exception as e:
|
| 53 |
+
_progress_disabled = True
|
| 54 |
+
log.warning("[progress-cb] disabled due to exception: %s", e)
|
| 55 |
+
|
| 56 |
+
# ---------- Errors ----------
|
| 57 |
+
class MatAnyError(RuntimeError):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
# ---------- CUDA helpers ----------
|
| 61 |
+
def _cuda_snapshot(device: Optional[str]) -> str:
|
| 62 |
+
try:
|
| 63 |
+
import torch
|
| 64 |
+
if not torch.cuda.is_available():
|
| 65 |
+
return "CUDA: N/A"
|
| 66 |
+
idx = 0
|
| 67 |
+
if device and device.startswith("cuda:"):
|
| 68 |
+
try:
|
| 69 |
+
idx = int(device.split(":")[1])
|
| 70 |
+
except (ValueError, IndexError):
|
| 71 |
+
idx = 0
|
| 72 |
+
name = torch.cuda.get_device_name(idx)
|
| 73 |
+
alloc = torch.cuda.memory_allocated(idx) / (1024**3)
|
| 74 |
+
resv = torch.cuda.memory_reserved(idx) / (1024**3)
|
| 75 |
+
return f"device={idx}, name={name}, alloc={alloc:.2f}GB, reserved={resv:.2f}GB"
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return f"CUDA snapshot error: {e!r}"
|
| 78 |
+
|
| 79 |
+
def _safe_empty_cache():
|
| 80 |
+
try:
|
| 81 |
+
import torch
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
log.info(f"[MATANY] CUDA memory before empty_cache: {_cuda_snapshot('cuda:0')}")
|
| 84 |
+
torch.cuda.empty_cache()
|
| 85 |
+
log.info(f"[MATANY] CUDA memory after empty_cache: {_cuda_snapshot('cuda:0')}")
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
# ============================================================================
|
| 90 |
+
|
| 91 |
+
class MatAnyoneSession:
|
| 92 |
+
"""
|
| 93 |
+
Simple wrapper around MatAnyone's official API.
|
| 94 |
+
Uses file-based input/output as designed by the MatAnyone authors.
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self, device: Optional[str] = None, precision: str = "auto"):
|
| 97 |
+
self.device = device or ("cuda" if self._cuda_available() else "cpu")
|
| 98 |
+
self.precision = precision.lower()
|
| 99 |
+
|
| 100 |
+
# Log MatAnyone version
|
| 101 |
+
try:
|
| 102 |
+
version = importlib.metadata.version("matanyone")
|
| 103 |
+
log.info(f"[MATANY] MatAnyone version: {version}")
|
| 104 |
+
except Exception:
|
| 105 |
+
log.info("[MATANY] MatAnyone version unknown")
|
| 106 |
+
|
| 107 |
+
# Initialize MatAnyone's official API
|
| 108 |
+
try:
|
| 109 |
+
from matanyone import InferenceCore
|
| 110 |
+
self.processor = InferenceCore("PeiqingYang/MatAnyone")
|
| 111 |
+
log.info("[MATANY] MatAnyone InferenceCore initialized successfully")
|
| 112 |
+
except Exception as e:
|
| 113 |
+
raise MatAnyError(f"Failed to initialize MatAnyone: {e}")
|
| 114 |
+
|
| 115 |
+
def _cuda_available(self) -> bool:
|
| 116 |
+
try:
|
| 117 |
+
import torch
|
| 118 |
+
return torch.cuda.is_available()
|
| 119 |
+
except Exception:
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
def process_stream(
|
| 123 |
+
self,
|
| 124 |
+
video_path: Path,
|
| 125 |
+
seed_mask_path: Optional[Path] = None,
|
| 126 |
+
out_dir: Optional[Path] = None,
|
| 127 |
+
progress_cb: Optional[Callable] = None,
|
| 128 |
+
) -> Tuple[Path, Path]:
|
| 129 |
+
"""
|
| 130 |
+
Process video using MatAnyone's official API.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
video_path: Path to input video file
|
| 134 |
+
seed_mask_path: Path to first-frame mask PNG (white=foreground, black=background)
|
| 135 |
+
out_dir: Output directory for results
|
| 136 |
+
progress_cb: Progress callback function
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Tuple of (alpha_path, foreground_path)
|
| 140 |
+
"""
|
| 141 |
+
video_path = Path(video_path)
|
| 142 |
+
if not video_path.exists():
|
| 143 |
+
raise MatAnyError(f"Video file not found: {video_path}")
|
| 144 |
+
|
| 145 |
+
if seed_mask_path and not Path(seed_mask_path).exists():
|
| 146 |
+
raise MatAnyError(f"Seed mask not found: {seed_mask_path}")
|
| 147 |
+
|
| 148 |
+
out_dir = Path(out_dir) if out_dir else video_path.parent / "matanyone_output"
|
| 149 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 150 |
+
|
| 151 |
+
log.info(f"[MATANY] Processing video: {video_path}")
|
| 152 |
+
log.info(f"[MATANY] Using mask: {seed_mask_path}")
|
| 153 |
+
log.info(f"[MATANY] Output directory: {out_dir}")
|
| 154 |
+
|
| 155 |
+
_emit_progress(progress_cb, 0.0, "Initializing MatAnyone processing...")
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
# Use MatAnyone's official API
|
| 159 |
+
start_time = time.time()
|
| 160 |
+
|
| 161 |
+
_emit_progress(progress_cb, 0.1, "Running MatAnyone video matting...")
|
| 162 |
+
|
| 163 |
+
# Call the official process_video method
|
| 164 |
+
foreground_path, alpha_path = self.processor.process_video(
|
| 165 |
+
input_path=str(video_path),
|
| 166 |
+
mask_path=str(seed_mask_path) if seed_mask_path else None,
|
| 167 |
+
output_path=str(out_dir)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
processing_time = time.time() - start_time
|
| 171 |
+
log.info(f"[MATANY] Processing completed in {processing_time:.1f}s")
|
| 172 |
+
log.info(f"[MATANY] Foreground output: {foreground_path}")
|
| 173 |
+
log.info(f"[MATANY] Alpha output: {alpha_path}")
|
| 174 |
+
|
| 175 |
+
# Convert to Path objects
|
| 176 |
+
fg_path = Path(foreground_path) if foreground_path else None
|
| 177 |
+
al_path = Path(alpha_path) if alpha_path else None
|
| 178 |
+
|
| 179 |
+
# Verify outputs exist
|
| 180 |
+
if not fg_path or not fg_path.exists():
|
| 181 |
+
raise MatAnyError(f"Foreground output not created: {fg_path}")
|
| 182 |
+
if not al_path or not al_path.exists():
|
| 183 |
+
raise MatAnyError(f"Alpha output not created: {al_path}")
|
| 184 |
+
|
| 185 |
+
_emit_progress(progress_cb, 1.0, "MatAnyone processing complete")
|
| 186 |
+
|
| 187 |
+
return al_path, fg_path # Return (alpha, foreground) to match expected order
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
log.error(f"[MATANY] Processing failed: {e}")
|
| 191 |
+
raise MatAnyError(f"MatAnyone processing failed: {e}")
|
| 192 |
+
|
| 193 |
+
finally:
|
| 194 |
+
_safe_empty_cache()
|
| 195 |
+
|
| 196 |
+
# ============================================================================
|
| 197 |
+
# MatAnyoneModel Wrapper Class for app_hf.py compatibility
|
| 198 |
+
# ============================================================================
|
| 199 |
+
|
| 200 |
+
class MatAnyoneModel:
|
| 201 |
+
"""Wrapper class for MatAnyone to match app_hf.py interface"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, device="cuda"):
|
| 204 |
+
self.device = device
|
| 205 |
+
self.session = None
|
| 206 |
+
self.loaded = False
|
| 207 |
+
log.info(f"Initializing MatAnyoneModel on device: {device}")
|
| 208 |
+
|
| 209 |
+
# Initialize the session
|
| 210 |
+
self._load_model()
|
| 211 |
+
|
| 212 |
+
def _load_model(self):
|
| 213 |
+
"""Load the MatAnyone session"""
|
| 214 |
+
try:
|
| 215 |
+
self.session = MatAnyoneSession(device=self.device, precision="auto")
|
| 216 |
+
self.loaded = True
|
| 217 |
+
log.info("MatAnyoneModel loaded successfully")
|
| 218 |
+
except Exception as e:
|
| 219 |
+
log.error(f"Error loading MatAnyoneModel: {e}")
|
| 220 |
+
self.loaded = False
|
| 221 |
+
|
| 222 |
+
def replace_background(self, video_path, masks, background_path):
|
| 223 |
+
"""Replace background in video using MatAnyone"""
|
| 224 |
+
if not self.loaded:
|
| 225 |
+
raise MatAnyError("MatAnyoneModel not loaded")
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
from pathlib import Path
|
| 229 |
+
import tempfile
|
| 230 |
+
|
| 231 |
+
# Convert paths to Path objects
|
| 232 |
+
video_path = Path(video_path)
|
| 233 |
+
|
| 234 |
+
# For now, we expect masks to be a path to the first-frame mask
|
| 235 |
+
mask_path = Path(masks) if isinstance(masks, (str, Path)) else None
|
| 236 |
+
|
| 237 |
+
# Create output directory
|
| 238 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 239 |
+
output_dir = Path(temp_dir)
|
| 240 |
+
|
| 241 |
+
# Process the video stream
|
| 242 |
+
alpha_path, fg_path = self.session.process_stream(
|
| 243 |
+
video_path=video_path,
|
| 244 |
+
seed_mask_path=mask_path,
|
| 245 |
+
out_dir=output_dir,
|
| 246 |
+
progress_cb=None
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Return the foreground video path
|
| 250 |
+
# In a full implementation, you'd composite with the background_path
|
| 251 |
+
return str(fg_path)
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
log.error(f"Error in replace_background: {e}")
|
| 255 |
+
raise MatAnyError(f"Background replacement failed: {e}")
|
| 256 |
+
|
| 257 |
+
# ============================================================================
|
| 258 |
+
# Helper function for pipeline integration
|
| 259 |
+
# ============================================================================
|
| 260 |
+
|
| 261 |
+
def create_matanyone_session(device=None):
|
| 262 |
+
"""Create a MatAnyone session for use in pipeline"""
|
| 263 |
+
return MatAnyoneSession(device=device)
|
| 264 |
+
|
| 265 |
+
def run_matanyone_on_files(video_path, mask_path, output_dir, device="cuda", progress_callback=None):
|
| 266 |
+
"""
|
| 267 |
+
Run MatAnyone on video and mask files.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
video_path: Path to input video
|
| 271 |
+
mask_path: Path to first-frame mask PNG
|
| 272 |
+
output_dir: Directory for outputs
|
| 273 |
+
device: Device to use (cuda/cpu)
|
| 274 |
+
progress_callback: Progress callback function
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
Tuple of (alpha_path, foreground_path) or (None, None) on failure
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
session = MatAnyoneSession(device=device)
|
| 281 |
+
alpha_path, fg_path = session.process_stream(
|
| 282 |
+
video_path=Path(video_path),
|
| 283 |
+
seed_mask_path=Path(mask_path) if mask_path else None,
|
| 284 |
+
out_dir=Path(output_dir),
|
| 285 |
+
progress_cb=progress_callback
|
| 286 |
+
)
|
| 287 |
+
return str(alpha_path), str(fg_path)
|
| 288 |
+
except Exception as e:
|
| 289 |
+
log.error(f"MatAnyone processing failed: {e}")
|
| 290 |
+
return None, None
|
VideoBackgroundReplacer2/models/sam2_loader.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SAM2 Loader with T4-optimized predictor wrapper
|
| 4 |
+
Provides SAM2Predictor class with memory management and optimization features
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import gc
|
| 9 |
+
import torch
|
| 10 |
+
import logging
|
| 11 |
+
import numpy as np
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Optional, Any, Dict, List, Tuple
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class SAM2Predictor:
|
| 18 |
+
"""
|
| 19 |
+
T4-optimized SAM2 video predictor wrapper with memory management
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, device: torch.device, model_size: str = "small"):
|
| 23 |
+
self.device = device
|
| 24 |
+
self.model_size = model_size
|
| 25 |
+
self.predictor = None
|
| 26 |
+
self.model = None
|
| 27 |
+
self._load_predictor()
|
| 28 |
+
|
| 29 |
+
def _load_predictor(self):
|
| 30 |
+
"""Load SAM2 predictor with optimizations"""
|
| 31 |
+
try:
|
| 32 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 33 |
+
|
| 34 |
+
# Download checkpoint if needed
|
| 35 |
+
checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt"
|
| 36 |
+
if not self._ensure_checkpoint(checkpoint_path):
|
| 37 |
+
raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint")
|
| 38 |
+
|
| 39 |
+
# Build predictor
|
| 40 |
+
model_cfg = f"sam2_hiera_{self.model_size[0]}.yaml" # small -> s, base -> b, large -> l
|
| 41 |
+
self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
|
| 42 |
+
|
| 43 |
+
# Apply T4 optimizations
|
| 44 |
+
self._optimize_for_t4()
|
| 45 |
+
|
| 46 |
+
logger.info(f"SAM2 {self.model_size} predictor loaded successfully")
|
| 47 |
+
|
| 48 |
+
except ImportError as e:
|
| 49 |
+
logger.error(f"SAM2 import failed: {e}")
|
| 50 |
+
raise RuntimeError("SAM2 not available - check third_party/sam2 installation")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"SAM2 loading failed: {e}")
|
| 53 |
+
raise
|
| 54 |
+
|
| 55 |
+
def _ensure_checkpoint(self, checkpoint_path: str) -> bool:
|
| 56 |
+
"""Ensure checkpoint exists, download if needed"""
|
| 57 |
+
checkpoint_file = Path(checkpoint_path)
|
| 58 |
+
|
| 59 |
+
if checkpoint_file.exists():
|
| 60 |
+
file_size = checkpoint_file.stat().st_size / (1024**2)
|
| 61 |
+
if file_size > 50: # At least 50MB
|
| 62 |
+
logger.info(f"SAM2 checkpoint exists: {file_size:.1f}MB")
|
| 63 |
+
return True
|
| 64 |
+
else:
|
| 65 |
+
logger.warning(f"Checkpoint too small ({file_size:.1f}MB), re-downloading")
|
| 66 |
+
checkpoint_file.unlink()
|
| 67 |
+
|
| 68 |
+
return self._download_checkpoint(checkpoint_path)
|
| 69 |
+
|
| 70 |
+
def _download_checkpoint(self, checkpoint_path: str, timeout_seconds: int = 600) -> bool:
|
| 71 |
+
"""Download SAM2 checkpoint"""
|
| 72 |
+
try:
|
| 73 |
+
logger.info(f"Downloading SAM2 {self.model_size} checkpoint...")
|
| 74 |
+
|
| 75 |
+
checkpoint_file = Path(checkpoint_path)
|
| 76 |
+
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
import requests
|
| 79 |
+
|
| 80 |
+
# Checkpoint URLs
|
| 81 |
+
urls = {
|
| 82 |
+
"small": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
|
| 83 |
+
"base": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
|
| 84 |
+
"large": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
if self.model_size not in urls:
|
| 88 |
+
raise ValueError(f"Unknown model size: {self.model_size}")
|
| 89 |
+
|
| 90 |
+
checkpoint_url = urls[self.model_size]
|
| 91 |
+
|
| 92 |
+
import time
|
| 93 |
+
start_time = time.time()
|
| 94 |
+
response = requests.get(checkpoint_url, stream=True, timeout=30)
|
| 95 |
+
response.raise_for_status()
|
| 96 |
+
|
| 97 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 98 |
+
|
| 99 |
+
temp_path = checkpoint_file.with_suffix('.download')
|
| 100 |
+
downloaded = 0
|
| 101 |
+
last_log = start_time
|
| 102 |
+
|
| 103 |
+
with open(temp_path, 'wb') as f:
|
| 104 |
+
for chunk in response.iter_content(chunk_size=1024*1024):
|
| 105 |
+
if chunk:
|
| 106 |
+
f.write(chunk)
|
| 107 |
+
downloaded += len(chunk)
|
| 108 |
+
|
| 109 |
+
current_time = time.time()
|
| 110 |
+
if current_time - start_time > timeout_seconds:
|
| 111 |
+
raise TimeoutError(f"Download timeout after {timeout_seconds}s")
|
| 112 |
+
|
| 113 |
+
# Progress logging every 15 seconds
|
| 114 |
+
if current_time - last_log > 15:
|
| 115 |
+
progress = (downloaded / total_size * 100) if total_size > 0 else 0
|
| 116 |
+
speed = downloaded / (current_time - start_time) / (1024**2)
|
| 117 |
+
logger.info(f"Download: {progress:.1f}% ({speed:.1f}MB/s)")
|
| 118 |
+
last_log = current_time
|
| 119 |
+
|
| 120 |
+
temp_path.rename(checkpoint_file)
|
| 121 |
+
|
| 122 |
+
download_time = time.time() - start_time
|
| 123 |
+
speed = downloaded / download_time / (1024**2)
|
| 124 |
+
logger.info(f"Download complete: {downloaded/(1024**2):.1f}MB in {download_time:.1f}s ({speed:.1f}MB/s)")
|
| 125 |
+
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.error(f"Checkpoint download failed: {e}")
|
| 130 |
+
if Path(checkpoint_path).exists():
|
| 131 |
+
Path(checkpoint_path).unlink()
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
def _optimize_for_t4(self):
|
| 135 |
+
"""Apply T4-specific optimizations"""
|
| 136 |
+
try:
|
| 137 |
+
if hasattr(self.predictor, "model") and self.predictor.model is not None:
|
| 138 |
+
self.model = self.predictor.model
|
| 139 |
+
|
| 140 |
+
# Apply fp16 and channels_last for T4 efficiency
|
| 141 |
+
self.model = self.model.half().to(self.device)
|
| 142 |
+
self.model = self.model.to(memory_format=torch.channels_last)
|
| 143 |
+
|
| 144 |
+
logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.warning(f"SAM2 T4 optimization warning: {e}")
|
| 148 |
+
|
| 149 |
+
def init_state(self, video_path: str):
|
| 150 |
+
"""Initialize video processing state"""
|
| 151 |
+
if self.predictor is None:
|
| 152 |
+
raise RuntimeError("Predictor not loaded")
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
return self.predictor.init_state(video_path=video_path)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Failed to initialize video state: {e}")
|
| 158 |
+
raise
|
| 159 |
+
|
| 160 |
+
def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
|
| 161 |
+
points: np.ndarray, labels: np.ndarray):
|
| 162 |
+
"""Add new points for tracking"""
|
| 163 |
+
if self.predictor is None:
|
| 164 |
+
raise RuntimeError("Predictor not loaded")
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
return self.predictor.add_new_points(
|
| 168 |
+
inference_state=inference_state,
|
| 169 |
+
frame_idx=frame_idx,
|
| 170 |
+
obj_id=obj_id,
|
| 171 |
+
points=points,
|
| 172 |
+
labels=labels
|
| 173 |
+
)
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.error(f"Failed to add new points: {e}")
|
| 176 |
+
raise
|
| 177 |
+
|
| 178 |
+
def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
|
| 179 |
+
"""Propagate through video with optional scaling"""
|
| 180 |
+
if self.predictor is None:
|
| 181 |
+
raise RuntimeError("Predictor not loaded")
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
# Use the predictor's propagate_in_video method
|
| 185 |
+
return self.predictor.propagate_in_video(inference_state, **kwargs)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.error(f"Failed to propagate in video: {e}")
|
| 188 |
+
raise
|
| 189 |
+
|
| 190 |
+
def prune_state(self, inference_state, keep: int):
|
| 191 |
+
"""Prune SAM2 state to keep only recent frames in memory"""
|
| 192 |
+
try:
|
| 193 |
+
# Try to access and prune internal caches
|
| 194 |
+
# This is model-specific and may need adjustment based on SAM2 internals
|
| 195 |
+
if hasattr(inference_state, 'cached_features'):
|
| 196 |
+
# Keep only the most recent 'keep' frames
|
| 197 |
+
cached_keys = list(inference_state.cached_features.keys())
|
| 198 |
+
if len(cached_keys) > keep:
|
| 199 |
+
keys_to_remove = cached_keys[:-keep]
|
| 200 |
+
for key in keys_to_remove:
|
| 201 |
+
if key in inference_state.cached_features:
|
| 202 |
+
del inference_state.cached_features[key]
|
| 203 |
+
logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
|
| 204 |
+
|
| 205 |
+
# Clear other potential caches
|
| 206 |
+
if hasattr(inference_state, 'point_inputs_per_obj'):
|
| 207 |
+
# Keep recent point inputs only
|
| 208 |
+
for obj_id in list(inference_state.point_inputs_per_obj.keys()):
|
| 209 |
+
obj_inputs = inference_state.point_inputs_per_obj[obj_id]
|
| 210 |
+
if len(obj_inputs) > keep:
|
| 211 |
+
# Keep only recent entries
|
| 212 |
+
recent_keys = sorted(obj_inputs.keys())[-keep:]
|
| 213 |
+
new_inputs = {k: obj_inputs[k] for k in recent_keys}
|
| 214 |
+
inference_state.point_inputs_per_obj[obj_id] = new_inputs
|
| 215 |
+
|
| 216 |
+
# Force garbage collection
|
| 217 |
+
torch.cuda.empty_cache() if self.device.type == 'cuda' else None
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.debug(f"State pruning warning: {e}")
|
| 221 |
+
|
| 222 |
+
def clear_memory(self):
|
| 223 |
+
"""Clear GPU memory aggressively"""
|
| 224 |
+
try:
|
| 225 |
+
if self.device.type == 'cuda':
|
| 226 |
+
torch.cuda.empty_cache()
|
| 227 |
+
torch.cuda.synchronize()
|
| 228 |
+
torch.cuda.ipc_collect()
|
| 229 |
+
gc.collect()
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.warning(f"Memory clearing warning: {e}")
|
| 232 |
+
|
| 233 |
+
def get_memory_usage(self) -> Dict[str, float]:
|
| 234 |
+
"""Get current memory usage statistics"""
|
| 235 |
+
if self.device.type != 'cuda':
|
| 236 |
+
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
|
| 240 |
+
reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
|
| 241 |
+
free, total = torch.cuda.mem_get_info(self.device)
|
| 242 |
+
free_gb = free / (1024**3)
|
| 243 |
+
|
| 244 |
+
return {
|
| 245 |
+
"allocated_gb": allocated,
|
| 246 |
+
"reserved_gb": reserved,
|
| 247 |
+
"free_gb": free_gb,
|
| 248 |
+
"total_gb": total / (1024**3)
|
| 249 |
+
}
|
| 250 |
+
except Exception:
|
| 251 |
+
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
|
| 252 |
+
|
| 253 |
+
def __del__(self):
|
| 254 |
+
"""Cleanup on deletion"""
|
| 255 |
+
try:
|
| 256 |
+
if hasattr(self, 'predictor') and self.predictor is not None:
|
| 257 |
+
del self.predictor
|
| 258 |
+
if hasattr(self, 'model') and self.model is not None:
|
| 259 |
+
del self.model
|
| 260 |
+
self.clear_memory()
|
| 261 |
+
except Exception:
|
| 262 |
+
pass
|
VideoBackgroundReplacer2/pipeline.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
pipeline.py β Production SAM2 + MatAnyone (T4-optimized, single-pass streaming)
|
| 4 |
+
|
| 5 |
+
Key features
|
| 6 |
+
------------
|
| 7 |
+
- One SAM2 inference state for the entire video (no per-chunk reinit).
|
| 8 |
+
- In-stream pipeline: Read β SAM2 β MatAnyone β Compose β Write (no big RAM dicts).
|
| 9 |
+
- Bounded memory everywhere (deque/window); optional CPU spill.
|
| 10 |
+
- fp16 + channels_last on SAM2; mixed precision blocks.
|
| 11 |
+
- VRAM-aware controller adjusts memory window/scale.
|
| 12 |
+
- Heartbeat logger to prevent HF watchdog restarts.
|
| 13 |
+
- Safer FFmpeg audio re-mux.
|
| 14 |
+
|
| 15 |
+
Compatible with Tesla T4 (β15β16 GB) and PyTorch 2.5.x + CUDA 12.4 wheels.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import gc
|
| 20 |
+
import cv2
|
| 21 |
+
import time
|
| 22 |
+
import uuid
|
| 23 |
+
import torch
|
| 24 |
+
import queue
|
| 25 |
+
import shutil
|
| 26 |
+
import logging
|
| 27 |
+
import tempfile
|
| 28 |
+
import subprocess
|
| 29 |
+
import threading
|
| 30 |
+
import numpy as np
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Optional, Tuple, Dict, Any, Callable
|
| 34 |
+
from collections import deque
|
| 35 |
+
|
| 36 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 37 |
+
# Logging
|
| 38 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 39 |
+
logger = logging.getLogger("backgroundfx_pro")
|
| 40 |
+
if not logger.handlers:
|
| 41 |
+
h = logging.StreamHandler()
|
| 42 |
+
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
|
| 43 |
+
logger.addHandler(h)
|
| 44 |
+
logger.setLevel(logging.INFO)
|
| 45 |
+
|
| 46 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 47 |
+
# Environment & Torch tuning for T4
|
| 48 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 49 |
+
def setup_t4_environment():
|
| 50 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF",
|
| 51 |
+
"expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
|
| 52 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 53 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 54 |
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 55 |
+
os.environ.setdefault("OPENCV_OPENCL_RUNTIME", "disabled")
|
| 56 |
+
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
|
| 57 |
+
|
| 58 |
+
torch.set_grad_enabled(False)
|
| 59 |
+
try:
|
| 60 |
+
torch.backends.cudnn.benchmark = True
|
| 61 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 62 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 63 |
+
torch.set_float32_matmul_precision("high")
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
if torch.cuda.is_available():
|
| 68 |
+
try:
|
| 69 |
+
frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88"))
|
| 70 |
+
torch.cuda.set_per_process_memory_fraction(frac)
|
| 71 |
+
logger.info(f"CUDA per-process memory fraction = {frac:.2f}")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.warning(f"Could not set CUDA memory fraction: {e}")
|
| 74 |
+
|
| 75 |
+
def vram_gb() -> Tuple[float, float]:
|
| 76 |
+
if not torch.cuda.is_available():
|
| 77 |
+
return 0.0, 0.0
|
| 78 |
+
free, total = torch.cuda.mem_get_info()
|
| 79 |
+
return free / (1024 ** 3), total / (1024 ** 3)
|
| 80 |
+
|
| 81 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 82 |
+
# Heartbeat (prevents Spaces watchdog killing the job)
|
| 83 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 84 |
+
def heartbeat_monitor(running_flag: Dict[str, bool], interval: float = 8.0):
|
| 85 |
+
while running_flag.get("running", False):
|
| 86 |
+
print(f"[HB] t={int(time.time())}", flush=True)
|
| 87 |
+
time.sleep(interval)
|
| 88 |
+
|
| 89 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 90 |
+
# Streaming video I/O
|
| 91 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 92 |
+
class StreamingVideoIO:
|
| 93 |
+
def __init__(self, video_path: str, out_path: str, fps: float):
|
| 94 |
+
self.video_path = video_path
|
| 95 |
+
self.out_path = out_path
|
| 96 |
+
self.fps = fps
|
| 97 |
+
self.cap = None
|
| 98 |
+
self.writer = None
|
| 99 |
+
self.size = None
|
| 100 |
+
|
| 101 |
+
def __enter__(self):
|
| 102 |
+
self.cap = cv2.VideoCapture(self.video_path)
|
| 103 |
+
if not self.cap.isOpened():
|
| 104 |
+
raise RuntimeError(f"Cannot open video: {self.video_path}")
|
| 105 |
+
w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 106 |
+
h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 107 |
+
self.size = (w, h)
|
| 108 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 109 |
+
self.writer = cv2.VideoWriter(self.out_path, fourcc, self.fps, (w, h))
|
| 110 |
+
return self
|
| 111 |
+
|
| 112 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 113 |
+
if self.cap:
|
| 114 |
+
self.cap.release()
|
| 115 |
+
if self.writer:
|
| 116 |
+
self.writer.release()
|
| 117 |
+
|
| 118 |
+
def read_frame(self):
|
| 119 |
+
if not self.cap:
|
| 120 |
+
return False, None
|
| 121 |
+
return self.cap.read()
|
| 122 |
+
|
| 123 |
+
def write_frame(self, frame_bgr: np.ndarray):
|
| 124 |
+
if not self.writer:
|
| 125 |
+
return
|
| 126 |
+
self.writer.write(frame_bgr)
|
| 127 |
+
|
| 128 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 129 |
+
# Models: loaders and safe optimizations
|
| 130 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 131 |
+
def load_sam2_predictor(device: torch.device):
|
| 132 |
+
"""
|
| 133 |
+
Prefer your local wrapper to keep interfaces stable.
|
| 134 |
+
"""
|
| 135 |
+
try:
|
| 136 |
+
from models.sam2_loader import SAM2Predictor # your wrapper
|
| 137 |
+
predictor = SAM2Predictor(device=device)
|
| 138 |
+
# Optional: try to access underlying model to set fp16 + channels_last
|
| 139 |
+
try:
|
| 140 |
+
if hasattr(predictor, "model") and predictor.model is not None:
|
| 141 |
+
predictor.model = predictor.model.half().to(device)
|
| 142 |
+
predictor.model = predictor.model.to(memory_format=torch.channels_last)
|
| 143 |
+
logger.info("SAM2: fp16 + channels_last applied (wrapper model).")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.warning(f"SAM2 fp16 optimization warning: {e}")
|
| 146 |
+
return predictor
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Failed to import SAM2Predictor: {e}")
|
| 149 |
+
raise
|
| 150 |
+
|
| 151 |
+
def load_matany_session(device: torch.device):
|
| 152 |
+
"""
|
| 153 |
+
Supports either MatAnyoneSession or MatAnyoneLoader (your code has varied).
|
| 154 |
+
"""
|
| 155 |
+
try:
|
| 156 |
+
try:
|
| 157 |
+
from models.matanyone_loader import MatAnyoneSession as _MatAny
|
| 158 |
+
except Exception:
|
| 159 |
+
from models.matanyone_loader import MatAnyoneLoader as _MatAny
|
| 160 |
+
session = _MatAny(device=device)
|
| 161 |
+
# Try fp16 eval where safe
|
| 162 |
+
if hasattr(session, "model") and session.model is not None:
|
| 163 |
+
session.model.eval()
|
| 164 |
+
try:
|
| 165 |
+
session.model = session.model.half().to(device)
|
| 166 |
+
logger.info("MatAnyone: fp16 + eval applied.")
|
| 167 |
+
except Exception:
|
| 168 |
+
logger.info("MatAnyone: using fp32 (fp16 not supported for some layers).")
|
| 169 |
+
return session
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.warning(f"MatAnyone not available ({e}). Proceeding without refinement.")
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 175 |
+
# SAM2 state pruning (adapter): we call predictor.prune_state if present, else best-effort
|
| 176 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 177 |
+
def prune_sam2_state(predictor, state: Any, keep: int):
|
| 178 |
+
"""
|
| 179 |
+
Try to prune SAM2 temporal caches to a fixed window length.
|
| 180 |
+
Your SAM2Predictor should implement prune_state(state, keep=N). If not, we do nothing.
|
| 181 |
+
"""
|
| 182 |
+
try:
|
| 183 |
+
if hasattr(predictor, "prune_state"):
|
| 184 |
+
predictor.prune_state(state, keep=keep)
|
| 185 |
+
elif hasattr(state, "prune") and callable(getattr(state, "prune")):
|
| 186 |
+
state.prune(keep=keep)
|
| 187 |
+
else:
|
| 188 |
+
# No-op; rely on model internals and GC
|
| 189 |
+
pass
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.debug(f"SAM2 prune_state warning: {e}")
|
| 192 |
+
|
| 193 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 194 |
+
# VRAM-aware controller
|
| 195 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 196 |
+
class VRAMAdaptiveController:
|
| 197 |
+
def __init__(self):
|
| 198 |
+
self.memory_window = int(os.getenv("SAM2_WINDOW", "96")) # frames to keep in model state
|
| 199 |
+
self.propagation_scale = float(os.getenv("SAM2_PROP_SCALE", "0.90")) # e.g., downscale factor for propagation
|
| 200 |
+
self.cleanup_every = 20 # frames
|
| 201 |
+
|
| 202 |
+
def adapt(self):
|
| 203 |
+
free, total = vram_gb()
|
| 204 |
+
if free == 0.0:
|
| 205 |
+
return
|
| 206 |
+
# Tighten if we dip under ~1.6 GB
|
| 207 |
+
if free < 1.6:
|
| 208 |
+
self.memory_window = max(48, self.memory_window - 8)
|
| 209 |
+
self.propagation_scale = max(0.75, self.propagation_scale - 0.03)
|
| 210 |
+
self.cleanup_every = max(12, self.cleanup_every - 2)
|
| 211 |
+
logger.warning(f"Low VRAM ({free:.2f} GB free) β window={self.memory_window}, scale={self.propagation_scale:.2f}")
|
| 212 |
+
# Relax if plenty free
|
| 213 |
+
elif free > 3.0:
|
| 214 |
+
self.memory_window = min(128, self.memory_window + 4)
|
| 215 |
+
self.propagation_scale = min(1.0, self.propagation_scale + 0.01)
|
| 216 |
+
self.cleanup_every = min(40, self.cleanup_every + 2)
|
| 217 |
+
|
| 218 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 219 |
+
# Audio mux helper (safer stream mapping)
|
| 220 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 221 |
+
def mux_audio(video_path_no_audio: str, source_with_audio: str, out_path: str) -> bool:
|
| 222 |
+
cmd = [
|
| 223 |
+
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
|
| 224 |
+
"-i", video_path_no_audio,
|
| 225 |
+
"-i", source_with_audio,
|
| 226 |
+
"-map", "0:v:0", "-map", "1:a:0",
|
| 227 |
+
"-c:v", "copy", "-c:a", "aac", "-shortest",
|
| 228 |
+
out_path
|
| 229 |
+
]
|
| 230 |
+
try:
|
| 231 |
+
r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
|
| 232 |
+
if r.returncode != 0:
|
| 233 |
+
logger.warning(f"FFmpeg mux failed: {r.stderr.strip()}")
|
| 234 |
+
return False
|
| 235 |
+
return True
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.warning(f"FFmpeg mux error: {e}")
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 241 |
+
# Main processing
|
| 242 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 243 |
+
def process(
|
| 244 |
+
video_path: str,
|
| 245 |
+
background_image: Optional[Image.Image] = None,
|
| 246 |
+
background_type: str = "custom",
|
| 247 |
+
background_prompt: str = "",
|
| 248 |
+
job_directory: Optional[Path] = None,
|
| 249 |
+
progress_callback: Optional[Callable[[str, float], None]] = None
|
| 250 |
+
) -> str:
|
| 251 |
+
"""
|
| 252 |
+
Production SAM2 + MatAnyone pipeline for T4.
|
| 253 |
+
- Single-pass streaming (no large mask dicts)
|
| 254 |
+
- Bounded memory windows
|
| 255 |
+
"""
|
| 256 |
+
setup_t4_environment()
|
| 257 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 258 |
+
|
| 259 |
+
# Heartbeat
|
| 260 |
+
hb_flag = {"running": True}
|
| 261 |
+
hb_thread = threading.Thread(target=heartbeat_monitor, args=(hb_flag, 8.0), daemon=True)
|
| 262 |
+
hb_thread.start()
|
| 263 |
+
|
| 264 |
+
def report(step: str, p: Optional[float] = None):
|
| 265 |
+
if p is None:
|
| 266 |
+
logger.info(step)
|
| 267 |
+
else:
|
| 268 |
+
logger.info(f"{step} [{p:.1%}]")
|
| 269 |
+
if progress_callback:
|
| 270 |
+
try:
|
| 271 |
+
progress_callback(step, p)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.debug(f"progress_callback error: {e}")
|
| 274 |
+
|
| 275 |
+
# Validate I/O
|
| 276 |
+
src = Path(video_path)
|
| 277 |
+
if not src.exists():
|
| 278 |
+
hb_flag["running"] = False
|
| 279 |
+
raise FileNotFoundError(f"Video not found: {video_path}")
|
| 280 |
+
|
| 281 |
+
if job_directory is None:
|
| 282 |
+
job_directory = Path.cwd() / "tmp" / f"job_{uuid.uuid4().hex[:8]}"
|
| 283 |
+
job_directory.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
# Probe video
|
| 286 |
+
cap_probe = cv2.VideoCapture(str(src))
|
| 287 |
+
if not cap_probe.isOpened():
|
| 288 |
+
hb_flag["running"] = False
|
| 289 |
+
raise RuntimeError(f"Cannot open video: {video_path}")
|
| 290 |
+
fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0
|
| 291 |
+
width = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 292 |
+
height = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 293 |
+
frame_count = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 294 |
+
duration = frame_count / fps if fps > 0 else 0.0
|
| 295 |
+
cap_probe.release()
|
| 296 |
+
logger.info(f"Video: {width}x{height} @ {fps:.2f} fps | {frame_count} frames ({duration:.1f}s)")
|
| 297 |
+
|
| 298 |
+
# Prepare background
|
| 299 |
+
if background_image is None:
|
| 300 |
+
hb_flag["running"] = False
|
| 301 |
+
raise ValueError("background_image is required")
|
| 302 |
+
bg = background_image.resize((width, height), Image.LANCZOS)
|
| 303 |
+
bg_np = np.array(bg).astype(np.float32)
|
| 304 |
+
|
| 305 |
+
# Load models
|
| 306 |
+
report("Loading SAM2 + MatAnyone", 0.05)
|
| 307 |
+
predictor = load_sam2_predictor(device)
|
| 308 |
+
matany = load_matany_session(device)
|
| 309 |
+
|
| 310 |
+
# Init SAM2 state (single)
|
| 311 |
+
report("Initializing SAM2 video state", 0.08)
|
| 312 |
+
state = predictor.init_state(video_path=str(src))
|
| 313 |
+
|
| 314 |
+
# Minimal prompt: single positive point at center (replace with your prompt UI if needed)
|
| 315 |
+
center_pt = np.array([[width // 2, height // 2]], dtype=np.float32)
|
| 316 |
+
labels = np.array([1], dtype=np.int32)
|
| 317 |
+
ann_obj_id = 1
|
| 318 |
+
with torch.inference_mode():
|
| 319 |
+
_ = predictor.add_new_points(
|
| 320 |
+
inference_state=state,
|
| 321 |
+
frame_idx=0,
|
| 322 |
+
obj_id=ann_obj_id,
|
| 323 |
+
points=center_pt,
|
| 324 |
+
labels=labels,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Controller
|
| 328 |
+
ctrl = VRAMAdaptiveController()
|
| 329 |
+
|
| 330 |
+
# Output paths
|
| 331 |
+
out_raw = str(job_directory / f"composite_{int(time.time())}.mp4")
|
| 332 |
+
out_final = str(job_directory / f"final_{int(time.time())}.mp4")
|
| 333 |
+
|
| 334 |
+
# Windows/buffers (bounded)
|
| 335 |
+
# For completeness we keep a tiny deque for any auxiliary temporal ops (e.g., matting history)
|
| 336 |
+
aux_window = deque(maxlen=max(32, min(96, ctrl.memory_window // 2)))
|
| 337 |
+
|
| 338 |
+
# Stream processing
|
| 339 |
+
start = time.time()
|
| 340 |
+
frames_done = 0
|
| 341 |
+
next_cleanup_at = ctrl.cleanup_every
|
| 342 |
+
|
| 343 |
+
report("Streaming: SAM2 β MatAnyone β Compose β Write", 0.12)
|
| 344 |
+
with StreamingVideoIO(str(src), out_raw, fps) as vio:
|
| 345 |
+
# iterate SAM2 propagation alongside reading frames
|
| 346 |
+
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16 if device.type == "cuda" else None):
|
| 347 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(state, scale=ctrl.propagation_scale):
|
| 348 |
+
# Read the matching frame
|
| 349 |
+
ret, frame_bgr = vio.read_frame()
|
| 350 |
+
if not ret:
|
| 351 |
+
break
|
| 352 |
+
|
| 353 |
+
# Get mask for ann_obj_id; keep on GPU as long as possible
|
| 354 |
+
mask_t = None
|
| 355 |
+
try:
|
| 356 |
+
if isinstance(out_obj_ids, torch.Tensor):
|
| 357 |
+
# find index where id == ann_obj_id
|
| 358 |
+
idxs = (out_obj_ids == ann_obj_id).nonzero(as_tuple=False)
|
| 359 |
+
if idxs.numel() > 0:
|
| 360 |
+
i = idxs[0].item()
|
| 361 |
+
logits = out_mask_logits[i]
|
| 362 |
+
else:
|
| 363 |
+
logits = None
|
| 364 |
+
else:
|
| 365 |
+
# list/array fallback
|
| 366 |
+
ids_list = list(out_obj_ids)
|
| 367 |
+
i = ids_list.index(ann_obj_id) if ann_obj_id in ids_list else -1
|
| 368 |
+
logits = out_mask_logits[i] if i >= 0 else None
|
| 369 |
+
|
| 370 |
+
if logits is not None:
|
| 371 |
+
# logits β prob β binary mask (threshold 0)
|
| 372 |
+
mask_t = (logits > 0).float() # HxW on CUDA fp16 β fp32 float
|
| 373 |
+
except Exception as e:
|
| 374 |
+
logger.debug(f"Mask extraction warning @frame {out_frame_idx}: {e}")
|
| 375 |
+
mask_t = None
|
| 376 |
+
|
| 377 |
+
# Optional: MatAnyone refinement
|
| 378 |
+
if mask_t is not None and matany is not None:
|
| 379 |
+
try:
|
| 380 |
+
# MatAnyone APIs vary β try common forms
|
| 381 |
+
# Convert RGB because many mattors expect RGB
|
| 382 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 383 |
+
# Move frame to GPU only if your matting backend supports it
|
| 384 |
+
refined = None
|
| 385 |
+
if hasattr(matany, "refine_mask"):
|
| 386 |
+
refined = matany.refine_mask(frame_rgb, mask_t) # allow handler to decide device
|
| 387 |
+
elif hasattr(matany, "process_frame"):
|
| 388 |
+
refined = matany.process_frame(frame_rgb, mask_t)
|
| 389 |
+
if refined is not None:
|
| 390 |
+
# ensure float mask 0..1 on CUDA or CPU
|
| 391 |
+
if isinstance(refined, torch.Tensor):
|
| 392 |
+
mask_t = refined.float()
|
| 393 |
+
else:
|
| 394 |
+
# numpy β torch
|
| 395 |
+
mask_t = torch.from_numpy(refined.astype(np.float32))
|
| 396 |
+
if device.type == "cuda":
|
| 397 |
+
mask_t = mask_t.to(device)
|
| 398 |
+
except Exception as e:
|
| 399 |
+
logger.debug(f"MatAnyone refinement failed (frame {out_frame_idx}): {e}")
|
| 400 |
+
|
| 401 |
+
# Compose and write (convert once, keep math sane)
|
| 402 |
+
if mask_t is not None:
|
| 403 |
+
# bring mask to CPU for np composition; keep as float [0,1]
|
| 404 |
+
mask_np = mask_t.detach().clamp(0, 1).to("cpu", non_blocking=True).float().numpy()
|
| 405 |
+
m3 = mask_np[..., None] # HxWx1
|
| 406 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 407 |
+
comp = frame_rgb * m3 + bg_np * (1.0 - m3)
|
| 408 |
+
comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 409 |
+
vio.write_frame(comp_bgr)
|
| 410 |
+
else:
|
| 411 |
+
# No mask β write original frame
|
| 412 |
+
vio.write_frame(frame_bgr)
|
| 413 |
+
|
| 414 |
+
# Periodic maintenance
|
| 415 |
+
frames_done += 1
|
| 416 |
+
if frames_done >= next_cleanup_at:
|
| 417 |
+
ctrl.adapt()
|
| 418 |
+
prune_sam2_state(predictor, state, keep=ctrl.memory_window)
|
| 419 |
+
# Clear small aux buffers
|
| 420 |
+
aux_window.clear()
|
| 421 |
+
if device.type == "cuda":
|
| 422 |
+
torch.cuda.ipc_collect()
|
| 423 |
+
torch.cuda.empty_cache()
|
| 424 |
+
next_cleanup_at = frames_done + ctrl.cleanup_every
|
| 425 |
+
|
| 426 |
+
# Progress
|
| 427 |
+
if frames_done % 25 == 0 and frame_count > 0:
|
| 428 |
+
p = 0.12 + 0.75 * (frames_done / frame_count)
|
| 429 |
+
report(f"Processing frame {frames_done}/{frame_count} | win={ctrl.memory_window} scale={ctrl.propagation_scale:.2f}", p)
|
| 430 |
+
|
| 431 |
+
# Audio mux
|
| 432 |
+
report("Restoring audio", 0.93)
|
| 433 |
+
ok = mux_audio(out_raw, str(src), out_final)
|
| 434 |
+
final_path = out_final if ok else out_raw
|
| 435 |
+
|
| 436 |
+
# Cleanup models/state promptly
|
| 437 |
+
try:
|
| 438 |
+
del predictor
|
| 439 |
+
del state
|
| 440 |
+
if matany is not None:
|
| 441 |
+
del matany
|
| 442 |
+
except Exception:
|
| 443 |
+
pass
|
| 444 |
+
|
| 445 |
+
if device.type == "cuda":
|
| 446 |
+
torch.cuda.ipc_collect()
|
| 447 |
+
torch.cuda.empty_cache()
|
| 448 |
+
gc.collect()
|
| 449 |
+
|
| 450 |
+
hb_flag["running"] = False
|
| 451 |
+
elapsed = time.time() - start
|
| 452 |
+
try:
|
| 453 |
+
peak = torch.cuda.max_memory_allocated() / (1024 ** 3) if device.type == "cuda" else 0.0
|
| 454 |
+
logger.info(f"Peak GPU memory: {peak:.2f} GB")
|
| 455 |
+
except Exception:
|
| 456 |
+
pass
|
| 457 |
+
report(f"Done in {elapsed:.1f}s", 1.0)
|
| 458 |
+
logger.info(f"Output: {final_path}")
|
| 459 |
+
logger.info(f"Artifacts: {job_directory}")
|
| 460 |
+
return final_path
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
# -------------------------------------------------------------------------------------------------
|
| 464 |
+
# CLI entry (optional)
|
| 465 |
+
# -------------------------------------------------------------------------------------------------
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
import argparse
|
| 468 |
+
parser = argparse.ArgumentParser(description="BackgroundFX Pro pipeline")
|
| 469 |
+
parser.add_argument("--video", required=True, help="Path to input video")
|
| 470 |
+
parser.add_argument("--background", required=True, help="Path to background image")
|
| 471 |
+
parser.add_argument("--outdir", default=None, help="Job directory (optional)")
|
| 472 |
+
args = parser.parse_args()
|
| 473 |
+
|
| 474 |
+
bg_img = Image.open(args.background).convert("RGB")
|
| 475 |
+
outdir = Path(args.outdir) if args.outdir else None
|
| 476 |
+
out_path = process(args.video, background_image=bg_img, job_directory=outdir)
|
| 477 |
+
print(out_path)
|
VideoBackgroundReplacer2/requirements.txt
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ===== Core Dependencies =====
|
| 2 |
+
# PyTorch is installed in Dockerfile with CUDA 12.1 β REQUIRED for SAM2
|
| 3 |
+
# torch==2.5.1
|
| 4 |
+
# torchvision==0.20.1
|
| 5 |
+
# torchaudio==2.5.1
|
| 6 |
+
|
| 7 |
+
# ===== Base Dependencies =====
|
| 8 |
+
numpy>=1.24.0,<2.1.0
|
| 9 |
+
Pillow>=10.0.0,<12.0.0
|
| 10 |
+
protobuf>=4.25.0,<6.0.0
|
| 11 |
+
|
| 12 |
+
# ===== Image/Video Processing =====
|
| 13 |
+
opencv-python-headless>=4.8.0,<4.11.0
|
| 14 |
+
imageio>=2.25.0,<3.0.0
|
| 15 |
+
imageio-ffmpeg>=0.4.7,<0.6.0
|
| 16 |
+
moviepy>=1.0.3,<2.0.0
|
| 17 |
+
decord>=0.6.0,<0.7.0
|
| 18 |
+
scikit-image>=0.19.3,<0.22.0
|
| 19 |
+
|
| 20 |
+
# ===== MediaPipe =====
|
| 21 |
+
mediapipe>=0.10.0,<0.11.0
|
| 22 |
+
|
| 23 |
+
# ===== SAM2 Dependencies =====
|
| 24 |
+
# SAM2 is installed via git clone in Dockerfile
|
| 25 |
+
hydra-core>=1.3.2,<2.0.0
|
| 26 |
+
omegaconf>=2.3.0,<3.0.0
|
| 27 |
+
einops>=0.6.0,<0.9.0
|
| 28 |
+
timm>=0.9.0,<1.1.0
|
| 29 |
+
pyyaml>=6.0.0,<7.0.0
|
| 30 |
+
matplotlib>=3.5.0,<4.0.0
|
| 31 |
+
iopath>=0.1.10,<0.2.0
|
| 32 |
+
|
| 33 |
+
# ===== MatAnyone Dependencies =====
|
| 34 |
+
# MatAnyone is installed separately in Dockerfile
|
| 35 |
+
kornia>=0.7.0,<0.8.0
|
| 36 |
+
tqdm>=4.60.0,<5.0.0
|
| 37 |
+
|
| 38 |
+
# ===== UI and API =====
|
| 39 |
+
# Bump to avoid gradio_client 1.3.0 bug ("bool is not iterable")
|
| 40 |
+
gradio==4.42.0
|
| 41 |
+
|
| 42 |
+
# ===== Web stack pins for Gradio 4.42.0 =====
|
| 43 |
+
fastapi==0.109.2
|
| 44 |
+
starlette==0.36.3
|
| 45 |
+
uvicorn==0.29.0
|
| 46 |
+
httpx==0.27.2
|
| 47 |
+
anyio==4.4.0
|
| 48 |
+
orjson>=3.10.0
|
| 49 |
+
|
| 50 |
+
# ===== Pydantic family (avoid breaking core 2.23.x) =====
|
| 51 |
+
pydantic==2.8.2
|
| 52 |
+
pydantic-core==2.20.1
|
| 53 |
+
annotated-types==0.6.0
|
| 54 |
+
typing-extensions==4.12.2
|
| 55 |
+
|
| 56 |
+
# ===== Helpers and Utilities =====
|
| 57 |
+
huggingface-hub>=0.20.0,<1.0.0
|
| 58 |
+
ffmpeg-python>=0.2.0,<1.0.0
|
| 59 |
+
psutil>=5.8.0,<7.0.0
|
| 60 |
+
requests>=2.25.0,<3.0.0
|
| 61 |
+
scikit-learn>=1.3.0,<2.0.0
|
| 62 |
+
|
| 63 |
+
# ===== Additional Dependencies =====
|
| 64 |
+
# Performance and monitoring
|
| 65 |
+
gputil>=1.4.0,<2.0.0
|
| 66 |
+
nvidia-ml-py3>=7.352.0,<12.0.0
|
| 67 |
+
|
| 68 |
+
# Error handling and logging
|
| 69 |
+
loguru>=0.6.0,<1.0.0
|
| 70 |
+
|
| 71 |
+
# File handling
|
| 72 |
+
python-multipart>=0.0.5,<1.0.0
|
VideoBackgroundReplacer2/two_stage_pipeline.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
two_stage_pipeline.py β Ephemeral SAM2 stage + MatAnyone stage
|
| 4 |
+
- Stage 1: SAM2 -> lossless mask stream (FFV1 .mkv) + meta.json, then unload SAM2
|
| 5 |
+
- Stage 2: read mask stream -> (optional) MatAnyone refine -> composite -> mux audio
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os, sys, gc, json, cv2, time, uuid, torch, shutil, logging, subprocess, threading
|
| 9 |
+
import numpy as np
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Callable, Tuple, Dict, Any
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger("backgroundfx_pro.two_stage")
|
| 15 |
+
if not logger.handlers:
|
| 16 |
+
h = logging.StreamHandler()
|
| 17 |
+
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
|
| 18 |
+
logger.addHandler(h)
|
| 19 |
+
logger.setLevel(logging.INFO)
|
| 20 |
+
|
| 21 |
+
# ---------------------------
|
| 22 |
+
# Env & CUDA helpers
|
| 23 |
+
# ---------------------------
|
| 24 |
+
def setup_env():
|
| 25 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF","expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
|
| 26 |
+
os.environ.setdefault("OMP_NUM_THREADS","1")
|
| 27 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS","1")
|
| 28 |
+
os.environ.setdefault("MKL_NUM_THREADS","1")
|
| 29 |
+
torch.set_grad_enabled(False)
|
| 30 |
+
try:
|
| 31 |
+
torch.backends.cudnn.benchmark = True
|
| 32 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 33 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 34 |
+
torch.set_float32_matmul_precision("high")
|
| 35 |
+
except Exception:
|
| 36 |
+
pass
|
| 37 |
+
if torch.cuda.is_available():
|
| 38 |
+
try:
|
| 39 |
+
torch.cuda.set_per_process_memory_fraction(float(os.getenv("CUDA_MEMORY_FRACTION","0.88")))
|
| 40 |
+
except Exception:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
def free_cuda():
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
torch.cuda.ipc_collect()
|
| 46 |
+
torch.cuda.empty_cache()
|
| 47 |
+
|
| 48 |
+
def unload_sam2_modules():
|
| 49 |
+
"""Aggressively unload SAM2 python modules to reduce RSS."""
|
| 50 |
+
try:
|
| 51 |
+
import importlib
|
| 52 |
+
mods = [m for m in list(sys.modules) if m.startswith("sam2")]
|
| 53 |
+
for m in mods:
|
| 54 |
+
sys.modules.pop(m, None)
|
| 55 |
+
importlib.invalidate_caches()
|
| 56 |
+
gc.collect()
|
| 57 |
+
free_cuda()
|
| 58 |
+
logger.info("SAM2 modules unloaded.")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.warning(f"Unloading SAM2 modules: {e}")
|
| 61 |
+
|
| 62 |
+
# ---------------------------
|
| 63 |
+
# Video probing
|
| 64 |
+
# ---------------------------
|
| 65 |
+
def probe_video(path:str) -> Tuple[int,int,float,int]:
|
| 66 |
+
cap = cv2.VideoCapture(path)
|
| 67 |
+
if not cap.isOpened():
|
| 68 |
+
raise RuntimeError(f"Cannot open video: {path}")
|
| 69 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| 70 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 71 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 72 |
+
n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 73 |
+
cap.release()
|
| 74 |
+
return w,h,float(fps),n
|
| 75 |
+
|
| 76 |
+
# ---------------------------
|
| 77 |
+
# FFmpeg mask writers/readers
|
| 78 |
+
# ---------------------------
|
| 79 |
+
class MaskFFV1Writer:
|
| 80 |
+
"""Write uint8 binary/gray masks to FFV1 lossless .mkv via pipe."""
|
| 81 |
+
def __init__(self, path:str, w:int, h:int, fps:float):
|
| 82 |
+
self.path = path
|
| 83 |
+
self.w, self.h, self.fps = w,h,fps
|
| 84 |
+
self.proc = None
|
| 85 |
+
|
| 86 |
+
def __enter__(self):
|
| 87 |
+
cmd = [
|
| 88 |
+
"ffmpeg","-y","-hide_banner","-loglevel","error",
|
| 89 |
+
"-f","rawvideo","-pix_fmt","gray","-s",f"{self.w}x{self.h}","-r",f"{self.fps}",
|
| 90 |
+
"-i","-",
|
| 91 |
+
"-c:v","ffv1","-level","3","-g","1", self.path
|
| 92 |
+
]
|
| 93 |
+
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
+
def write(self, mask_u8: np.ndarray):
|
| 97 |
+
# Expect HxW uint8 (0/255). Ensure contiguous.
|
| 98 |
+
if mask_u8.dtype != np.uint8:
|
| 99 |
+
mask_u8 = mask_u8.astype(np.uint8)
|
| 100 |
+
self.proc.stdin.write(mask_u8.tobytes())
|
| 101 |
+
|
| 102 |
+
def __exit__(self, exc_type, exc, tb):
|
| 103 |
+
if self.proc:
|
| 104 |
+
try:
|
| 105 |
+
self.proc.stdin.flush()
|
| 106 |
+
self.proc.stdin.close()
|
| 107 |
+
self.proc.wait(timeout=120)
|
| 108 |
+
except Exception:
|
| 109 |
+
self.proc.kill()
|
| 110 |
+
|
| 111 |
+
class MaskFFV1Reader:
|
| 112 |
+
"""Read uint8 masks from FFV1 .mkv via pipe."""
|
| 113 |
+
def __init__(self, path:str, w:int, h:int):
|
| 114 |
+
self.path = path
|
| 115 |
+
self.w,self.h = w,h
|
| 116 |
+
self.proc = None
|
| 117 |
+
self.frame_bytes = w*h
|
| 118 |
+
|
| 119 |
+
def __enter__(self):
|
| 120 |
+
cmd = [
|
| 121 |
+
"ffmpeg","-hide_banner","-loglevel","error","-i", self.path,
|
| 122 |
+
"-f","rawvideo","-pix_fmt","gray","-"
|
| 123 |
+
]
|
| 124 |
+
self.proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
|
| 125 |
+
return self
|
| 126 |
+
|
| 127 |
+
def read(self) -> Optional[np.ndarray]:
|
| 128 |
+
buf = self.proc.stdout.read(self.frame_bytes)
|
| 129 |
+
if not buf or len(buf) < self.frame_bytes:
|
| 130 |
+
return None
|
| 131 |
+
return np.frombuffer(buf, dtype=np.uint8).reshape(self.h, self.w)
|
| 132 |
+
|
| 133 |
+
def __exit__(self, exc_type, exc, tb):
|
| 134 |
+
if self.proc:
|
| 135 |
+
try:
|
| 136 |
+
self.proc.stdout.close()
|
| 137 |
+
self.proc.wait(timeout=30)
|
| 138 |
+
except Exception:
|
| 139 |
+
self.proc.kill()
|
| 140 |
+
|
| 141 |
+
# Fallback: PNG sequence (disk heavy but simple & robust)
|
| 142 |
+
class MaskPNGWriter:
|
| 143 |
+
def __init__(self, dirpath: Path):
|
| 144 |
+
self.dir = dirpath; self.dir.mkdir(parents=True, exist_ok=True); self.idx=0
|
| 145 |
+
def write(self, mask_u8: np.ndarray):
|
| 146 |
+
cv2.imwrite(str(self.dir / f"{self.idx:06d}.png"), mask_u8)
|
| 147 |
+
self.idx+=1
|
| 148 |
+
|
| 149 |
+
class MaskPNGReader:
|
| 150 |
+
def __init__(self, dirpath: Path):
|
| 151 |
+
self.dir=dirpath; self.idx=0
|
| 152 |
+
def read(self) -> Optional[np.ndarray]:
|
| 153 |
+
p = self.dir / f"{self.idx:06d}.png"
|
| 154 |
+
if not p.exists(): return None
|
| 155 |
+
img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
|
| 156 |
+
self.idx+=1
|
| 157 |
+
return img
|
| 158 |
+
|
| 159 |
+
# ---------------------------
|
| 160 |
+
# Stage 1 β SAM2 β mask dump
|
| 161 |
+
# ---------------------------
|
| 162 |
+
def stage1_dump_masks(video_path:str, out_dir:Path, obj_point:Tuple[int,int]=None) -> Dict[str,Any]:
|
| 163 |
+
"""
|
| 164 |
+
Run only SAM2, save masks as FFV1 (preferred) or PNG sequence + meta.json.
|
| 165 |
+
Returns meta dict.
|
| 166 |
+
"""
|
| 167 |
+
setup_env()
|
| 168 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 169 |
+
w,h,fps,n = probe_video(video_path)
|
| 170 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 171 |
+
meta = {"video":video_path, "width":w,"height":h,"fps":fps,"frames":n, "storage":None}
|
| 172 |
+
logger.info(f"[Stage1] {w}x{h}@{fps:.2f} | frames={n}")
|
| 173 |
+
|
| 174 |
+
# Load SAM2 (your wrapper)
|
| 175 |
+
from models.sam2_loader import SAM2Predictor
|
| 176 |
+
predictor = SAM2Predictor(device=device)
|
| 177 |
+
state = predictor.init_state(video_path=video_path)
|
| 178 |
+
|
| 179 |
+
# Prompt: center positive if not provided
|
| 180 |
+
if obj_point is None:
|
| 181 |
+
obj_point = (w//2, h//2)
|
| 182 |
+
pts = np.array([[obj_point[0], obj_point[1]]], dtype=np.float32)
|
| 183 |
+
labels = np.array([1], dtype=np.int32)
|
| 184 |
+
ann_obj_id = 1
|
| 185 |
+
with torch.inference_mode():
|
| 186 |
+
predictor.add_new_points(state, 0, ann_obj_id, pts, labels)
|
| 187 |
+
|
| 188 |
+
# Preferred: FFV1 mask stream
|
| 189 |
+
mask_mkv = out_dir / "mask.mkv"
|
| 190 |
+
use_png = False
|
| 191 |
+
try:
|
| 192 |
+
with MaskFFV1Writer(str(mask_mkv), w, h, fps) as writer, \
|
| 193 |
+
torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16 if device.type=="cuda" else None):
|
| 194 |
+
for _, out_ids, out_logits in predictor.propagate_in_video(state):
|
| 195 |
+
# pick ann_obj_id
|
| 196 |
+
i = None
|
| 197 |
+
if isinstance(out_ids, torch.Tensor):
|
| 198 |
+
nz = (out_ids == ann_obj_id).nonzero(as_tuple=False)
|
| 199 |
+
if nz.numel() > 0: i = nz[0].item()
|
| 200 |
+
else:
|
| 201 |
+
ids = list(out_ids); i = ids.index(ann_obj_id) if ann_obj_id in ids else None
|
| 202 |
+
if i is None:
|
| 203 |
+
# write empty
|
| 204 |
+
writer.write(np.zeros((h,w), np.uint8))
|
| 205 |
+
continue
|
| 206 |
+
mask = (out_logits[i] > 0).detach()
|
| 207 |
+
mask_u8 = (mask.float().mul_(255).to("cpu", non_blocking=True).numpy()).astype(np.uint8)
|
| 208 |
+
writer.write(mask_u8)
|
| 209 |
+
meta["storage"] = "ffv1"
|
| 210 |
+
meta["mask_path"] = str(mask_mkv)
|
| 211 |
+
logger.info("[Stage1] Masks saved as FFV1 .mkv")
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.warning(f"FFV1 writer failed ({e}), falling back to PNG sequence.")
|
| 214 |
+
png_dir = out_dir / "masks_png"
|
| 215 |
+
wr = MaskPNGWriter(png_dir)
|
| 216 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16 if device.type=="cuda" else None):
|
| 217 |
+
for _, out_ids, out_logits in predictor.propagate_in_video(state):
|
| 218 |
+
i = None
|
| 219 |
+
if isinstance(out_ids, torch.Tensor):
|
| 220 |
+
nz = (out_ids == ann_obj_id).nonzero(as_tuple=False)
|
| 221 |
+
if nz.numel() > 0: i = nz[0].item()
|
| 222 |
+
else:
|
| 223 |
+
ids = list(out_ids); i = ids.index(ann_obj_id) if ann_obj_id in ids else None
|
| 224 |
+
if i is None:
|
| 225 |
+
wr.write(np.zeros((h,w), np.uint8)); continue
|
| 226 |
+
mask = (out_logits[i] > 0).detach()
|
| 227 |
+
wr.write((mask.float().mul_(255).to("cpu").numpy()).astype(np.uint8))
|
| 228 |
+
meta["storage"] = "png"
|
| 229 |
+
meta["mask_path"] = str(png_dir)
|
| 230 |
+
|
| 231 |
+
# Persist meta
|
| 232 |
+
with open(out_dir / "meta.json","w") as f:
|
| 233 |
+
json.dump(meta, f)
|
| 234 |
+
# Unload SAM2 completely
|
| 235 |
+
del predictor, state
|
| 236 |
+
free_cuda(); unload_sam2_modules()
|
| 237 |
+
return meta
|
| 238 |
+
|
| 239 |
+
# ---------------------------
|
| 240 |
+
# Stage 2 β refine + compose
|
| 241 |
+
# ---------------------------
|
| 242 |
+
def stage2_refine_and_compose(video_path:str, mask_dir:Path, background_image:Image.Image,
|
| 243 |
+
out_path:str, use_matany:bool=True) -> str:
|
| 244 |
+
w,h,fps,n = probe_video(video_path)
|
| 245 |
+
bg = background_image.resize((w,h), Image.LANCZOS)
|
| 246 |
+
bg_np = np.array(bg).astype(np.float32)
|
| 247 |
+
|
| 248 |
+
# Read meta
|
| 249 |
+
with open(mask_dir / "meta.json","r") as f:
|
| 250 |
+
meta = json.load(f)
|
| 251 |
+
storage = meta["storage"]; mask_path = meta["mask_path"]
|
| 252 |
+
|
| 253 |
+
# Optional MatAnyone
|
| 254 |
+
session = None
|
| 255 |
+
if use_matany:
|
| 256 |
+
try:
|
| 257 |
+
from models.matanyone_loader import MatAnyoneSession as _M
|
| 258 |
+
except Exception:
|
| 259 |
+
try:
|
| 260 |
+
from models.matanyone_loader import MatAnyoneLoader as _M
|
| 261 |
+
except Exception:
|
| 262 |
+
_M = None
|
| 263 |
+
if _M:
|
| 264 |
+
session = _M(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| 265 |
+
if hasattr(session,"model") and session.model is not None:
|
| 266 |
+
session.model.eval()
|
| 267 |
+
|
| 268 |
+
# Open video + writer
|
| 269 |
+
cap = cv2.VideoCapture(video_path)
|
| 270 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 271 |
+
tmp_out = str(Path(out_path).with_suffix(".noaudio.mp4"))
|
| 272 |
+
writer = cv2.VideoWriter(tmp_out, fourcc, fps, (w,h))
|
| 273 |
+
|
| 274 |
+
# Open mask reader
|
| 275 |
+
if storage == "ffv1":
|
| 276 |
+
mreader = MaskFFV1Reader(mask_path, w, h)
|
| 277 |
+
mreader.__enter__()
|
| 278 |
+
read_mask = lambda : mreader.read()
|
| 279 |
+
else:
|
| 280 |
+
mreader = MaskPNGReader(Path(mask_path))
|
| 281 |
+
read_mask = lambda : mreader.read()
|
| 282 |
+
|
| 283 |
+
i = 0
|
| 284 |
+
try:
|
| 285 |
+
while True:
|
| 286 |
+
ok, frame_bgr = cap.read()
|
| 287 |
+
if not ok: break
|
| 288 |
+
mask_u8 = read_mask()
|
| 289 |
+
if mask_u8 is None:
|
| 290 |
+
# out of masks; write original
|
| 291 |
+
writer.write(frame_bgr); i+=1; continue
|
| 292 |
+
|
| 293 |
+
# Optional refine
|
| 294 |
+
if session is not None:
|
| 295 |
+
try:
|
| 296 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 297 |
+
# Provide a float mask 0..1 to session; adapt if your API differs
|
| 298 |
+
mask_f = (mask_u8.astype(np.float32) / 255.0)
|
| 299 |
+
if hasattr(session,"refine_mask"):
|
| 300 |
+
mask_refined = session.refine_mask(frame_rgb, mask_f)
|
| 301 |
+
elif hasattr(session,"process_frame"):
|
| 302 |
+
mask_refined = session.process_frame(frame_rgb, mask_f)
|
| 303 |
+
else:
|
| 304 |
+
mask_refined = mask_f
|
| 305 |
+
if isinstance(mask_refined, torch.Tensor):
|
| 306 |
+
mask_u8 = (mask_refined.detach().clamp(0,1).mul(255).to("cpu").numpy()).astype(np.uint8)
|
| 307 |
+
elif isinstance(mask_refined, np.ndarray):
|
| 308 |
+
mask_u8 = (np.clip(mask_refined,0,1)*255).astype(np.uint8)
|
| 309 |
+
except Exception as e:
|
| 310 |
+
logger.debug(f"MatAnyone refine failed @frame {i}: {e}")
|
| 311 |
+
|
| 312 |
+
# Composite
|
| 313 |
+
m = (mask_u8.astype(np.float32)/255.0)[...,None] # HxWx1
|
| 314 |
+
fr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 315 |
+
comp = fr*m + bg_np*(1.0-m)
|
| 316 |
+
comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 317 |
+
writer.write(comp_bgr)
|
| 318 |
+
|
| 319 |
+
if i % 50 == 0:
|
| 320 |
+
logger.info(f"[Stage2] frame {i}/{n}")
|
| 321 |
+
i += 1
|
| 322 |
+
finally:
|
| 323 |
+
cap.release(); writer.release()
|
| 324 |
+
if isinstance(mreader, MaskFFV1Reader):
|
| 325 |
+
mreader.__exit__(None,None,None)
|
| 326 |
+
|
| 327 |
+
# Mux audio
|
| 328 |
+
final_out = str(Path(out_path))
|
| 329 |
+
cmd = [
|
| 330 |
+
"ffmpeg","-y","-hide_banner","-loglevel","error",
|
| 331 |
+
"-i", tmp_out, "-i", video_path,
|
| 332 |
+
"-map","0:v:0","-map","1:a:0","-c:v","copy","-c:a","aac","-shortest", final_out
|
| 333 |
+
]
|
| 334 |
+
try:
|
| 335 |
+
r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
|
| 336 |
+
if r.returncode != 0:
|
| 337 |
+
logger.warning(f"Audio mux failed: {r.stderr.strip()}")
|
| 338 |
+
shutil.move(tmp_out, final_out)
|
| 339 |
+
else:
|
| 340 |
+
os.remove(tmp_out)
|
| 341 |
+
except Exception:
|
| 342 |
+
shutil.move(tmp_out, final_out)
|
| 343 |
+
return final_out
|
| 344 |
+
|
| 345 |
+
# ---------------------------
|
| 346 |
+
# Orchestrator
|
| 347 |
+
# ---------------------------
|
| 348 |
+
def process_two_stage(
|
| 349 |
+
video_path:str,
|
| 350 |
+
background_image: Image.Image,
|
| 351 |
+
workdir: Optional[Path]=None,
|
| 352 |
+
progress: Optional[Callable[[str,float],None]] = None,
|
| 353 |
+
use_matany: bool = True,
|
| 354 |
+
) -> str:
|
| 355 |
+
setup_env()
|
| 356 |
+
if workdir is None:
|
| 357 |
+
workdir = Path.cwd()/ "tmp" / f"job_{uuid.uuid4().hex[:8]}"
|
| 358 |
+
workdir.mkdir(parents=True, exist_ok=True)
|
| 359 |
+
|
| 360 |
+
# Stage 1
|
| 361 |
+
if progress: progress("Stage 1: SAM2 mask pass", 0.05)
|
| 362 |
+
mask_dir = workdir / "sam2_masks"
|
| 363 |
+
meta = stage1_dump_masks(video_path, mask_dir)
|
| 364 |
+
if progress: progress("Stage 1 complete", 0.45)
|
| 365 |
+
|
| 366 |
+
# Stage 2
|
| 367 |
+
if progress: progress("Stage 2: refine + compose", 0.50)
|
| 368 |
+
out_path = workdir / f"final_{int(time.time())}.mp4"
|
| 369 |
+
final_video = stage2_refine_and_compose(video_path, mask_dir, background_image, str(out_path), use_matany=use_matany)
|
| 370 |
+
if progress: progress("Done", 1.0)
|
| 371 |
+
logger.info(f"Output: {final_video}")
|
| 372 |
+
return final_video
|
| 373 |
+
|
| 374 |
+
# ---------------------------
|
| 375 |
+
# CLI
|
| 376 |
+
# ---------------------------
|
| 377 |
+
if __name__ == "__main__":
|
| 378 |
+
import argparse
|
| 379 |
+
parser = argparse.ArgumentParser(description="Two-stage BackgroundFX Pro")
|
| 380 |
+
parser.add_argument("--video", required=True)
|
| 381 |
+
parser.add_argument("--background", required=True)
|
| 382 |
+
parser.add_argument("--outdir", default=None)
|
| 383 |
+
parser.add_argument("--no-matany", action="store_true")
|
| 384 |
+
args = parser.parse_args()
|
| 385 |
+
|
| 386 |
+
bg = Image.open(args.background).convert("RGB")
|
| 387 |
+
out = process_two_stage(args.video, bg, Path(args.outdir) if args.outdir else None, use_matany=not args.no_matany)
|
| 388 |
+
print(out)
|
VideoBackgroundReplacer2/ui.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BackgroundFX Pro β Main UI Application (Gradio 4.42.x)
|
| 4 |
+
Clean, focused main file that coordinates the application
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# ============================================================
|
| 8 |
+
# Mount-mode handoff: delegate to app.py when enabled
|
| 9 |
+
# (So we can serve a safe /config JSON via our FastAPI shim)
|
| 10 |
+
# ============================================================
|
| 11 |
+
import os, runpy
|
| 12 |
+
if os.getenv("GRADIO_MOUNT_MODE") == "1":
|
| 13 |
+
runpy.run_module("app", run_name="__main__")
|
| 14 |
+
raise SystemExit
|
| 15 |
+
|
| 16 |
+
# ==== Runtime hygiene & paths (very high in file) ====
|
| 17 |
+
import sys
|
| 18 |
+
import logging
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# --- Sanitize OMP/BLAS threads early (avoids "libgomp: Invalid value..." issues)
|
| 22 |
+
def _sanitize_omp_env():
|
| 23 |
+
import multiprocessing as _mp
|
| 24 |
+
cpu = max(1, _mp.cpu_count())
|
| 25 |
+
default_omp = max(1, cpu // 2)
|
| 26 |
+
|
| 27 |
+
raw = os.environ.get("OMP_NUM_THREADS", "").strip()
|
| 28 |
+
try:
|
| 29 |
+
n = int(raw)
|
| 30 |
+
if n <= 0 or n > cpu * 2:
|
| 31 |
+
raise ValueError
|
| 32 |
+
omp_val = n
|
| 33 |
+
except Exception:
|
| 34 |
+
omp_val = default_omp
|
| 35 |
+
os.environ["OMP_NUM_THREADS"] = str(omp_val)
|
| 36 |
+
|
| 37 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 38 |
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 39 |
+
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
|
| 40 |
+
|
| 41 |
+
_sanitize_omp_env()
|
| 42 |
+
|
| 43 |
+
# Stable app dirs (avoid /tmp surprises on HF)
|
| 44 |
+
APP_ROOT = Path(__file__).resolve().parent
|
| 45 |
+
DATA_ROOT = APP_ROOT / "data"
|
| 46 |
+
TMP_ROOT = APP_ROOT / "tmp"
|
| 47 |
+
JOB_ROOT = TMP_ROOT / "backgroundfx_jobs"
|
| 48 |
+
for p in (DATA_ROOT, TMP_ROOT, JOB_ROOT):
|
| 49 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Keep model/caches local to repo volume
|
| 52 |
+
os.environ.setdefault("HF_HOME", str(APP_ROOT / ".hf"))
|
| 53 |
+
os.environ.setdefault("TORCH_HOME", str(APP_ROOT / ".torch"))
|
| 54 |
+
Path(os.environ["HF_HOME"]).mkdir(parents=True, exist_ok=True)
|
| 55 |
+
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
# Make Gradio a bit quieter / safer in Spaces
|
| 58 |
+
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
|
| 59 |
+
|
| 60 |
+
# Standard imports (after env is sane)
|
| 61 |
+
import torch
|
| 62 |
+
import gradio as gr
|
| 63 |
+
|
| 64 |
+
# Import our modules
|
| 65 |
+
from ui_core_functionality import startup_probe, logger
|
| 66 |
+
from ui_core_interface import create_interface
|
| 67 |
+
|
| 68 |
+
# Optional: patch a Gradio client util to tolerate boolean JSON Schemas
|
| 69 |
+
def _patch_gradio_client_bool_schema():
|
| 70 |
+
try:
|
| 71 |
+
import gradio_client.utils as _gc_utils # type: ignore
|
| 72 |
+
_orig_get_type = _gc_utils.get_type
|
| 73 |
+
|
| 74 |
+
def _safe_get_type(schema):
|
| 75 |
+
if isinstance(schema, bool):
|
| 76 |
+
return "Any" if schema else "None"
|
| 77 |
+
return _orig_get_type(schema)
|
| 78 |
+
|
| 79 |
+
_gc_utils.get_type = _safe_get_type # type: ignore[attr-defined]
|
| 80 |
+
|
| 81 |
+
if hasattr(_gc_utils, "_json_schema_to_python_type"):
|
| 82 |
+
_orig_walk = _gc_utils._json_schema_to_python_type # type: ignore[attr-defined]
|
| 83 |
+
def _safe_walk(schema, defs):
|
| 84 |
+
if isinstance(schema, bool):
|
| 85 |
+
return "Any" if schema else "None"
|
| 86 |
+
return _orig_walk(schema, defs)
|
| 87 |
+
_gc_utils._json_schema_to_python_type = _safe_walk # type: ignore[attr-defined]
|
| 88 |
+
|
| 89 |
+
logger.info("π©Ή Patched gradio_client.utils to handle boolean JSON Schemas.")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.warning("Could not patch gradio_client boolean schema handling: %s", e)
|
| 92 |
+
|
| 93 |
+
_patch_gradio_client_bool_schema()
|
| 94 |
+
|
| 95 |
+
# =======================================================================
|
| 96 |
+
# MAIN APPLICATION
|
| 97 |
+
# =======================================================================
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
"""Main application entry point"""
|
| 101 |
+
try:
|
| 102 |
+
startup_probe()
|
| 103 |
+
|
| 104 |
+
logger.info("π Launching Gradio interface...")
|
| 105 |
+
logger.info(
|
| 106 |
+
"Gradio=%s | torch=%s | cu=%s | cuda_available=%s",
|
| 107 |
+
getattr(gr, "__version__", "?"),
|
| 108 |
+
torch.__version__,
|
| 109 |
+
getattr(torch.version, "cuda", None),
|
| 110 |
+
torch.cuda.is_available(),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
demo = create_interface()
|
| 114 |
+
|
| 115 |
+
# Gradio 4.x: keep queue small to avoid RAM spikes (no concurrency_count here)
|
| 116 |
+
demo.queue(max_size=2)
|
| 117 |
+
|
| 118 |
+
# Port from env (HF sets PORT)
|
| 119 |
+
port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", "7860")))
|
| 120 |
+
|
| 121 |
+
# Detect HF Space; never use share=True on Spaces (avoids frpc download / 500s)
|
| 122 |
+
in_space = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") or os.getenv("SYSTEM") == "spaces")
|
| 123 |
+
|
| 124 |
+
demo.launch(
|
| 125 |
+
server_name="0.0.0.0",
|
| 126 |
+
server_port=port,
|
| 127 |
+
share=False if in_space else False, # keep False on Spaces
|
| 128 |
+
show_api=False, # safer on public Spaces
|
| 129 |
+
show_error=True,
|
| 130 |
+
quiet=True,
|
| 131 |
+
debug=False,
|
| 132 |
+
max_threads=1 # worker threads; per-listener concurrency set in ui_core_interface.py
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error("β Application startup failed: %s", e)
|
| 137 |
+
raise
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
main()
|
VideoBackgroundReplacer2/ui_core_functionality.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BackgroundFX Pro β Core Functionality
|
| 4 |
+
All processing logic, utilities, background generators, and handlers
|
| 5 |
+
Enhanced with file safety, robust logging, and runtime diagnostics.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import io
|
| 11 |
+
import gc
|
| 12 |
+
import time
|
| 13 |
+
import json
|
| 14 |
+
import uuid
|
| 15 |
+
import shutil
|
| 16 |
+
import logging
|
| 17 |
+
import tempfile
|
| 18 |
+
import requests
|
| 19 |
+
import threading
|
| 20 |
+
import traceback
|
| 21 |
+
import subprocess
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 24 |
+
from typing import Optional, Tuple, List, Dict, Any, Union, Callable
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import numpy as np
|
| 29 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 30 |
+
import cv2
|
| 31 |
+
|
| 32 |
+
# ==============================================================================
|
| 33 |
+
# PATHS & ENV
|
| 34 |
+
# ==============================================================================
|
| 35 |
+
|
| 36 |
+
# Repo root (β¦/app)
|
| 37 |
+
APP_ROOT = Path(__file__).resolve().parent
|
| 38 |
+
DATA_ROOT = APP_ROOT / "data"
|
| 39 |
+
TMP_ROOT = APP_ROOT / "tmp"
|
| 40 |
+
JOB_ROOT = TMP_ROOT / "backgroundfx_jobs"
|
| 41 |
+
|
| 42 |
+
for p in (
|
| 43 |
+
DATA_ROOT,
|
| 44 |
+
TMP_ROOT,
|
| 45 |
+
JOB_ROOT,
|
| 46 |
+
APP_ROOT / ".hf",
|
| 47 |
+
APP_ROOT / ".torch",
|
| 48 |
+
APP_ROOT / "checkpoints",
|
| 49 |
+
APP_ROOT / "models",
|
| 50 |
+
APP_ROOT / "utils",
|
| 51 |
+
):
|
| 52 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# Cache dirs (stable on Spaces)
|
| 55 |
+
os.environ.setdefault("HF_HOME", str(APP_ROOT / ".hf"))
|
| 56 |
+
os.environ.setdefault("TORCH_HOME", str(APP_ROOT / ".torch"))
|
| 57 |
+
|
| 58 |
+
# Quiet BLAS/OpenMP spam (in case ui.py wasn't first)
|
| 59 |
+
if not os.environ.get("OMP_NUM_THREADS", "").isdigit():
|
| 60 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
| 61 |
+
os.environ.setdefault("OMP_NUM_THREADS", "4")
|
| 62 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 63 |
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 64 |
+
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
|
| 65 |
+
os.environ.setdefault("PYTHONFAULTHANDLER", "1")
|
| 66 |
+
|
| 67 |
+
# ==============================================================================
|
| 68 |
+
# LOGGING + DIAGNOSTICS (console + file + heartbeat)
|
| 69 |
+
# ==============================================================================
|
| 70 |
+
|
| 71 |
+
# Line-buffer logs so Space UI shows them promptly
|
| 72 |
+
try:
|
| 73 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 74 |
+
sys.stderr.reconfigure(line_buffering=True)
|
| 75 |
+
except Exception:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
LOG_FILE = DATA_ROOT / "run.log"
|
| 79 |
+
logging.basicConfig(
|
| 80 |
+
level=logging.INFO,
|
| 81 |
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
| 82 |
+
handlers=[logging.StreamHandler(sys.stdout),
|
| 83 |
+
logging.FileHandler(LOG_FILE, encoding="utf-8")],
|
| 84 |
+
force=True,
|
| 85 |
+
)
|
| 86 |
+
logger = logging.getLogger("bgfx")
|
| 87 |
+
|
| 88 |
+
# Faulthandler (native crashes -> stacks)
|
| 89 |
+
try:
|
| 90 |
+
import faulthandler, signal # type: ignore
|
| 91 |
+
faulthandler.enable(all_threads=True)
|
| 92 |
+
if hasattr(signal, "SIGUSR1"):
|
| 93 |
+
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.warning("faulthandler setup skipped: %s", e)
|
| 96 |
+
|
| 97 |
+
def _disk_stats(p: Path) -> str:
|
| 98 |
+
try:
|
| 99 |
+
total, used, free = shutil.disk_usage(str(p))
|
| 100 |
+
mb = lambda x: x // (1024 * 1024)
|
| 101 |
+
return f"disk(total={mb(total)}MB, used={mb(used)}MB, free={mb(free)}MB)"
|
| 102 |
+
except Exception:
|
| 103 |
+
return "disk(n/a)"
|
| 104 |
+
|
| 105 |
+
def _cgroup_limit_bytes():
|
| 106 |
+
for fp in ("/sys/fs/cgroup/memory.max", "/sys/fs/cgroup/memory/memory.limit_in_bytes"):
|
| 107 |
+
try:
|
| 108 |
+
s = Path(fp).read_text().strip()
|
| 109 |
+
if s and s != "max":
|
| 110 |
+
return int(s)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def _rss_bytes():
|
| 115 |
+
try:
|
| 116 |
+
for line in Path("/proc/self/status").read_text().splitlines():
|
| 117 |
+
if line.startswith("VmRSS:"):
|
| 118 |
+
return int(line.split()[1]) * 1024
|
| 119 |
+
except Exception:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def _heartbeat():
|
| 123 |
+
lim = _cgroup_limit_bytes()
|
| 124 |
+
while True:
|
| 125 |
+
rss = _rss_bytes()
|
| 126 |
+
logger.info(
|
| 127 |
+
"HEARTBEAT | rss=%s MB | limit=%s MB | %s",
|
| 128 |
+
f"{rss//2**20}" if rss else "n/a",
|
| 129 |
+
f"{lim//2**20}" if lim else "n/a",
|
| 130 |
+
_disk_stats(APP_ROOT),
|
| 131 |
+
)
|
| 132 |
+
time.sleep(2)
|
| 133 |
+
|
| 134 |
+
# Start heartbeat as a daemon thread (only once)
|
| 135 |
+
try:
|
| 136 |
+
threading.Thread(target=_heartbeat, name="heartbeat", daemon=True).start()
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning("heartbeat skipped: %s", e)
|
| 139 |
+
|
| 140 |
+
import atexit
|
| 141 |
+
@atexit.register
|
| 142 |
+
def _on_exit():
|
| 143 |
+
logger.info("PROCESS EXITING (atexit) β if you don't see this, it was a hard kill (OOM/SIGKILL)")
|
| 144 |
+
|
| 145 |
+
# ==============================================================================
|
| 146 |
+
# STARTUP VALIDATION
|
| 147 |
+
# ==============================================================================
|
| 148 |
+
|
| 149 |
+
def startup_probe():
|
| 150 |
+
"""Comprehensive startup probe - validates system readiness"""
|
| 151 |
+
try:
|
| 152 |
+
logger.info("π BACKGROUNDFX PRO STARTUP PROBE")
|
| 153 |
+
logger.info("π Working directory: %s", os.getcwd())
|
| 154 |
+
logger.info("π Python executable: %s", sys.executable)
|
| 155 |
+
|
| 156 |
+
# Write probe (fail fast if not writable)
|
| 157 |
+
probe_file = TMP_ROOT / "startup_probe.txt"
|
| 158 |
+
probe_file.write_text("startup_test_ok", encoding="utf-8")
|
| 159 |
+
assert probe_file.read_text(encoding="utf-8") == "startup_test_ok"
|
| 160 |
+
logger.info("β
WRITE PROBE OK: %s | %s", probe_file, _disk_stats(APP_ROOT))
|
| 161 |
+
probe_file.unlink(missing_ok=True)
|
| 162 |
+
|
| 163 |
+
# GPU/Torch status
|
| 164 |
+
try:
|
| 165 |
+
logger.info("π§ Torch=%s | cu=%s | cuda_available=%s",
|
| 166 |
+
torch.__version__, getattr(torch.version, "cuda", None), torch.cuda.is_available())
|
| 167 |
+
if torch.cuda.is_available():
|
| 168 |
+
gpu_count = torch.cuda.device_count()
|
| 169 |
+
name = torch.cuda.get_device_name(0) if gpu_count else "Unknown"
|
| 170 |
+
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) if gpu_count else 0
|
| 171 |
+
logger.info("π₯ GPU Available: %s (%d device(s)) β VRAM %.1f GB", name, gpu_count, vram_gb)
|
| 172 |
+
else:
|
| 173 |
+
logger.warning("β οΈ No GPU available β using CPU")
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.warning("β οΈ Torch check failed: %s", e)
|
| 176 |
+
|
| 177 |
+
# Directory verification (and creation if missing)
|
| 178 |
+
for d in ("checkpoints", "models", "utils"):
|
| 179 |
+
dp = APP_ROOT / d
|
| 180 |
+
dp.mkdir(parents=True, exist_ok=True)
|
| 181 |
+
logger.info("β
Directory %s: %s", d, dp)
|
| 182 |
+
|
| 183 |
+
# Job dir isolation test
|
| 184 |
+
test_job = JOB_ROOT / "startup_test_job"
|
| 185 |
+
test_job.mkdir(parents=True, exist_ok=True)
|
| 186 |
+
tfile = test_job / "test.tmp"
|
| 187 |
+
tfile.write_text("job_isolation_test")
|
| 188 |
+
assert tfile.read_text() == "job_isolation_test"
|
| 189 |
+
logger.info("β
Job isolation directory ready: %s", JOB_ROOT)
|
| 190 |
+
shutil.rmtree(test_job, ignore_errors=True)
|
| 191 |
+
|
| 192 |
+
# Env summary
|
| 193 |
+
logger.info("π Env: OMP_NUM_THREADS=%s | HF_HOME=%s | TORCH_HOME=%s",
|
| 194 |
+
os.environ.get("OMP_NUM_THREADS", "unset"),
|
| 195 |
+
os.environ.get("HF_HOME", "default"),
|
| 196 |
+
os.environ.get("TORCH_HOME", "default"))
|
| 197 |
+
|
| 198 |
+
logger.info("π― Startup probe completed β system ready!")
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.error("β STARTUP PROBE FAILED: %s", e)
|
| 202 |
+
logger.error("π %s", _disk_stats(APP_ROOT))
|
| 203 |
+
raise RuntimeError(f"Startup probe failed β system not ready: {e}") from e
|
| 204 |
+
|
| 205 |
+
# ==============================================================================
|
| 206 |
+
# FILE SAFETY UTILITIES
|
| 207 |
+
# ==============================================================================
|
| 208 |
+
|
| 209 |
+
def new_tmp_path(suffix: str) -> Path:
|
| 210 |
+
"""Generate safe temporary path within TMP_ROOT"""
|
| 211 |
+
return TMP_ROOT / f"{uuid.uuid4().hex}{suffix}"
|
| 212 |
+
|
| 213 |
+
def atomic_write_bytes(dst: Path, data: bytes):
|
| 214 |
+
"""Atomic file write to prevent corruption"""
|
| 215 |
+
tmp = new_tmp_path(dst.suffix + ".part")
|
| 216 |
+
try:
|
| 217 |
+
with open(tmp, "wb") as f:
|
| 218 |
+
f.write(data)
|
| 219 |
+
tmp.replace(dst) # atomic on same FS
|
| 220 |
+
logger.debug("β
Atomic write: %s", dst)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
if tmp.exists():
|
| 223 |
+
tmp.unlink(missing_ok=True)
|
| 224 |
+
raise e
|
| 225 |
+
|
| 226 |
+
def safe_name(name: str, default="file") -> str:
|
| 227 |
+
"""Sanitize filename to prevent traversal/unicode issues"""
|
| 228 |
+
import re
|
| 229 |
+
base = re.sub(r"[^A-Za-z0-9._-]+", "_", (name or default))
|
| 230 |
+
return base[:120] or default
|
| 231 |
+
|
| 232 |
+
def place_uploaded(in_path: str, sub="uploads") -> Path:
|
| 233 |
+
"""Safely handle uploaded files with sanitized names"""
|
| 234 |
+
target_dir = DATA_ROOT / sub
|
| 235 |
+
target_dir.mkdir(exist_ok=True, parents=True)
|
| 236 |
+
out = target_dir / safe_name(Path(in_path).name)
|
| 237 |
+
shutil.copy2(in_path, out)
|
| 238 |
+
logger.info("π Uploaded file placed: %s", out)
|
| 239 |
+
return out
|
| 240 |
+
|
| 241 |
+
def tmp_video_path(ext=".mp4") -> Path:
|
| 242 |
+
return new_tmp_path(ext)
|
| 243 |
+
|
| 244 |
+
def tmp_image_path(ext=".png") -> Path:
|
| 245 |
+
return new_tmp_path(ext)
|
| 246 |
+
|
| 247 |
+
def run_safely(fn: Callable, *args, **kwargs):
|
| 248 |
+
"""Execute function with comprehensive error logging"""
|
| 249 |
+
try:
|
| 250 |
+
return fn(*args, **kwargs)
|
| 251 |
+
except Exception:
|
| 252 |
+
logger.error("PROCESSING FAILED\n%s", "".join(traceback.format_exc()))
|
| 253 |
+
logger.error("CWD=%s | DATA_ROOT=%s | TMP_ROOT=%s | %s",
|
| 254 |
+
os.getcwd(), DATA_ROOT, TMP_ROOT, _disk_stats(APP_ROOT))
|
| 255 |
+
try:
|
| 256 |
+
logger.error("Env: OMP_NUM_THREADS=%s | CUDA=%s | torch=%s | cu=%s",
|
| 257 |
+
os.environ.get("OMP_NUM_THREADS"),
|
| 258 |
+
os.environ.get("CUDA_VISIBLE_DEVICES", "default"),
|
| 259 |
+
torch.__version__,
|
| 260 |
+
getattr(torch.version, "cuda", None))
|
| 261 |
+
except Exception:
|
| 262 |
+
pass
|
| 263 |
+
raise
|
| 264 |
+
|
| 265 |
+
# ==============================================================================
|
| 266 |
+
# SYSTEM UTILITIES
|
| 267 |
+
# ==============================================================================
|
| 268 |
+
|
| 269 |
+
def get_device():
|
| 270 |
+
"""Get optimal device for processing"""
|
| 271 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 272 |
+
|
| 273 |
+
def clear_gpu_memory():
|
| 274 |
+
"""Aggressive GPU memory cleanup"""
|
| 275 |
+
try:
|
| 276 |
+
if torch.cuda.is_available():
|
| 277 |
+
torch.cuda.empty_cache()
|
| 278 |
+
torch.cuda.synchronize()
|
| 279 |
+
gc.collect()
|
| 280 |
+
logger.info("π§Ή GPU memory cleared")
|
| 281 |
+
except Exception as e:
|
| 282 |
+
logger.warning("GPU cleanup warning: %s", e)
|
| 283 |
+
|
| 284 |
+
def safe_file_operation(operation: Callable, *args, max_retries: int = 3, **kwargs):
|
| 285 |
+
"""Safely execute file operations with retries"""
|
| 286 |
+
last_error = None
|
| 287 |
+
for attempt in range(max_retries):
|
| 288 |
+
try:
|
| 289 |
+
return operation(*args, **kwargs)
|
| 290 |
+
except Exception as e:
|
| 291 |
+
last_error = e
|
| 292 |
+
if attempt < max_retries - 1:
|
| 293 |
+
time.sleep(0.1 * (attempt + 1))
|
| 294 |
+
logger.warning("File op retry %d: %s", attempt + 1, e)
|
| 295 |
+
else:
|
| 296 |
+
logger.error("File op failed after %d attempts: %s", max_retries, e)
|
| 297 |
+
raise last_error
|
| 298 |
+
|
| 299 |
+
# ==============================================================================
|
| 300 |
+
# BACKGROUND GENERATORS
|
| 301 |
+
# ==============================================================================
|
| 302 |
+
|
| 303 |
+
def generate_ai_background(prompt: str, width: int, height: int) -> Image.Image:
|
| 304 |
+
"""Generate AI-like background using prompt cues (procedural)"""
|
| 305 |
+
try:
|
| 306 |
+
logger.info("Generating AI background: '%s' (%dx%d)", prompt, width, height)
|
| 307 |
+
img = np.zeros((height, width, 3), dtype=np.uint8)
|
| 308 |
+
prompt_lower = prompt.lower()
|
| 309 |
+
|
| 310 |
+
if any(w in prompt_lower for w in ('city', 'urban', 'futuristic', 'cyberpunk')):
|
| 311 |
+
for i in range(height):
|
| 312 |
+
r = int(20 + 80 * (i / height))
|
| 313 |
+
g = int(30 + 100 * (i / height))
|
| 314 |
+
b = int(60 + 120 * (i / height))
|
| 315 |
+
img[i, :] = [r, g, b]
|
| 316 |
+
elif any(w in prompt_lower for w in ('beach', 'tropical', 'ocean', 'sea')):
|
| 317 |
+
for i in range(height):
|
| 318 |
+
r = int(135 + 120 * (i / height))
|
| 319 |
+
g = int(206 + 49 * (i / height))
|
| 320 |
+
b = int(235 + 20 * (i / height))
|
| 321 |
+
img[i, :] = [r, g, b]
|
| 322 |
+
elif any(w in prompt_lower for w in ('forest', 'jungle', 'nature', 'green')):
|
| 323 |
+
for i in range(height):
|
| 324 |
+
r = int(34 + 105 * (i / height))
|
| 325 |
+
g = int(139 + 30 * (i / height))
|
| 326 |
+
b = int(34 - 15 * (i / height))
|
| 327 |
+
img[i, :] = [max(0, r), max(0, g), max(0, b)]
|
| 328 |
+
elif any(w in prompt_lower for w in ('space', 'galaxy', 'stars', 'cosmic')):
|
| 329 |
+
for i in range(height):
|
| 330 |
+
r = int(10 + 50 * (i / height))
|
| 331 |
+
g = int(0 + 30 * (i / height))
|
| 332 |
+
b = int(30 + 100 * (i / height))
|
| 333 |
+
img[i, :] = [r, g, b]
|
| 334 |
+
elif any(w in prompt_lower for w in ('desert', 'sand', 'canyon')):
|
| 335 |
+
for i in range(height):
|
| 336 |
+
r = int(238 + 17 * (i / height))
|
| 337 |
+
g = int(203 + 52 * (i / height))
|
| 338 |
+
b = int(173 + 82 * (i / height))
|
| 339 |
+
img[i, :] = [min(255, r), min(255, g), min(255, b)]
|
| 340 |
+
else:
|
| 341 |
+
colors = [(255, 182, 193), (255, 218, 185), (176, 224, 230)]
|
| 342 |
+
color = colors[len(prompt) % len(colors)]
|
| 343 |
+
for i in range(height):
|
| 344 |
+
t = 1 - (i / height) * 0.3
|
| 345 |
+
img[i, :] = [int(color[0] * t), int(color[1] * t), int(color[2] * t)]
|
| 346 |
+
|
| 347 |
+
noise = np.random.randint(-15, 15, (height, width, 3))
|
| 348 |
+
img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)
|
| 349 |
+
return Image.fromarray(img)
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
logger.warning("AI background generation failed: %s β using fallback", e)
|
| 353 |
+
return create_gradient_background("sunset", width, height)
|
| 354 |
+
|
| 355 |
+
def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
|
| 356 |
+
img = np.zeros((height, width, 3), dtype=np.uint8)
|
| 357 |
+
gradients = {
|
| 358 |
+
"sunset": [(255, 165, 0), (128, 64, 128)],
|
| 359 |
+
"ocean": [(0, 100, 255), (30, 144, 255)],
|
| 360 |
+
"forest": [(34, 139, 34), (139, 69, 19)],
|
| 361 |
+
"sky": [(135, 206, 235), (206, 235, 255)],
|
| 362 |
+
}
|
| 363 |
+
if gradient_type in gradients:
|
| 364 |
+
start, end = gradients[gradient_type]
|
| 365 |
+
for i in range(height):
|
| 366 |
+
r = int(start[0] * (1 - i/height) + end[0] * (i/height))
|
| 367 |
+
g = int(start[1] * (1 - i/height) + end[1] * (i/height))
|
| 368 |
+
b = int(start[2] * (1 - i/height) + end[2] * (i/height))
|
| 369 |
+
img[i, :] = [r, g, b]
|
| 370 |
+
else:
|
| 371 |
+
img.fill(128)
|
| 372 |
+
return Image.fromarray(img)
|
| 373 |
+
|
| 374 |
+
def create_solid_background(color: str, width: int, height: int) -> Image.Image:
|
| 375 |
+
color_map = {
|
| 376 |
+
"white": (255, 255, 255), "black": (0, 0, 0), "red": (255, 0, 0),
|
| 377 |
+
"green": (0, 255, 0), "blue": (0, 0, 255), "yellow": (255, 255, 0),
|
| 378 |
+
"purple": (128, 0, 128), "orange": (255, 165, 0), "pink": (255, 192, 203),
|
| 379 |
+
"gray": (128, 128, 128)
|
| 380 |
+
}
|
| 381 |
+
rgb = color_map.get(color.lower(), (128, 128, 128))
|
| 382 |
+
return Image.new("RGB", (width, height), rgb)
|
| 383 |
+
|
| 384 |
+
def download_unsplash_image(query: str, width: int, height: int) -> Image.Image:
|
| 385 |
+
try:
|
| 386 |
+
url = f"https://source.unsplash.com/{width}x{height}/?{query}"
|
| 387 |
+
resp = requests.get(url, timeout=10)
|
| 388 |
+
resp.raise_for_status()
|
| 389 |
+
img = Image.open(io.BytesIO(resp.content))
|
| 390 |
+
if img.size != (width, height):
|
| 391 |
+
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
| 392 |
+
return img.convert("RGB")
|
| 393 |
+
except Exception as e:
|
| 394 |
+
logger.warning("Unsplash download failed: %s", e)
|
| 395 |
+
return create_solid_background("gray", width, height)
|
| 396 |
+
|
| 397 |
+
# ==============================================================================
|
| 398 |
+
# VIDEO UTILITIES
|
| 399 |
+
# ==============================================================================
|
| 400 |
+
|
| 401 |
+
def get_video_info(video_path: str) -> Dict[str, Any]:
|
| 402 |
+
try:
|
| 403 |
+
cap = cv2.VideoCapture(video_path)
|
| 404 |
+
if not cap.isOpened():
|
| 405 |
+
raise ValueError("Cannot open video file")
|
| 406 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 407 |
+
frames= int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 408 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 409 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 410 |
+
cap.release()
|
| 411 |
+
return {"fps": fps, "frame_count": frames, "width": w, "height": h,
|
| 412 |
+
"duration": (frames / fps if fps > 0 else 0)}
|
| 413 |
+
except Exception as e:
|
| 414 |
+
logger.error("get_video_info failed: %s", e)
|
| 415 |
+
return {"fps": 30.0, "frame_count": 0, "width": 1920, "height": 1080, "duration": 0}
|
| 416 |
+
|
| 417 |
+
def extract_frame(video_path: str, frame_number: int) -> Optional[np.ndarray]:
|
| 418 |
+
try:
|
| 419 |
+
cap = cv2.VideoCapture(video_path)
|
| 420 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
| 421 |
+
ret, frame = cap.read()
|
| 422 |
+
cap.release()
|
| 423 |
+
if ret:
|
| 424 |
+
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 425 |
+
return None
|
| 426 |
+
except Exception as e:
|
| 427 |
+
logger.error("extract_frame failed: %s", e)
|
| 428 |
+
return None
|
| 429 |
+
|
| 430 |
+
def ffmpeg_safe_call(inp: Path, out: Path, extra=()):
|
| 431 |
+
cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error", "-i", str(inp), *extra, str(out)]
|
| 432 |
+
logger.info("FFMPEG %s", " ".join(cmd))
|
| 433 |
+
subprocess.run(cmd, check=True, timeout=300)
|
| 434 |
+
|
| 435 |
+
# ==============================================================================
|
| 436 |
+
# PROGRESS TRACKING
|
| 437 |
+
# ==============================================================================
|
| 438 |
+
|
| 439 |
+
class ProgressTracker:
|
| 440 |
+
"""Thread-safe progress tracking for video processing"""
|
| 441 |
+
def __init__(self):
|
| 442 |
+
self.current_step = ""
|
| 443 |
+
self.progress = 0.0
|
| 444 |
+
self.total_frames = 0
|
| 445 |
+
self.processed_frames = 0
|
| 446 |
+
self.start_time = time.time()
|
| 447 |
+
self.lock = threading.Lock()
|
| 448 |
+
|
| 449 |
+
def update(self, step: str, progress: float = None):
|
| 450 |
+
with self.lock:
|
| 451 |
+
self.current_step = step
|
| 452 |
+
if progress is not None:
|
| 453 |
+
self.progress = max(0.0, min(1.0, progress))
|
| 454 |
+
|
| 455 |
+
def update_frames(self, processed: int, total: int = None):
|
| 456 |
+
with self.lock:
|
| 457 |
+
self.processed_frames = processed
|
| 458 |
+
if total is not None:
|
| 459 |
+
self.total_frames = total
|
| 460 |
+
if self.total_frames > 0:
|
| 461 |
+
self.progress = self.processed_frames / self.total_frames
|
| 462 |
+
|
| 463 |
+
def get_status(self) -> Dict[str, Any]:
|
| 464 |
+
with self.lock:
|
| 465 |
+
elapsed = time.time() - self.start_time
|
| 466 |
+
eta = 0
|
| 467 |
+
if self.progress > 0.01:
|
| 468 |
+
eta = elapsed * (1.0 - self.progress) / self.progress
|
| 469 |
+
return {
|
| 470 |
+
"step": self.current_step, "progress": self.progress,
|
| 471 |
+
"processed_frames": self.processed_frames, "total_frames": self.total_frames,
|
| 472 |
+
"elapsed": elapsed, "eta": eta
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
# Global tracker
|
| 476 |
+
progress_tracker = ProgressTracker()
|
| 477 |
+
|
| 478 |
+
# ==============================================================================
|
| 479 |
+
# SAFE FILE OPS
|
| 480 |
+
# ==============================================================================
|
| 481 |
+
|
| 482 |
+
def create_job_directory() -> Path:
|
| 483 |
+
job_id = str(uuid.uuid4())[:8]
|
| 484 |
+
job_dir = JOB_ROOT / f"job_{job_id}_{int(time.time())}"
|
| 485 |
+
job_dir.mkdir(parents=True, exist_ok=True)
|
| 486 |
+
logger.info("π Created job directory: %s", job_dir)
|
| 487 |
+
return job_dir
|
| 488 |
+
|
| 489 |
+
def atomic_file_write(filepath: Path, content: bytes):
|
| 490 |
+
# Use with_name to append ".tmp" without breaking pathlib rules
|
| 491 |
+
temp_path = filepath.with_name(f"{filepath.name}.tmp")
|
| 492 |
+
try:
|
| 493 |
+
with open(temp_path, 'wb') as f:
|
| 494 |
+
f.write(content)
|
| 495 |
+
temp_path.rename(filepath)
|
| 496 |
+
logger.debug("β
Atomic write: %s", filepath)
|
| 497 |
+
except Exception as e:
|
| 498 |
+
if temp_path.exists():
|
| 499 |
+
temp_path.unlink(missing_ok=True)
|
| 500 |
+
raise e
|
| 501 |
+
|
| 502 |
+
def safe_download(url: str, filepath: Path, max_size: int = 500 * 1024 * 1024):
|
| 503 |
+
# Use with_name to append ".download" safely (e.g., "video.mp4.download")
|
| 504 |
+
temp_path = filepath.with_name(f"{filepath.name}.download")
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
r = requests.get(url, stream=True, timeout=30)
|
| 508 |
+
r.raise_for_status()
|
| 509 |
+
cl = r.headers.get('content-length')
|
| 510 |
+
if cl and int(cl) > max_size:
|
| 511 |
+
raise ValueError(f"File too large: {cl} bytes")
|
| 512 |
+
|
| 513 |
+
downloaded = 0
|
| 514 |
+
with open(temp_path, 'wb') as f:
|
| 515 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 516 |
+
if chunk:
|
| 517 |
+
downloaded += len(chunk)
|
| 518 |
+
if downloaded > max_size:
|
| 519 |
+
raise ValueError(f"Download exceeded size limit: {downloaded} bytes")
|
| 520 |
+
f.write(chunk)
|
| 521 |
+
|
| 522 |
+
if not temp_path.exists() or temp_path.stat().st_size == 0:
|
| 523 |
+
raise ValueError("Download resulted in empty file")
|
| 524 |
+
|
| 525 |
+
temp_path.rename(filepath)
|
| 526 |
+
logger.info("β
Downloaded: %s (%d bytes)", filepath, downloaded)
|
| 527 |
+
|
| 528 |
+
except Exception as e:
|
| 529 |
+
if temp_path.exists():
|
| 530 |
+
temp_path.unlink(missing_ok=True)
|
| 531 |
+
logger.error("β Download failed: %s", e)
|
| 532 |
+
raise
|
| 533 |
+
|
| 534 |
+
# ==============================================================================
|
| 535 |
+
# ENHANCED PIPELINE INTEGRATION
|
| 536 |
+
# ==============================================================================
|
| 537 |
+
|
| 538 |
+
def process_video_pipeline(
|
| 539 |
+
video_path: str,
|
| 540 |
+
background_image: Optional[Image.Image],
|
| 541 |
+
background_type: str,
|
| 542 |
+
background_prompt: str,
|
| 543 |
+
job_dir: Path,
|
| 544 |
+
progress_callback: Optional[Callable] = None
|
| 545 |
+
) -> str:
|
| 546 |
+
"""Process video using the two-stage pipeline with enhanced safety and monitoring"""
|
| 547 |
+
|
| 548 |
+
def _inner_process():
|
| 549 |
+
logger.info("=" * 60)
|
| 550 |
+
logger.info("=== ENHANCED TWO-STAGE PIPELINE (WITH SAFETY) ===")
|
| 551 |
+
logger.info("=" * 60)
|
| 552 |
+
|
| 553 |
+
logger.info("DEBUG video_path=%s exists=%s size=%s bytes",
|
| 554 |
+
video_path, Path(video_path).exists(),
|
| 555 |
+
(Path(video_path).stat().st_size if Path(video_path).exists() else "N/A"))
|
| 556 |
+
logger.info("DEBUG job_dir=%s writable=%s", job_dir, os.access(job_dir, os.W_OK))
|
| 557 |
+
logger.info("DEBUG bg_image=%s bg_type=%s | %s",
|
| 558 |
+
(background_image.size if background_image else None),
|
| 559 |
+
background_type, _disk_stats(APP_ROOT))
|
| 560 |
+
|
| 561 |
+
if not Path(video_path).exists():
|
| 562 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
| 563 |
+
|
| 564 |
+
# Copy into controlled area
|
| 565 |
+
safe_video_path = place_uploaded(video_path, "videos")
|
| 566 |
+
logger.info("DEBUG safe_video_path=%s", safe_video_path)
|
| 567 |
+
|
| 568 |
+
logger.info("DEBUG importing two-stage pipelineβ¦")
|
| 569 |
+
try:
|
| 570 |
+
from two_stage_pipeline import process_two_stage as pipeline_process
|
| 571 |
+
logger.info("β two-stage pipeline import OK")
|
| 572 |
+
except ImportError as e:
|
| 573 |
+
logger.error("Import two_stage_pipeline failed: %s", e)
|
| 574 |
+
raise
|
| 575 |
+
|
| 576 |
+
progress_tracker.update("Initializing enhanced two-stage pipelineβ¦")
|
| 577 |
+
|
| 578 |
+
current_stage = {"stage": "init", "start_time": time.time()}
|
| 579 |
+
|
| 580 |
+
def safe_progress_callback(step: str, progress: float = None):
|
| 581 |
+
try:
|
| 582 |
+
now = time.time()
|
| 583 |
+
elapsed = now - current_stage["start_time"]
|
| 584 |
+
|
| 585 |
+
if "Stage 1" in step and current_stage["stage"] != "stage1":
|
| 586 |
+
current_stage["stage"] = "stage1"
|
| 587 |
+
current_stage["start_time"] = now
|
| 588 |
+
logger.info("π Entering Stage 1 (SAM2) | %s", _disk_stats(APP_ROOT))
|
| 589 |
+
elif "Stage 2" in step and current_stage["stage"] != "stage2":
|
| 590 |
+
d1 = now - current_stage["start_time"]
|
| 591 |
+
current_stage["stage"] = "stage2"
|
| 592 |
+
current_stage["start_time"] = now
|
| 593 |
+
logger.info("π Entering Stage 2 (Composition) β Stage 1 time %.1fs | %s", d1, _disk_stats(APP_ROOT))
|
| 594 |
+
elif "Done" in step and current_stage["stage"] != "complete":
|
| 595 |
+
d2 = now - current_stage["start_time"]
|
| 596 |
+
current_stage["stage"] = "complete"
|
| 597 |
+
logger.info("π Pipeline complete β Stage 2 time %.1fs | %s", d2, _disk_stats(APP_ROOT))
|
| 598 |
+
|
| 599 |
+
logger.info("PROGRESS [%s] (%.1fs): %s (%s)",
|
| 600 |
+
current_stage['stage'].upper(), elapsed, step, progress)
|
| 601 |
+
progress_tracker.update(step, progress)
|
| 602 |
+
|
| 603 |
+
if progress_callback:
|
| 604 |
+
progress_callback(f"Progress: {progress:.1%} - {step}" if progress is not None else step)
|
| 605 |
+
|
| 606 |
+
if current_stage["stage"] == "stage1" and elapsed > 15:
|
| 607 |
+
logger.warning("β οΈ Stage 1 running for %.1fs β monitoring memory", elapsed)
|
| 608 |
+
|
| 609 |
+
except Exception as e:
|
| 610 |
+
logger.error("Progress callback error: %s", e)
|
| 611 |
+
|
| 612 |
+
if background_image is None:
|
| 613 |
+
raise ValueError("Background image is required")
|
| 614 |
+
|
| 615 |
+
logger.info("DEBUG: calling two-stage pipelineβ¦")
|
| 616 |
+
result_path = pipeline_process(
|
| 617 |
+
video_path=str(safe_video_path),
|
| 618 |
+
background_image=background_image,
|
| 619 |
+
workdir=job_dir,
|
| 620 |
+
progress=safe_progress_callback,
|
| 621 |
+
use_matany=True
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
logger.info("DEBUG: pipeline returned %s (%s)", result_path, type(result_path))
|
| 625 |
+
|
| 626 |
+
if result_path:
|
| 627 |
+
result_file = Path(result_path)
|
| 628 |
+
logger.info("DEBUG: result exists=%s", result_file.exists())
|
| 629 |
+
if result_file.exists():
|
| 630 |
+
size = result_file.stat().st_size
|
| 631 |
+
logger.info("DEBUG: result size=%d bytes", size)
|
| 632 |
+
if size == 0:
|
| 633 |
+
raise RuntimeError("Pipeline produced empty output file")
|
| 634 |
+
|
| 635 |
+
# Quick validity check
|
| 636 |
+
try:
|
| 637 |
+
cap = cv2.VideoCapture(str(result_file))
|
| 638 |
+
if cap.isOpened():
|
| 639 |
+
frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 640 |
+
logger.info("DEBUG: output frame_count=%d", frames)
|
| 641 |
+
cap.release()
|
| 642 |
+
else:
|
| 643 |
+
logger.warning("β οΈ Output may not be a valid video (cannot open)")
|
| 644 |
+
except Exception as e:
|
| 645 |
+
logger.warning("β οΈ Could not verify output video: %s", e)
|
| 646 |
+
|
| 647 |
+
if not result_path or not Path(result_path).exists():
|
| 648 |
+
raise RuntimeError("Two-stage pipeline failed β no output produced")
|
| 649 |
+
|
| 650 |
+
logger.info("=" * 60)
|
| 651 |
+
logger.info("β
ENHANCED TWO-STAGE PIPELINE COMPLETED: %s", result_path)
|
| 652 |
+
logger.info("=" * 60)
|
| 653 |
+
return result_path
|
| 654 |
+
|
| 655 |
+
try:
|
| 656 |
+
return run_safely(_inner_process)
|
| 657 |
+
except Exception as e:
|
| 658 |
+
logger.error("π§Ή Error cleanupβ¦")
|
| 659 |
+
clear_gpu_memory()
|
| 660 |
+
logger.error("Job dir state: %s",
|
| 661 |
+
(list(job_dir.iterdir()) if job_dir.exists() else "does not exist"))
|
| 662 |
+
raise
|
VideoBackgroundReplacer2/ui_core_interface.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BackgroundFX Pro β Gradio Interface & Event Handlers
|
| 4 |
+
UI components, event handlers, and interface creation
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import shutil
|
| 9 |
+
import traceback
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import gradio as gr
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
# Import our functionality
|
| 18 |
+
from ui_core_functionality import (
|
| 19 |
+
get_device, clear_gpu_memory, get_video_info, extract_frame,
|
| 20 |
+
create_gradient_background, create_solid_background, download_unsplash_image,
|
| 21 |
+
generate_ai_background, create_job_directory, safe_file_operation, process_video_pipeline,
|
| 22 |
+
progress_tracker, JOB_ROOT, APP_ROOT, logger
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# ===============================================================================
|
| 26 |
+
# GRADIO HANDLERS
|
| 27 |
+
# ===============================================================================
|
| 28 |
+
|
| 29 |
+
def handle_custom_background_upload(image: Optional[Image.Image]) -> Tuple[Optional[Image.Image], str]:
|
| 30 |
+
"""Handle custom background image upload"""
|
| 31 |
+
if image is None:
|
| 32 |
+
return None, "No image uploaded"
|
| 33 |
+
try:
|
| 34 |
+
if image.mode != "RGB":
|
| 35 |
+
image = image.convert("RGB")
|
| 36 |
+
status = f"β
Custom background uploaded: {image.size[0]}x{image.size[1]}"
|
| 37 |
+
logger.info(status)
|
| 38 |
+
return image, status
|
| 39 |
+
except Exception as e:
|
| 40 |
+
error_msg = f"β Background upload failed: {str(e)}"
|
| 41 |
+
logger.error(error_msg)
|
| 42 |
+
return None, error_msg
|
| 43 |
+
|
| 44 |
+
def handle_background_type_change(bg_type: str):
|
| 45 |
+
"""Handle background type selection - show/hide relevant controls"""
|
| 46 |
+
logger.info(f"π¨ Background type changed to: {bg_type}")
|
| 47 |
+
if bg_type == "upload":
|
| 48 |
+
return (
|
| 49 |
+
gr.update(visible=True, label="Upload Custom Background Image"),
|
| 50 |
+
gr.update(visible=False),
|
| 51 |
+
gr.update(visible=False),
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
prompt_placeholder = {
|
| 55 |
+
"ai_generate": "Describe the scene: 'futuristic city', 'tropical beach', 'mystical forest'...",
|
| 56 |
+
"gradient": "Choose style: 'sunset', 'ocean', 'forest', 'sky'",
|
| 57 |
+
"solid": "Choose color: 'red', 'blue', 'green', 'white', 'black'...",
|
| 58 |
+
"unsplash": "Search query: 'mountain landscape', 'city skyline', 'nature'..."
|
| 59 |
+
}
|
| 60 |
+
return (
|
| 61 |
+
gr.update(visible=False),
|
| 62 |
+
gr.update(visible=True, placeholder=prompt_placeholder.get(bg_type, "Enter your prompt...")),
|
| 63 |
+
gr.update(visible=True, value=f"Generate {bg_type.replace('_', ' ').title()} Background"),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def handle_video_upload(video_file) -> Tuple[Optional[str], str]:
|
| 67 |
+
"""Handle video file upload"""
|
| 68 |
+
if video_file is None:
|
| 69 |
+
return None, "No video file provided"
|
| 70 |
+
try:
|
| 71 |
+
job_dir = create_job_directory()
|
| 72 |
+
# Preserve original extension if possible
|
| 73 |
+
src_path = Path(video_file)
|
| 74 |
+
ext = src_path.suffix if src_path.suffix else ".mp4"
|
| 75 |
+
video_path = job_dir / f"input_video{ext}"
|
| 76 |
+
safe_file_operation(lambda src, dst: shutil.copy2(src, dst), str(src_path), str(video_path))
|
| 77 |
+
|
| 78 |
+
info = get_video_info(str(video_path))
|
| 79 |
+
duration_text = f"{info['duration']:.1f}s"
|
| 80 |
+
status = f"β
Video uploaded: {info['width']}x{info['height']}, {info['fps']:.1f}fps, {duration_text}"
|
| 81 |
+
logger.info(status)
|
| 82 |
+
return str(video_path), status
|
| 83 |
+
except Exception as e:
|
| 84 |
+
error_msg = f"β Video upload failed: {str(e)}"
|
| 85 |
+
logger.error(error_msg)
|
| 86 |
+
return None, error_msg
|
| 87 |
+
|
| 88 |
+
def handle_background_generation(bg_type: str, bg_prompt: str, video_path: str) -> Tuple[Optional[Image.Image], str]:
|
| 89 |
+
"""Handle background generation (for non-upload types)"""
|
| 90 |
+
if not video_path:
|
| 91 |
+
return None, "No video loaded"
|
| 92 |
+
if bg_type == "upload":
|
| 93 |
+
return None, "Use the upload field above for custom backgrounds"
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
info = get_video_info(video_path)
|
| 97 |
+
width, height = info['width'], info['height']
|
| 98 |
+
|
| 99 |
+
if bg_type == "ai_generate":
|
| 100 |
+
background = generate_ai_background(bg_prompt, width, height)
|
| 101 |
+
status = f"β
Generated AI background: '{bg_prompt}'"
|
| 102 |
+
|
| 103 |
+
elif bg_type == "gradient":
|
| 104 |
+
gradients = ["sunset", "ocean", "forest", "sky"]
|
| 105 |
+
gradient_type = next((g for g in gradients if g in bg_prompt.lower()), gradients[0])
|
| 106 |
+
background = create_gradient_background(gradient_type, width, height)
|
| 107 |
+
status = f"β
Generated {gradient_type} gradient background"
|
| 108 |
+
|
| 109 |
+
elif bg_type == "solid":
|
| 110 |
+
colors = ["white", "black", "red", "green", "blue", "yellow", "purple", "orange", "pink", "gray"]
|
| 111 |
+
color = next((c for c in colors if c in bg_prompt.lower()), "white")
|
| 112 |
+
background = create_solid_background(color, width, height)
|
| 113 |
+
status = f"β
Generated {color} solid background"
|
| 114 |
+
|
| 115 |
+
elif bg_type == "unsplash":
|
| 116 |
+
query = bg_prompt.strip() or "nature"
|
| 117 |
+
background = download_unsplash_image(query, width, height)
|
| 118 |
+
status = f"β
Downloaded background from Unsplash: '{query}'"
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
background = create_solid_background("gray", width, height)
|
| 122 |
+
status = "β
Generated default gray background"
|
| 123 |
+
|
| 124 |
+
logger.info(status)
|
| 125 |
+
return background, status
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
error_msg = f"β Background generation failed: {str(e)}"
|
| 129 |
+
logger.error(error_msg)
|
| 130 |
+
return None, error_msg
|
| 131 |
+
|
| 132 |
+
def handle_video_processing(
|
| 133 |
+
video_path: str,
|
| 134 |
+
background_image: Optional[Image.Image],
|
| 135 |
+
background_type: str,
|
| 136 |
+
background_prompt: str,
|
| 137 |
+
progress=gr.Progress()
|
| 138 |
+
) -> Tuple[Optional[str], str]:
|
| 139 |
+
"""Handle complete video processing"""
|
| 140 |
+
if not video_path:
|
| 141 |
+
return None, "β No video provided"
|
| 142 |
+
if not background_image:
|
| 143 |
+
return None, "β No background provided"
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
progress(0, "Starting video processing...")
|
| 147 |
+
logger.info("π¬ Starting video processing")
|
| 148 |
+
|
| 149 |
+
job_dir = create_job_directory()
|
| 150 |
+
progress_tracker.update("Creating job directory...")
|
| 151 |
+
|
| 152 |
+
def update_progress(message: str):
|
| 153 |
+
try:
|
| 154 |
+
status = progress_tracker.get_status()
|
| 155 |
+
progress_val = status['progress']
|
| 156 |
+
progress(progress_val, message)
|
| 157 |
+
logger.info(f"Progress: {progress_val:.1%} - {message}")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.warning(f"Progress update failed: {e}")
|
| 160 |
+
|
| 161 |
+
result_path = process_video_pipeline(
|
| 162 |
+
video_path=video_path,
|
| 163 |
+
background_image=background_image,
|
| 164 |
+
background_type=background_type,
|
| 165 |
+
background_prompt=background_prompt,
|
| 166 |
+
job_dir=job_dir,
|
| 167 |
+
progress_callback=update_progress
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
progress(1.0, "Processing complete!")
|
| 171 |
+
clear_gpu_memory()
|
| 172 |
+
|
| 173 |
+
status = "β
Video processing completed successfully!"
|
| 174 |
+
logger.info(status)
|
| 175 |
+
return result_path, status
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
error_msg = f"β Processing failed: {str(e)}"
|
| 179 |
+
logger.error(error_msg)
|
| 180 |
+
logger.error("Traceback: %s", traceback.format_exc())
|
| 181 |
+
clear_gpu_memory()
|
| 182 |
+
return None, error_msg
|
| 183 |
+
|
| 184 |
+
def handle_preview_generation(video_path: str, frame_number: int = 0) -> Tuple[Optional[Image.Image], str]:
|
| 185 |
+
"""Generate preview frame from video"""
|
| 186 |
+
if not video_path:
|
| 187 |
+
return None, "No video loaded"
|
| 188 |
+
try:
|
| 189 |
+
frame = extract_frame(video_path, frame_number)
|
| 190 |
+
if frame is None:
|
| 191 |
+
return None, "Failed to extract frame"
|
| 192 |
+
preview_image = Image.fromarray(frame)
|
| 193 |
+
return preview_image, f"β
Preview generated (frame {frame_number})"
|
| 194 |
+
except Exception as e:
|
| 195 |
+
error_msg = f"β Preview generation failed: {str(e)}"
|
| 196 |
+
logger.error(error_msg)
|
| 197 |
+
return None, error_msg
|
| 198 |
+
|
| 199 |
+
# ===============================================================================
|
| 200 |
+
# GRADIO INTERFACE
|
| 201 |
+
# ===============================================================================
|
| 202 |
+
|
| 203 |
+
def create_interface():
|
| 204 |
+
"""Create the main Gradio interface"""
|
| 205 |
+
|
| 206 |
+
custom_css = """
|
| 207 |
+
.container { max-width: 1200px; margin: auto; }
|
| 208 |
+
.header { text-align: center; margin-bottom: 30px; }
|
| 209 |
+
.section { margin: 20px 0; padding: 20px; border-radius: 10px; }
|
| 210 |
+
.status { font-family: monospace; font-size: 12px; }
|
| 211 |
+
.progress-bar { margin: 10px 0; }
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
with gr.Blocks(
|
| 215 |
+
title="BackgroundFX Pro",
|
| 216 |
+
css=custom_css,
|
| 217 |
+
theme=gr.themes.Soft(),
|
| 218 |
+
analytics_enabled=False, # keep things quiet/stable on 4.x
|
| 219 |
+
) as demo:
|
| 220 |
+
|
| 221 |
+
gr.HTML("""
|
| 222 |
+
<div class="header">
|
| 223 |
+
<h1>π¬ BackgroundFX Pro</h1>
|
| 224 |
+
<p>Professional AI-powered video background replacement using SAM2 and MatAnyone</p>
|
| 225 |
+
</div>
|
| 226 |
+
""")
|
| 227 |
+
|
| 228 |
+
video_path_state = gr.State(value=None)
|
| 229 |
+
background_image_state = gr.State(value=None)
|
| 230 |
+
|
| 231 |
+
with gr.Row():
|
| 232 |
+
with gr.Column(scale=1):
|
| 233 |
+
with gr.Group():
|
| 234 |
+
gr.HTML("<h3>πΉ Video Input</h3>")
|
| 235 |
+
video_upload = gr.File(
|
| 236 |
+
label="Upload Video",
|
| 237 |
+
file_types=[".mp4", ".avi", ".mov", ".mkv"],
|
| 238 |
+
type="filepath"
|
| 239 |
+
)
|
| 240 |
+
video_preview = gr.Image(
|
| 241 |
+
label="Video Preview",
|
| 242 |
+
interactive=False,
|
| 243 |
+
height=300
|
| 244 |
+
)
|
| 245 |
+
# Fixed preview status box (hidden)
|
| 246 |
+
preview_status = gr.Textbox(
|
| 247 |
+
label="Preview Status",
|
| 248 |
+
interactive=False,
|
| 249 |
+
visible=False,
|
| 250 |
+
elem_classes=["status"]
|
| 251 |
+
)
|
| 252 |
+
video_status = gr.Textbox(
|
| 253 |
+
label="Video Status",
|
| 254 |
+
interactive=False,
|
| 255 |
+
elem_classes=["status"]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
with gr.Group():
|
| 259 |
+
gr.HTML("<h3>π¨ Background Selection</h3>")
|
| 260 |
+
|
| 261 |
+
gr.HTML("""
|
| 262 |
+
<div style='background: #f0f8ff; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>
|
| 263 |
+
<b>Choose your background method:</b><br>
|
| 264 |
+
β’ <b>Upload:</b> Use your own image<br>
|
| 265 |
+
β’ <b>AI Generate:</b> Create with AI prompt<br>
|
| 266 |
+
β’ <b>Gradient/Solid/Unsplash:</b> Quick generation
|
| 267 |
+
</div>
|
| 268 |
+
""")
|
| 269 |
+
|
| 270 |
+
background_type = gr.Radio(
|
| 271 |
+
choices=[
|
| 272 |
+
("π€ Upload Image", "upload"),
|
| 273 |
+
("π€ AI Generate", "ai_generate"),
|
| 274 |
+
("π Gradient", "gradient"),
|
| 275 |
+
("π― Solid Color", "solid"),
|
| 276 |
+
("πΈ Unsplash Photo", "unsplash")
|
| 277 |
+
],
|
| 278 |
+
label="Background Type",
|
| 279 |
+
value="upload"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
custom_bg_upload = gr.Image(
|
| 283 |
+
label="Upload Custom Background",
|
| 284 |
+
type="pil",
|
| 285 |
+
interactive=True,
|
| 286 |
+
height=250,
|
| 287 |
+
visible=True
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
background_prompt = gr.Textbox(
|
| 291 |
+
label="Background Prompt",
|
| 292 |
+
placeholder=("AI: 'futuristic city', 'tropical beach' | Gradient: 'sunset', 'ocean' | "
|
| 293 |
+
"Solid: 'red', 'blue' | Unsplash: 'mountain landscape'"),
|
| 294 |
+
value="futuristic city skyline at sunset",
|
| 295 |
+
visible=False
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
generate_bg_btn = gr.Button(
|
| 299 |
+
"Generate Background",
|
| 300 |
+
variant="secondary",
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
background_preview = gr.Image(
|
| 304 |
+
label="Background Preview",
|
| 305 |
+
interactive=False,
|
| 306 |
+
height=300
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
background_status = gr.Textbox(
|
| 310 |
+
label="Background Status",
|
| 311 |
+
interactive=False,
|
| 312 |
+
elem_classes=["status"]
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
with gr.Column(scale=1):
|
| 316 |
+
with gr.Group():
|
| 317 |
+
gr.HTML("<h3>β‘ Processing</h3>")
|
| 318 |
+
|
| 319 |
+
process_btn = gr.Button(
|
| 320 |
+
"π Process Video",
|
| 321 |
+
variant="primary",
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
processing_status = gr.Textbox(
|
| 325 |
+
label="Processing Status",
|
| 326 |
+
interactive=False,
|
| 327 |
+
elem_classes=["status"]
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
with gr.Group():
|
| 331 |
+
gr.HTML("<h3>π½οΈ Results</h3>")
|
| 332 |
+
|
| 333 |
+
result_video = gr.Video(
|
| 334 |
+
label="Processed Video",
|
| 335 |
+
height=400
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Real downloadable output
|
| 339 |
+
download_btn = gr.DownloadButton(
|
| 340 |
+
"π₯ Download Result",
|
| 341 |
+
visible=False
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
with gr.Accordion("π§ System Information", open=False):
|
| 345 |
+
system_info = gr.HTML(f"""
|
| 346 |
+
<div class="system-info">
|
| 347 |
+
<p><strong>Device:</strong> {get_device()}</p>
|
| 348 |
+
<p><strong>Torch Version:</strong> {torch.__version__}</p>
|
| 349 |
+
<p><strong>CUDA Available:</strong> {torch.cuda.is_available()}</p>
|
| 350 |
+
<p><strong>Job Directory:</strong> {JOB_ROOT}</p>
|
| 351 |
+
<p><strong>App Root:</strong> {APP_ROOT}</p>
|
| 352 |
+
</div>
|
| 353 |
+
""")
|
| 354 |
+
|
| 355 |
+
# =========================
|
| 356 |
+
# Event Handlers (4.42.x)
|
| 357 |
+
# =========================
|
| 358 |
+
|
| 359 |
+
# Lightweight; no queue needed
|
| 360 |
+
background_type.change(
|
| 361 |
+
fn=handle_background_type_change,
|
| 362 |
+
inputs=[background_type],
|
| 363 |
+
outputs=[custom_bg_upload, background_prompt, generate_bg_btn],
|
| 364 |
+
queue=False,
|
| 365 |
+
concurrency_limit=4,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Small, immediate state update; no queue
|
| 369 |
+
custom_bg_upload.change(
|
| 370 |
+
fn=handle_custom_background_upload,
|
| 371 |
+
inputs=[custom_bg_upload],
|
| 372 |
+
outputs=[background_image_state, background_status],
|
| 373 |
+
queue=False,
|
| 374 |
+
concurrency_limit=2,
|
| 375 |
+
).then(
|
| 376 |
+
fn=lambda img: img,
|
| 377 |
+
inputs=[background_image_state],
|
| 378 |
+
outputs=[background_preview],
|
| 379 |
+
queue=False,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Copy to job dir + probe video info; keep queued but single flight
|
| 383 |
+
video_upload.change(
|
| 384 |
+
fn=handle_video_upload,
|
| 385 |
+
inputs=[video_upload],
|
| 386 |
+
outputs=[video_path_state, video_status],
|
| 387 |
+
queue=True,
|
| 388 |
+
concurrency_limit=1,
|
| 389 |
+
).then(
|
| 390 |
+
fn=handle_preview_generation,
|
| 391 |
+
inputs=[video_path_state],
|
| 392 |
+
outputs=[video_preview, preview_status],
|
| 393 |
+
queue=False,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Background generation can be heavier; single-flight
|
| 397 |
+
generate_bg_btn.click(
|
| 398 |
+
fn=handle_background_generation,
|
| 399 |
+
inputs=[background_type, background_prompt, video_path_state],
|
| 400 |
+
outputs=[background_image_state, background_status],
|
| 401 |
+
queue=True,
|
| 402 |
+
concurrency_limit=1,
|
| 403 |
+
).then(
|
| 404 |
+
fn=lambda img: img,
|
| 405 |
+
inputs=[background_image_state],
|
| 406 |
+
outputs=[background_preview],
|
| 407 |
+
queue=False,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# The heavy pipeline β single-flight
|
| 411 |
+
process_btn.click(
|
| 412 |
+
fn=handle_video_processing,
|
| 413 |
+
inputs=[
|
| 414 |
+
video_path_state,
|
| 415 |
+
background_image_state,
|
| 416 |
+
background_type,
|
| 417 |
+
background_prompt
|
| 418 |
+
],
|
| 419 |
+
outputs=[result_video, processing_status],
|
| 420 |
+
queue=True,
|
| 421 |
+
concurrency_limit=1,
|
| 422 |
+
).then(
|
| 423 |
+
# Wire the download button (set value=path and visibility)
|
| 424 |
+
fn=lambda path: gr.update(value=path, visible=bool(path)),
|
| 425 |
+
inputs=[result_video],
|
| 426 |
+
outputs=[download_btn],
|
| 427 |
+
queue=False,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return demo
|
VideoBackgroundReplacer2/update_pins.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
update_pins.py
|
| 4 |
+
- Fetch newest SHAs (release tag or default branch) for SAM2 + MatAnyone
|
| 5 |
+
- Update ARG lines in Dockerfile: SAM2_SHA / MATANYONE_SHA
|
| 6 |
+
- Supports dry-run and manual pins
|
| 7 |
+
- Uses GitHub API; set GITHUB_TOKEN to avoid rate limits (optional)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
import sys
|
| 13 |
+
import json
|
| 14 |
+
import argparse
|
| 15 |
+
from urllib.parse import urlparse
|
| 16 |
+
import requests
|
| 17 |
+
from datetime import datetime, timezone
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
|
| 20 |
+
DOCKERFILE_PATH = "Dockerfile"
|
| 21 |
+
|
| 22 |
+
# Default repos (must match your Dockerfile ARGs)
|
| 23 |
+
SAM2_REPO_URL = "https://github.com/facebookresearch/segment-anything-2"
|
| 24 |
+
MATANY_REPO_URL = "https://github.com/pq-yang/MatAnyone"
|
| 25 |
+
|
| 26 |
+
SESSION = requests.Session()
|
| 27 |
+
if os.getenv("GITHUB_TOKEN"):
|
| 28 |
+
SESSION.headers.update({"Authorization": f"Bearer {os.environ['GITHUB_TOKEN']}"})
|
| 29 |
+
SESSION.headers.update({
|
| 30 |
+
"Accept": "application/vnd.github+json",
|
| 31 |
+
"User-Agent": "update-pins-script"
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
def gh_owner_repo(repo_url: str):
|
| 35 |
+
p = urlparse(repo_url)
|
| 36 |
+
parts = p.path.strip("/").split("/")
|
| 37 |
+
if len(parts) < 2:
|
| 38 |
+
raise ValueError(f"Invalid repo URL: {repo_url}")
|
| 39 |
+
return parts[0], parts[1]
|
| 40 |
+
|
| 41 |
+
def gh_api(path: str):
|
| 42 |
+
url = f"https://api.github.com{path}"
|
| 43 |
+
r = SESSION.get(url, timeout=30)
|
| 44 |
+
if r.status_code >= 400:
|
| 45 |
+
raise RuntimeError(f"GitHub API error {r.status_code}: {r.text}")
|
| 46 |
+
return r.json()
|
| 47 |
+
|
| 48 |
+
def get_latest_release_sha(repo_url: str) -> tuple[str, str]:
|
| 49 |
+
"""Return (ref_desc, commit_sha) using latest release tag."""
|
| 50 |
+
owner, repo = gh_owner_repo(repo_url)
|
| 51 |
+
try:
|
| 52 |
+
rel = gh_api(f"/repos/{owner}/{repo}/releases/latest")
|
| 53 |
+
tag = rel["tag_name"]
|
| 54 |
+
# Resolve tag to commit
|
| 55 |
+
ref = gh_api(f"/repos/{owner}/{repo}/git/ref/tags/{tag}")
|
| 56 |
+
obj = ref["object"]
|
| 57 |
+
if obj["type"] == "tag":
|
| 58 |
+
tag_obj = gh_api(f"/repos/{owner}/{repo}/git/tags/{obj['sha']}")
|
| 59 |
+
sha = tag_obj["object"]["sha"]
|
| 60 |
+
else:
|
| 61 |
+
sha = obj["sha"]
|
| 62 |
+
return (f"release:{tag}", sha)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
raise RuntimeError(f"Could not get latest release for {repo}: {e}")
|
| 65 |
+
|
| 66 |
+
def get_latest_default_branch_sha(repo_url: str) -> tuple[str, str]:
|
| 67 |
+
"""Return (ref_desc, commit_sha) using the default branch head."""
|
| 68 |
+
owner, repo = gh_owner_repo(repo_url)
|
| 69 |
+
info = gh_api(f"/repos/{owner}/{repo}")
|
| 70 |
+
default_branch = info["default_branch"]
|
| 71 |
+
branch = gh_api(f"/repos/{owner}/{repo}/branches/{default_branch}")
|
| 72 |
+
sha = branch["commit"]["sha"]
|
| 73 |
+
return (f"branch:{default_branch}", sha)
|
| 74 |
+
|
| 75 |
+
def get_sha_for_ref(repo_url: str, ref: str) -> tuple[str, str]:
|
| 76 |
+
"""
|
| 77 |
+
Resolve any Git ref (branch name, tag name, or commit SHA) to a commit SHA.
|
| 78 |
+
"""
|
| 79 |
+
owner, repo = gh_owner_repo(repo_url)
|
| 80 |
+
# If it's already a full SHA, just return it
|
| 81 |
+
if re.fullmatch(r"[0-9a-f]{40}", ref):
|
| 82 |
+
return (f"commit:{ref[:7]}", ref)
|
| 83 |
+
# Try branches/<ref>, then tags/<ref>, then commits/<ref>
|
| 84 |
+
for kind, path in [
|
| 85 |
+
("branch", f"/repos/{owner}/{repo}/branches/{ref}"),
|
| 86 |
+
("tag", f"/repos/{owner}/{repo}/git/ref/tags/{ref}"),
|
| 87 |
+
("commit", f"/repos/{owner}/{repo}/commits/{ref}")
|
| 88 |
+
]:
|
| 89 |
+
try:
|
| 90 |
+
data = gh_api(path)
|
| 91 |
+
if kind == "branch":
|
| 92 |
+
return (f"branch:{ref}", data["commit"]["sha"])
|
| 93 |
+
if kind == "tag":
|
| 94 |
+
obj = data["object"]
|
| 95 |
+
if obj["type"] == "tag":
|
| 96 |
+
tag_obj = gh_api(f"/repos/{owner}/{repo}/git/tags/{obj['sha']}")
|
| 97 |
+
return (f"tag:{ref}", tag_obj["object"]["sha"])
|
| 98 |
+
else:
|
| 99 |
+
return (f"tag:{ref}", obj["sha"])
|
| 100 |
+
if kind == "commit":
|
| 101 |
+
return (f"commit:{ref[:7]}", data["sha"])
|
| 102 |
+
except Exception:
|
| 103 |
+
continue
|
| 104 |
+
raise RuntimeError(f"Could not resolve ref '{ref}' for {repo}")
|
| 105 |
+
|
| 106 |
+
def update_dockerfile_arg(dockerfile_text: str, arg_name: str, new_value: str) -> str:
|
| 107 |
+
"""
|
| 108 |
+
Replace a line like:
|
| 109 |
+
ARG SAM2_SHA=...
|
| 110 |
+
with:
|
| 111 |
+
ARG SAM2_SHA=<new_value>
|
| 112 |
+
"""
|
| 113 |
+
pattern = rf"^(ARG\s+{re.escape(arg_name)}=).*$"
|
| 114 |
+
|
| 115 |
+
# Use a callable replacement to avoid backreference ambiguity (e.g., \12)
|
| 116 |
+
def repl(m: re.Match) -> str:
|
| 117 |
+
return m.group(1) + new_value
|
| 118 |
+
|
| 119 |
+
new_text, n = re.subn(pattern, repl, dockerfile_text, flags=re.MULTILINE)
|
| 120 |
+
if n == 0:
|
| 121 |
+
raise RuntimeError(f"ARG {arg_name}=β¦ line not found in Dockerfile.")
|
| 122 |
+
return new_text
|
| 123 |
+
|
| 124 |
+
def main():
|
| 125 |
+
ap = argparse.ArgumentParser(description="Update pinned SHAs in Dockerfile.")
|
| 126 |
+
ap.add_argument("--mode", choices=["release", "default-branch"], default="release",
|
| 127 |
+
help="Where to pull pins from (latest GitHub release tag or default branch head).")
|
| 128 |
+
ap.add_argument("--sam2-ref", help="Explicit ref for SAM2 (tag/branch/sha). Overrides --mode.")
|
| 129 |
+
ap.add_argument("--matany-ref", help="Explicit ref for MatAnyone (tag/branch/sha). Overrides --mode.")
|
| 130 |
+
ap.add_argument("--dockerfile", default=DOCKERFILE_PATH, help="Path to Dockerfile.")
|
| 131 |
+
ap.add_argument("--dry-run", action="store_true", help="Show changes but do not write file.")
|
| 132 |
+
ap.add_argument("--json", action="store_true", help="Print resulting pins as JSON.")
|
| 133 |
+
ap.add_argument("--no-backup", action="store_true", help="Do not create a Dockerfile.bak backup.")
|
| 134 |
+
args = ap.parse_args()
|
| 135 |
+
|
| 136 |
+
# Resolve SHAs
|
| 137 |
+
if args.sam2_ref:
|
| 138 |
+
sam2_refdesc, sam2_sha = get_sha_for_ref(SAM2_REPO_URL, args.sam2_ref)
|
| 139 |
+
else:
|
| 140 |
+
sam2_refdesc, sam2_sha = (
|
| 141 |
+
get_latest_release_sha(SAM2_REPO_URL) if args.mode == "release"
|
| 142 |
+
else get_latest_default_branch_sha(SAM2_REPO_URL)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if args.matany_ref:
|
| 146 |
+
mat_refdesc, mat_sha = get_sha_for_ref(MATANY_REPO_URL, args.matany_ref)
|
| 147 |
+
else:
|
| 148 |
+
mat_refdesc, mat_sha = (
|
| 149 |
+
get_latest_release_sha(MATANY_REPO_URL) if args.mode == "release"
|
| 150 |
+
else get_latest_default_branch_sha(MATANY_REPO_URL)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
result = {
|
| 154 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 155 |
+
"mode": args.mode,
|
| 156 |
+
"SAM2": {"repo": SAM2_REPO_URL, "ref": sam2_refdesc, "sha": sam2_sha},
|
| 157 |
+
"MatAnyone": {"repo": MATANY_REPO_URL, "ref": mat_refdesc, "sha": mat_sha},
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Show pins
|
| 161 |
+
if args.json:
|
| 162 |
+
print(json.dumps(result, indent=2))
|
| 163 |
+
else:
|
| 164 |
+
print(f"[Pins] SAM2 -> {sam2_refdesc} -> {sam2_sha}")
|
| 165 |
+
print(f"[Pins] MatAnyone -> {mat_refdesc} -> {mat_sha}")
|
| 166 |
+
|
| 167 |
+
# Read Dockerfile
|
| 168 |
+
if not os.path.isfile(args.dockerfile):
|
| 169 |
+
raise FileNotFoundError(f"Dockerfile not found at: {args.dockerfile}")
|
| 170 |
+
with open(args.dockerfile, "r", encoding="utf-8") as f:
|
| 171 |
+
text = f.read()
|
| 172 |
+
|
| 173 |
+
# Update lines
|
| 174 |
+
text = update_dockerfile_arg(text, "SAM2_SHA", sam2_sha)
|
| 175 |
+
text = update_dockerfile_arg(text, "MATANYONE_SHA", mat_sha)
|
| 176 |
+
|
| 177 |
+
if args.dry_run:
|
| 178 |
+
print("\n--- Dockerfile (preview) ---\n")
|
| 179 |
+
print(text)
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
# Backup
|
| 183 |
+
if not args.no_backup:
|
| 184 |
+
copyfile(args.dockerfile, args.dockerfile + ".bak")
|
| 185 |
+
|
| 186 |
+
# Write
|
| 187 |
+
with open(args.dockerfile, "w", encoding="utf-8") as f:
|
| 188 |
+
f.write(text)
|
| 189 |
+
|
| 190 |
+
print(f"\nβ
Updated {args.dockerfile} with new pins.")
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
try:
|
| 194 |
+
main()
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"\nβ Error: {e}", file=sys.stderr)
|
| 197 |
+
sys.exit(1)
|
VideoBackgroundReplacer2/utils/__init__.py
ADDED
|
File without changes
|
VideoBackgroundReplacer2/utils/paths.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/paths.py
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os, re, uuid, shutil
|
| 4 |
+
|
| 5 |
+
APP_ROOT = Path(__file__).resolve().parents[1]
|
| 6 |
+
DATA_ROOT = APP_ROOT / "data"
|
| 7 |
+
TMP_ROOT = APP_ROOT / "tmp"
|
| 8 |
+
for p in (DATA_ROOT, TMP_ROOT, APP_ROOT / ".hf", APP_ROOT / ".torch"):
|
| 9 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 10 |
+
|
| 11 |
+
os.environ.setdefault("HF_HOME", str(APP_ROOT / ".hf"))
|
| 12 |
+
os.environ.setdefault("TORCH_HOME", str(APP_ROOT / ".torch"))
|
| 13 |
+
|
| 14 |
+
def safe_name(name: str, default="file"):
|
| 15 |
+
base = re.sub(r"[^A-Za-z0-9._-]+", "_", (name or default))
|
| 16 |
+
return (base or default)[:120]
|
| 17 |
+
|
| 18 |
+
def job_dir(prefix="job"):
|
| 19 |
+
d = DATA_ROOT / f"{prefix}-{uuid.uuid4().hex[:8]}"
|
| 20 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
return d
|
| 22 |
+
|
| 23 |
+
def disk_stats(p: Path = APP_ROOT) -> str:
|
| 24 |
+
try:
|
| 25 |
+
total, used, free = shutil.disk_usage(str(p))
|
| 26 |
+
mb = lambda x: x // (1024 * 1024)
|
| 27 |
+
return f"disk(total={mb(total)}MB, used={mb(used)}MB, free={mb(free)}MB)"
|
| 28 |
+
except Exception:
|
| 29 |
+
return "disk(n/a)"
|
VideoBackgroundReplacer2/utils/perf_tuning.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/perf_tuning.py
|
| 2 |
+
import os, logging
|
| 3 |
+
try:
|
| 4 |
+
import cv2
|
| 5 |
+
except Exception:
|
| 6 |
+
cv2 = None
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def apply():
|
| 10 |
+
os.environ.setdefault("OMP_NUM_THREADS", "4")
|
| 11 |
+
if cv2:
|
| 12 |
+
try:
|
| 13 |
+
cv2.setNumThreads(4)
|
| 14 |
+
except Exception as e:
|
| 15 |
+
logging.info("cv2 threads not set: %s", e)
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
torch.backends.cudnn.benchmark = True
|
| 18 |
+
try:
|
| 19 |
+
logging.info("CUDA device %s β cuDNN benchmark ON", torch.cuda.get_device_name(0))
|
| 20 |
+
except Exception:
|
| 21 |
+
logging.info("CUDA available β cuDNN benchmark ON")
|
app.py
CHANGED
|
@@ -1,300 +1,570 @@
|
|
| 1 |
-
|
| 2 |
-
"""
|
| 3 |
-
VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
|
| 4 |
-
=======================================================
|
| 5 |
-
- Sets up Gradio UI and launches pipeline
|
| 6 |
-
- Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
|
| 7 |
-
|
| 8 |
-
Changes (2025-09-18):
|
| 9 |
-
- Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
|
| 10 |
-
- Added toggleable "mount mode": run Gradio inside our own FastAPI app
|
| 11 |
-
and provide a safe /config route shim (uses demo.get_config_file()).
|
| 12 |
-
- Kept your startup diagnostics, GPU logging, and heartbeats
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
|
| 16 |
-
|
| 17 |
-
# ---------------------------------------------------------------------
|
| 18 |
-
# Imports & basic setup
|
| 19 |
-
# ---------------------------------------------------------------------
|
| 20 |
-
import sys
|
| 21 |
import os
|
| 22 |
-
import
|
| 23 |
-
import
|
| 24 |
-
import logging
|
| 25 |
-
import threading
|
| 26 |
import time
|
| 27 |
-
import warnings
|
| 28 |
-
import traceback
|
| 29 |
-
import subprocess
|
| 30 |
from pathlib import Path
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
def
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
"
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
"env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
|
| 128 |
}
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
#
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
# ---------------------------------------------------------------------
|
| 236 |
-
def build_fastapi_with_gradio(demo: gr.Blocks):
|
| 237 |
"""
|
| 238 |
-
|
| 239 |
-
|
| 240 |
"""
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
# ---------------------------------------------------------------------
|
| 266 |
-
# Entrypoint
|
| 267 |
-
# ---------------------------------------------------------------------
|
| 268 |
if __name__ == "__main__":
|
| 269 |
-
|
| 270 |
-
port = int(os.environ.get("PORT", "7860"))
|
| 271 |
-
mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
|
| 272 |
-
|
| 273 |
-
logger.info("Launching on %s:%s (mount_mode=%s)β¦", host, port, mount_mode)
|
| 274 |
-
_log_web_stack_versions_and_paths()
|
| 275 |
-
|
| 276 |
-
demo = build_ui()
|
| 277 |
-
demo.queue(max_size=16, api_open=False)
|
| 278 |
-
|
| 279 |
-
threading.Thread(target=_post_launch_diag, daemon=True).start()
|
| 280 |
-
|
| 281 |
-
if mount_mode:
|
| 282 |
-
try:
|
| 283 |
-
from uvicorn import run as uvicorn_run
|
| 284 |
-
except Exception:
|
| 285 |
-
logger.error("uvicorn is not installed; mount mode cannot start.")
|
| 286 |
-
raise
|
| 287 |
-
|
| 288 |
-
app = build_fastapi_with_gradio(demo)
|
| 289 |
-
uvicorn_run(app=app, host=host, port=port, log_level="info")
|
| 290 |
-
else:
|
| 291 |
-
demo.launch(
|
| 292 |
-
server_name=host,
|
| 293 |
-
server_port=port,
|
| 294 |
-
share=False,
|
| 295 |
-
show_api=False,
|
| 296 |
-
show_error=True,
|
| 297 |
-
quiet=False,
|
| 298 |
-
debug=True,
|
| 299 |
-
max_threads=1,
|
| 300 |
-
)
|
|
|
|
| 1 |
+
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
+
import sys
|
| 4 |
+
import tempfile
|
|
|
|
|
|
|
| 5 |
import time
|
|
|
|
|
|
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import logging
|
| 11 |
+
import base64
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
|
| 14 |
+
# Add project root to path
|
| 15 |
+
sys.path.append(str(Path(__file__).parent.absolute()))
|
| 16 |
+
|
| 17 |
+
# Configure logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Set page config
|
| 22 |
+
st.set_page_config(
|
| 23 |
+
page_title="MyAvatar - Video Background Replacer",
|
| 24 |
+
page_icon="π₯",
|
| 25 |
+
layout="wide",
|
| 26 |
+
initial_sidebar_state="expanded"
|
| 27 |
)
|
| 28 |
|
| 29 |
+
# Custom CSS for better UI with logo
|
| 30 |
+
def add_logo():
|
| 31 |
+
st.markdown(
|
| 32 |
+
"""
|
| 33 |
+
<style>
|
| 34 |
+
.main .block-container {
|
| 35 |
+
padding-top: 2rem;
|
| 36 |
+
padding-bottom: 2rem;
|
| 37 |
+
}
|
| 38 |
+
.stButton>button {
|
| 39 |
+
width: 100%;
|
| 40 |
+
background-color: #4CAF50;
|
| 41 |
+
color: white;
|
| 42 |
+
font-weight: bold;
|
| 43 |
+
transition: all 0.3s;
|
| 44 |
+
}
|
| 45 |
+
.stButton>button:hover {
|
| 46 |
+
background-color: #45a049;
|
| 47 |
+
}
|
| 48 |
+
.stProgress > div > div > div > div {
|
| 49 |
+
background-color: #4CAF50;
|
| 50 |
+
}
|
| 51 |
+
.stAlert {
|
| 52 |
+
border-radius: 10px;
|
| 53 |
+
}
|
| 54 |
+
.stTabs [data-baseweb="tab-list"] {
|
| 55 |
+
gap: 10px;
|
| 56 |
+
}
|
| 57 |
+
.stTabs [data-baseweb="tab"] {
|
| 58 |
+
height: 50px;
|
| 59 |
+
white-space: pre;
|
| 60 |
+
background-color: #f0f2f6;
|
| 61 |
+
border-radius: 4px 4px 0 0;
|
| 62 |
+
padding: 10px 20px;
|
| 63 |
+
margin-right: 5px;
|
| 64 |
+
}
|
| 65 |
+
.stTabs [aria-selected="true"] {
|
| 66 |
+
background-color: #4CAF50;
|
| 67 |
+
color: white;
|
| 68 |
+
}
|
| 69 |
+
.video-container {
|
| 70 |
+
border: 2px dashed #4CAF50;
|
| 71 |
+
border-radius: 10px;
|
| 72 |
+
padding: 10px;
|
| 73 |
+
margin-bottom: 20px;
|
| 74 |
+
}
|
| 75 |
+
.logo-container {
|
| 76 |
+
display: flex;
|
| 77 |
+
justify-content: flex-end;
|
| 78 |
+
padding: 10px 20px 0 0;
|
| 79 |
+
}
|
| 80 |
+
.logo {
|
| 81 |
+
height: 50px;
|
| 82 |
+
width: auto;
|
| 83 |
+
margin-bottom: -20px;
|
| 84 |
+
}
|
| 85 |
+
.title-container {
|
| 86 |
+
text-align: center;
|
| 87 |
+
margin-bottom: 30px;
|
| 88 |
+
}
|
| 89 |
+
.color-swatch {
|
| 90 |
+
display: inline-block;
|
| 91 |
+
width: 30px;
|
| 92 |
+
height: 30px;
|
| 93 |
+
margin: 5px;
|
| 94 |
+
border: 2px solid #ddd;
|
| 95 |
+
border-radius: 4px;
|
| 96 |
+
cursor: pointer;
|
| 97 |
+
transition: transform 0.2s;
|
| 98 |
+
}
|
| 99 |
+
.color-swatch:hover {
|
| 100 |
+
transform: scale(1.1);
|
| 101 |
+
border-color: #4CAF50;
|
| 102 |
+
}
|
| 103 |
+
.background-option {
|
| 104 |
+
padding: 10px;
|
| 105 |
+
margin: 5px 0;
|
| 106 |
+
border-radius: 5px;
|
| 107 |
+
background-color: #f8f9fa;
|
| 108 |
+
border-left: 4px solid #4CAF50;
|
| 109 |
+
}
|
| 110 |
+
</style>
|
| 111 |
+
""",
|
| 112 |
+
unsafe_allow_html=True
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Add logo to the top right
|
| 116 |
+
st.markdown(
|
| 117 |
+
"""
|
| 118 |
+
<div class="logo-container">
|
| 119 |
+
<img src="data:image/png;base64,{}" class="logo">
|
| 120 |
+
</div>
|
| 121 |
+
""".format(base64.b64encode(open("myavatar_logo.png", "rb").read()).decode()),
|
| 122 |
+
unsafe_allow_html=True
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def initialize_session_state():
|
| 126 |
+
"""Initialize all session state variables"""
|
| 127 |
+
if 'uploaded_video' not in st.session_state:
|
| 128 |
+
st.session_state.uploaded_video = None
|
| 129 |
+
if 'bg_image' not in st.session_state:
|
| 130 |
+
st.session_state.bg_image = None
|
| 131 |
+
if 'bg_color' not in st.session_state:
|
| 132 |
+
st.session_state.bg_color = "#00FF00"
|
| 133 |
+
if 'bg_type' not in st.session_state:
|
| 134 |
+
st.session_state.bg_type = "image"
|
| 135 |
+
if 'processed_video_path' not in st.session_state:
|
| 136 |
+
st.session_state.processed_video_path = None
|
| 137 |
+
if 'processing' not in st.session_state:
|
| 138 |
+
st.session_state.processing = False
|
| 139 |
+
if 'progress' not in st.session_state:
|
| 140 |
+
st.session_state.progress = 0
|
| 141 |
+
if 'progress_text' not in st.session_state:
|
| 142 |
+
st.session_state.progress_text = "Ready"
|
| 143 |
+
|
| 144 |
+
def handle_video_upload():
|
| 145 |
+
"""Handle video file upload"""
|
| 146 |
+
uploaded = st.file_uploader(
|
| 147 |
+
"πΉ Upload Video",
|
| 148 |
+
type=["mp4", "mov", "avi"],
|
| 149 |
+
key="video_uploader"
|
| 150 |
+
)
|
| 151 |
+
if uploaded is not None:
|
| 152 |
+
st.session_state.uploaded_video = uploaded
|
| 153 |
+
|
| 154 |
+
def show_video_preview():
|
| 155 |
+
"""Show video preview in the UI"""
|
| 156 |
+
st.markdown("### Video Preview")
|
| 157 |
+
if st.session_state.uploaded_video is not None:
|
| 158 |
+
video_bytes = st.session_state.uploaded_video.getvalue()
|
| 159 |
+
st.video(video_bytes)
|
| 160 |
+
st.session_state.uploaded_video.seek(0)
|
| 161 |
+
|
| 162 |
+
def handle_background_selection():
|
| 163 |
+
"""Handle background selection UI with all options"""
|
| 164 |
+
st.markdown("### Background Options")
|
| 165 |
+
|
| 166 |
+
# Background type selection
|
| 167 |
+
bg_type = st.radio(
|
| 168 |
+
"Select Background Type:",
|
| 169 |
+
["Image", "Color", "Blur", "Professional Backgrounds", "AI Generated"],
|
| 170 |
+
horizontal=True,
|
| 171 |
+
key="bg_type_radio"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
st.session_state.bg_type = bg_type.lower()
|
| 175 |
+
|
| 176 |
+
# Show appropriate controls based on selection
|
| 177 |
+
if bg_type == "Image":
|
| 178 |
+
handle_image_background()
|
| 179 |
+
elif bg_type == "Color":
|
| 180 |
+
handle_color_background()
|
| 181 |
+
elif bg_type == "Blur":
|
| 182 |
+
handle_blur_background()
|
| 183 |
+
elif bg_type == "Professional Backgrounds":
|
| 184 |
+
handle_professional_backgrounds()
|
| 185 |
+
elif bg_type == "AI Generated":
|
| 186 |
+
handle_ai_generated_background()
|
| 187 |
+
|
| 188 |
+
def handle_image_background():
|
| 189 |
+
"""Handle image background selection"""
|
| 190 |
+
bg_image = st.file_uploader(
|
| 191 |
+
"πΌοΈ Upload Background Image",
|
| 192 |
+
type=["jpg", "png", "jpeg"],
|
| 193 |
+
key="bg_image_uploader"
|
| 194 |
+
)
|
| 195 |
+
if bg_image is not None:
|
| 196 |
+
st.session_state.bg_image = Image.open(bg_image)
|
| 197 |
+
st.image(
|
| 198 |
+
st.session_state.bg_image,
|
| 199 |
+
caption="Selected Background",
|
| 200 |
+
use_container_width=True
|
| 201 |
+
)
|
| 202 |
|
| 203 |
+
def handle_color_background():
|
| 204 |
+
"""Handle color background selection with presets"""
|
| 205 |
+
st.markdown("#### Select a Color")
|
| 206 |
+
|
| 207 |
+
# Color presets
|
| 208 |
+
color_presets = {
|
| 209 |
+
"Pure White": "#FFFFFF",
|
| 210 |
+
"Pure Black": "#000000",
|
| 211 |
+
"Light Gray": "#F5F5F5",
|
| 212 |
+
"Dark Gray": "#333333",
|
| 213 |
+
"Professional Blue": "#0078D4",
|
| 214 |
+
"Corporate Green": "#107C10",
|
| 215 |
+
"Warm Beige": "#F5F5DC",
|
| 216 |
+
"Custom": st.session_state.get('bg_color', "#00FF00")
|
|
|
|
| 217 |
}
|
| 218 |
+
|
| 219 |
+
# Create color swatches
|
| 220 |
+
cols = st.columns(4)
|
| 221 |
+
selected_color = None
|
| 222 |
+
|
| 223 |
+
for i, (name, color) in enumerate(color_presets.items()):
|
| 224 |
+
with cols[i % 4]:
|
| 225 |
+
if name == "Custom":
|
| 226 |
+
# Show color picker for custom color
|
| 227 |
+
st.session_state.bg_color = st.color_picker(
|
| 228 |
+
"Custom Color",
|
| 229 |
+
st.session_state.get('bg_color', "#00FF00"),
|
| 230 |
+
key="custom_color_picker"
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
# Create a color swatch
|
| 234 |
+
if st.button(
|
| 235 |
+
"",
|
| 236 |
+
key=f"color_{name}",
|
| 237 |
+
help=name,
|
| 238 |
+
type="secondary",
|
| 239 |
+
use_container_width=True
|
| 240 |
+
):
|
| 241 |
+
st.session_state.bg_color = color
|
| 242 |
+
|
| 243 |
+
# Show the color preview
|
| 244 |
+
st.markdown(
|
| 245 |
+
f'<div style="background-color:{color}; height:30px; border-radius:4px; margin-top:-10px;"></div>',
|
| 246 |
+
unsafe_allow_html=True
|
| 247 |
+
)
|
| 248 |
+
st.caption(name)
|
| 249 |
+
|
| 250 |
+
def handle_blur_background():
|
| 251 |
+
"""Handle blur background selection"""
|
| 252 |
+
blur_strength = st.select_slider(
|
| 253 |
+
"Blur Strength:",
|
| 254 |
+
options=["Subtle", "Medium", "Strong", "Maximum"],
|
| 255 |
+
value="Medium",
|
| 256 |
+
key="blur_strength"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Show preview of blur effect
|
| 260 |
+
st.markdown("**Preview**")
|
| 261 |
+
preview_img = np.zeros((100, 200, 3), dtype=np.uint8)
|
| 262 |
+
cv2.putText(
|
| 263 |
+
preview_img,
|
| 264 |
+
"Blur Effect",
|
| 265 |
+
(20, 50),
|
| 266 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 267 |
+
0.8,
|
| 268 |
+
(255, 255, 255),
|
| 269 |
+
2
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Apply blur based on selection
|
| 273 |
+
if blur_strength == "Subtle":
|
| 274 |
+
preview_img = cv2.GaussianBlur(preview_img, (15, 15), 5)
|
| 275 |
+
elif blur_strength == "Medium":
|
| 276 |
+
preview_img = cv2.GaussianBlur(preview_img, (25, 25), 10)
|
| 277 |
+
elif blur_strength == "Strong":
|
| 278 |
+
preview_img = cv2.GaussianBlur(preview_img, (35, 35), 15)
|
| 279 |
+
else: # Maximum
|
| 280 |
+
preview_img = cv2.GaussianBlur(preview_img, (51, 51), 20)
|
| 281 |
+
|
| 282 |
+
st.image(preview_img, use_column_width=True)
|
| 283 |
+
st.caption(f"Selected: {blur_strength} blur")
|
| 284 |
+
|
| 285 |
+
def handle_professional_backgrounds():
|
| 286 |
+
"""Handle professional background selection"""
|
| 287 |
+
categories = {
|
| 288 |
+
"Office Settings": ["Modern Office", "Executive Office", "Home Office", "Conference Room"],
|
| 289 |
+
"Virtual Backgrounds": ["Professional", "Minimalist", "Creative", "Branded"],
|
| 290 |
+
"Nature Scenes": ["Forest", "Beach", "Mountain", "City Skyline"],
|
| 291 |
+
"Abstract Designs": ["Gradient", "Geometric", "Particles", "Bokeh"]
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
# Category selection
|
| 295 |
+
selected_category = st.selectbox(
|
| 296 |
+
"Select Category:",
|
| 297 |
+
list(categories.keys()),
|
| 298 |
+
key="bg_category"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Show thumbnails for selected category
|
| 302 |
+
st.markdown("#### Available Backgrounds")
|
| 303 |
+
cols = st.columns(2)
|
| 304 |
+
|
| 305 |
+
for i, bg in enumerate(categories[selected_category]):
|
| 306 |
+
with cols[i % 2]:
|
| 307 |
+
# Create a placeholder image (replace with actual thumbnails)
|
| 308 |
+
img = np.zeros((120, 200, 3), dtype=np.uint8)
|
| 309 |
+
cv2.putText(
|
| 310 |
+
img,
|
| 311 |
+
bg,
|
| 312 |
+
(20, 60),
|
| 313 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 314 |
+
0.7,
|
| 315 |
+
(255, 255, 255),
|
| 316 |
+
2
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if st.button(
|
| 320 |
+
f"Use {bg}",
|
| 321 |
+
key=f"prof_bg_{bg}",
|
| 322 |
+
use_container_width=True
|
| 323 |
+
):
|
| 324 |
+
st.session_state.selected_bg = bg
|
| 325 |
+
st.success(f"Selected: {bg}")
|
| 326 |
+
|
| 327 |
+
st.image(img, use_column_width=True)
|
| 328 |
+
|
| 329 |
+
def handle_ai_generated_background():
|
| 330 |
+
"""Handle AI generated background selection"""
|
| 331 |
+
ai_prompts = [
|
| 332 |
+
"Professional office with bookshelf",
|
| 333 |
+
"Modern co-working space",
|
| 334 |
+
"Neutral abstract background",
|
| 335 |
+
"City skyline at night",
|
| 336 |
+
"Minimalist home office setup",
|
| 337 |
+
"Corporate meeting room",
|
| 338 |
+
"Creative studio background",
|
| 339 |
+
"Custom prompt..."
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
# Prompt selection
|
| 343 |
+
selected_prompt = st.selectbox(
|
| 344 |
+
"Select a prompt or create your own:",
|
| 345 |
+
ai_prompts,
|
| 346 |
+
key="ai_prompt_select"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if selected_prompt == "Custom prompt...":
|
| 350 |
+
custom_prompt = st.text_input(
|
| 351 |
+
"Enter your custom prompt:",
|
| 352 |
+
key="custom_ai_prompt"
|
| 353 |
)
|
| 354 |
+
if custom_prompt:
|
| 355 |
+
selected_prompt = custom_prompt
|
| 356 |
+
|
| 357 |
+
# Generate button
|
| 358 |
+
if st.button(
|
| 359 |
+
"οΏ½οΏ½οΏ½ Generate Background",
|
| 360 |
+
key="generate_ai_bg",
|
| 361 |
+
use_container_width=True
|
| 362 |
+
):
|
| 363 |
+
with st.spinner(f"Generating '{selected_prompt}'..."):
|
| 364 |
+
# Simulate generation
|
| 365 |
+
time.sleep(2)
|
| 366 |
+
|
| 367 |
+
# Create a placeholder for the generated image
|
| 368 |
+
img = np.zeros((300, 500, 3), dtype=np.uint8)
|
| 369 |
+
cv2.putText(
|
| 370 |
+
img,
|
| 371 |
+
f"AI Generated:\n{selected_prompt}",
|
| 372 |
+
(50, 150),
|
| 373 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 374 |
+
0.8,
|
| 375 |
+
(255, 255, 255),
|
| 376 |
+
2,
|
| 377 |
+
cv2.LINE_AA
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Show the "generated" image
|
| 381 |
+
st.image(img, use_column_width=True)
|
| 382 |
+
|
| 383 |
+
# Add use button
|
| 384 |
+
if st.button(
|
| 385 |
+
"Use This Background",
|
| 386 |
+
key="use_ai_bg",
|
| 387 |
+
use_container_width=True
|
| 388 |
+
):
|
| 389 |
+
st.session_state.bg_image = Image.fromarray(img)
|
| 390 |
+
st.success("Background selected!")
|
| 391 |
+
|
| 392 |
+
def process_video(input_file, background, bg_type="image"):
|
|
|
|
|
|
|
| 393 |
"""
|
| 394 |
+
Mock video processing that works without SAM2/MatA
|
| 395 |
+
Just applies a simple effect to simulate background replacement
|
| 396 |
"""
|
| 397 |
+
try:
|
| 398 |
+
# Create a temporary directory for processing
|
| 399 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 400 |
+
temp_dir = Path(temp_dir)
|
| 401 |
+
|
| 402 |
+
# Save the uploaded video to a temporary file
|
| 403 |
+
input_path = str(temp_dir / "input.mp4")
|
| 404 |
+
with open(input_path, "wb") as f:
|
| 405 |
+
f.write(input_file.getvalue())
|
| 406 |
+
|
| 407 |
+
# Set up progress bar
|
| 408 |
+
progress_bar = st.progress(0)
|
| 409 |
+
status_text = st.empty()
|
| 410 |
+
|
| 411 |
+
def update_progress(progress, message):
|
| 412 |
+
progress = max(0, min(1, progress))
|
| 413 |
+
progress_bar.progress(progress)
|
| 414 |
+
status_text.text(f"Status: {message}")
|
| 415 |
+
|
| 416 |
+
# Simulate processing steps
|
| 417 |
+
update_progress(0.1, "Loading video...")
|
| 418 |
+
time.sleep(1)
|
| 419 |
+
|
| 420 |
+
update_progress(0.3, "Processing frames...")
|
| 421 |
+
time.sleep(2)
|
| 422 |
+
|
| 423 |
+
# Create a simple output video that just adds a colored border
|
| 424 |
+
cap = cv2.VideoCapture(input_path)
|
| 425 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 426 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 427 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 428 |
+
|
| 429 |
+
output_path = str(temp_dir / "output.mp4")
|
| 430 |
+
fourcc = cv2.VideoWriter_fourcentCC(*'mp4v')
|
| 431 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 432 |
+
|
| 433 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 434 |
+
|
| 435 |
+
for i in range(frame_count):
|
| 436 |
+
ret, frame = cap.read()
|
| 437 |
+
if not ret:
|
| 438 |
+
break
|
| 439 |
+
|
| 440 |
+
# Simple effect: add a colored border to simulate processing
|
| 441 |
+
border_size = 20
|
| 442 |
+
if bg_type == "color":
|
| 443 |
+
color_hex = st.session_state.bg_color.lstrip('#')
|
| 444 |
+
color_bgr = tuple(int(color_hex[i:i+2], 16) for i in (4, 2, 0)) # RGB to BGR
|
| 445 |
+
else:
|
| 446 |
+
color_bgr = (0, 255, 0) # Default green border
|
| 447 |
+
|
| 448 |
+
frame = cv2.copyMakeBorder(
|
| 449 |
+
frame,
|
| 450 |
+
border_size, border_size, border_size, border_size,
|
| 451 |
+
cv2.BORDER_CONSTANT,
|
| 452 |
+
value=color_bgr
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# Resize back to original dimensions
|
| 456 |
+
frame = cv2.resize(frame, (width, height))
|
| 457 |
+
|
| 458 |
+
out.write(frame)
|
| 459 |
+
|
| 460 |
+
# Update progress
|
| 461 |
+
if i % 10 == 0:
|
| 462 |
+
update_progress(0.3 + 0.7 * (i/frame_count), f"Processing frame {i}/{frame_count}")
|
| 463 |
+
|
| 464 |
+
cap.release()
|
| 465 |
+
out.release()
|
| 466 |
+
|
| 467 |
+
update_progress(1.0, "Processing complete!")
|
| 468 |
+
time.sleep(0.5)
|
| 469 |
+
|
| 470 |
+
return output_path
|
| 471 |
+
|
| 472 |
+
except Exception as e:
|
| 473 |
+
logger.error(f"Error in mock video processing: {str(e)}", exc_info=True)
|
| 474 |
+
st.error(f"An error occurred during processing: {str(e)}")
|
| 475 |
+
return None
|
| 476 |
+
|
| 477 |
+
def main():
|
| 478 |
+
# Add custom CSS and logo
|
| 479 |
+
add_logo()
|
| 480 |
+
|
| 481 |
+
# Title
|
| 482 |
+
st.markdown(
|
| 483 |
+
"""
|
| 484 |
+
<div class="title-container">
|
| 485 |
+
<h1>π₯ Video Background Replacer</h1>
|
| 486 |
+
</div>
|
| 487 |
+
""",
|
| 488 |
+
unsafe_allow_html=True
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
st.markdown("---")
|
| 492 |
+
|
| 493 |
+
# Initialize session state
|
| 494 |
+
initialize_session_state()
|
| 495 |
+
|
| 496 |
+
# Main layout
|
| 497 |
+
col1, col2 = st.columns([1, 1], gap="large")
|
| 498 |
+
|
| 499 |
+
with col1:
|
| 500 |
+
st.header("1. Upload Video")
|
| 501 |
+
handle_video_upload()
|
| 502 |
+
show_video_preview()
|
| 503 |
+
|
| 504 |
+
with col2:
|
| 505 |
+
st.header("2. Background Settings")
|
| 506 |
+
handle_background_selection()
|
| 507 |
+
|
| 508 |
+
st.header("3. Process & Download")
|
| 509 |
+
if st.button(
|
| 510 |
+
"π Process Video",
|
| 511 |
+
type="primary",
|
| 512 |
+
disabled=not st.session_state.uploaded_video or st.session_state.processing,
|
| 513 |
+
use_container_width=True
|
| 514 |
+
):
|
| 515 |
+
with st.spinner("Processing video (this may take a few minutes)..."):
|
| 516 |
+
st.session_state.processing = True
|
| 517 |
+
|
| 518 |
+
try:
|
| 519 |
+
# Prepare background based on type
|
| 520 |
+
background = None
|
| 521 |
+
if st.session_state.bg_type == "image" and 'bg_image' in st.session_state and st.session_state.bg_image is not None:
|
| 522 |
+
background = st.session_state.bg_image
|
| 523 |
+
elif st.session_state.bg_type == "color" and 'bg_color' in st.session_state:
|
| 524 |
+
background = st.session_state.bg_color
|
| 525 |
+
|
| 526 |
+
# Process the video
|
| 527 |
+
output_path = process_video(
|
| 528 |
+
st.session_state.uploaded_video,
|
| 529 |
+
background,
|
| 530 |
+
bg_type=st.session_state.bg_type
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
if output_path and os.path.exists(output_path):
|
| 534 |
+
# Store the path to the processed video
|
| 535 |
+
st.session_state.processed_video_path = output_path
|
| 536 |
+
st.success("β
Video processing complete!")
|
| 537 |
+
else:
|
| 538 |
+
st.error("β Failed to process video. Please check the logs for details.")
|
| 539 |
+
|
| 540 |
+
except Exception as e:
|
| 541 |
+
st.error(f"β An error occurred: {str(e)}")
|
| 542 |
+
logger.exception("Video processing failed")
|
| 543 |
+
|
| 544 |
+
finally:
|
| 545 |
+
st.session_state.processing = False
|
| 546 |
+
|
| 547 |
+
# Show processed video if available
|
| 548 |
+
if 'processed_video_path' in st.session_state and st.session_state.processed_video_path:
|
| 549 |
+
st.markdown("### Processed Video")
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
# Display the video directly from the file
|
| 553 |
+
with open(st.session_state.processed_video_path, 'rb') as f:
|
| 554 |
+
video_bytes = f.read()
|
| 555 |
+
st.video(video_bytes)
|
| 556 |
+
|
| 557 |
+
# Download button
|
| 558 |
+
st.download_button(
|
| 559 |
+
label="πΎ Download Processed Video",
|
| 560 |
+
data=video_bytes,
|
| 561 |
+
file_name="processed_video.mp4",
|
| 562 |
+
mime="video/mp4",
|
| 563 |
+
use_container_width=True
|
| 564 |
+
)
|
| 565 |
+
except Exception as e:
|
| 566 |
+
st.error(f"Error displaying video: {str(e)}")
|
| 567 |
+
logger.error(f"Error displaying video: {str(e)}", exc_info=True)
|
| 568 |
|
|
|
|
|
|
|
|
|
|
| 569 |
if __name__ == "__main__":
|
| 570 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_backup.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
|
| 4 |
+
=======================================================
|
| 5 |
+
- Sets up Gradio UI and launches pipeline
|
| 6 |
+
- Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
|
| 7 |
+
|
| 8 |
+
Changes (2025-09-18):
|
| 9 |
+
- Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
|
| 10 |
+
- Added toggleable "mount mode": run Gradio inside our own FastAPI app
|
| 11 |
+
and provide a safe /config route shim (uses demo.get_config_file()).
|
| 12 |
+
- Kept your startup diagnostics, GPU logging, and heartbeats
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------
|
| 18 |
+
# Imports & basic setup
|
| 19 |
+
# ---------------------------------------------------------------------
|
| 20 |
+
import sys
|
| 21 |
+
import os
|
| 22 |
+
import gc
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import threading
|
| 26 |
+
import time
|
| 27 |
+
import warnings
|
| 28 |
+
import traceback
|
| 29 |
+
import subprocess
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from loguru import logger
|
| 32 |
+
|
| 33 |
+
# Logging (loguru to stderr)
|
| 34 |
+
logger.remove()
|
| 35 |
+
logger.add(
|
| 36 |
+
sys.stderr,
|
| 37 |
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> "
|
| 38 |
+
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Warnings
|
| 42 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 43 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 44 |
+
warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
|
| 45 |
+
|
| 46 |
+
# Environment (lightweight & safe in Spaces)
|
| 47 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 48 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 49 |
+
|
| 50 |
+
# Paths
|
| 51 |
+
BASE_DIR = Path(__file__).parent.absolute()
|
| 52 |
+
THIRD_PARTY_DIR = BASE_DIR / "third_party"
|
| 53 |
+
SAM2_DIR = THIRD_PARTY_DIR / "sam2"
|
| 54 |
+
CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
|
| 55 |
+
|
| 56 |
+
# Python path extends
|
| 57 |
+
for p in (str(THIRD_PARTY_DIR), str(SAM2_DIR)):
|
| 58 |
+
if p not in sys.path:
|
| 59 |
+
sys.path.insert(0, p)
|
| 60 |
+
|
| 61 |
+
logger.info(f"Base directory: {BASE_DIR}")
|
| 62 |
+
logger.info(f"Python path[0:5]: {sys.path[:5]}")
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------
|
| 65 |
+
# GPU / Torch diagnostics (non-blocking)
|
| 66 |
+
# ---------------------------------------------------------------------
|
| 67 |
+
try:
|
| 68 |
+
import torch
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.warning("Torch import failed at startup: %s", e)
|
| 71 |
+
torch = None
|
| 72 |
+
|
| 73 |
+
DEVICE = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
|
| 74 |
+
if DEVICE == "cuda":
|
| 75 |
+
os.environ["SAM2_DEVICE"] = "cuda"
|
| 76 |
+
os.environ["MATANY_DEVICE"] = "cuda"
|
| 77 |
+
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
|
| 78 |
+
try:
|
| 79 |
+
logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
|
| 80 |
+
except Exception:
|
| 81 |
+
logger.info("CUDA device name not available at startup.")
|
| 82 |
+
else:
|
| 83 |
+
os.environ["SAM2_DEVICE"] = "cpu"
|
| 84 |
+
os.environ["MATANY_DEVICE"] = "cpu"
|
| 85 |
+
logger.warning("CUDA not available, falling back to CPU")
|
| 86 |
+
|
| 87 |
+
def verify_models():
|
| 88 |
+
"""Verify critical model files exist and are loadable (cheap checks)."""
|
| 89 |
+
results = {"status": "success", "details": {}}
|
| 90 |
+
try:
|
| 91 |
+
sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
|
| 92 |
+
if not os.path.exists(sam2_model_path):
|
| 93 |
+
raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
|
| 94 |
+
# Cheap load test (map to CPU to avoid VRAM use during boot)
|
| 95 |
+
if torch:
|
| 96 |
+
sd = torch.load(sam2_model_path, map_location="cpu")
|
| 97 |
+
if not isinstance(sd, dict):
|
| 98 |
+
raise ValueError("Invalid SAM2 checkpoint format")
|
| 99 |
+
results["details"]["sam2"] = {
|
| 100 |
+
"status": "success",
|
| 101 |
+
"path": sam2_model_path,
|
| 102 |
+
"size_mb": round(os.path.getsize(sam2_model_path) / (1024 * 1024), 2),
|
| 103 |
+
}
|
| 104 |
+
except Exception as e:
|
| 105 |
+
results["status"] = "error"
|
| 106 |
+
results["details"]["sam2"] = {
|
| 107 |
+
"status": "error",
|
| 108 |
+
"error": str(e),
|
| 109 |
+
"traceback": traceback.format_exc(),
|
| 110 |
+
}
|
| 111 |
+
return results
|
| 112 |
+
|
| 113 |
+
def run_startup_diagnostics():
|
| 114 |
+
diag = {
|
| 115 |
+
"system": {
|
| 116 |
+
"python": sys.version,
|
| 117 |
+
"pytorch": getattr(torch, "__version__", None) if torch else None,
|
| 118 |
+
"cuda_available": bool(torch and torch.cuda.is_available()),
|
| 119 |
+
"device_count": (torch.cuda.device_count() if torch and torch.cuda.is_available() else 0),
|
| 120 |
+
"cuda_version": getattr(getattr(torch, "version", None), "cuda", None) if torch else None,
|
| 121 |
+
},
|
| 122 |
+
"paths": {
|
| 123 |
+
"base_dir": str(BASE_DIR),
|
| 124 |
+
"checkpoints_dir": str(CHECKPOINTS_DIR),
|
| 125 |
+
"sam2_dir": str(SAM2_DIR),
|
| 126 |
+
},
|
| 127 |
+
"env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
|
| 128 |
+
}
|
| 129 |
+
diag["model_verification"] = verify_models()
|
| 130 |
+
return diag
|
| 131 |
+
|
| 132 |
+
startup_diag = run_startup_diagnostics()
|
| 133 |
+
logger.info("Startup diagnostics completed")
|
| 134 |
+
|
| 135 |
+
# Noisy heartbeat so logs show life during import time
|
| 136 |
+
def _heartbeat():
|
| 137 |
+
i = 0
|
| 138 |
+
while True:
|
| 139 |
+
i += 1
|
| 140 |
+
print(f"[startup-heartbeat] {i*5}sβ¦", flush=True)
|
| 141 |
+
time.sleep(5)
|
| 142 |
+
|
| 143 |
+
threading.Thread(target=_heartbeat, daemon=True).start()
|
| 144 |
+
|
| 145 |
+
# Optional perf tuning import (non-fatal)
|
| 146 |
+
try:
|
| 147 |
+
import perf_tuning # noqa: F401
|
| 148 |
+
logger.info("perf_tuning imported successfully.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.info("perf_tuning not available: %s", e)
|
| 151 |
+
|
| 152 |
+
# MatAnyone non-instantiating probe
|
| 153 |
+
try:
|
| 154 |
+
import inspect
|
| 155 |
+
from matanyone.inference import inference_core as ic # type: ignore
|
| 156 |
+
sigs = {}
|
| 157 |
+
for name in ("InferenceCore",):
|
| 158 |
+
obj = getattr(ic, name, None)
|
| 159 |
+
if obj:
|
| 160 |
+
sigs[name] = "callable" if callable(obj) else "present"
|
| 161 |
+
logger.info(f"[MATANY] probe (non-instantiating): {sigs}")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.info(f"[MATANY] probe skipped: {e}")
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------
|
| 166 |
+
# Gradio import and web-stack probes
|
| 167 |
+
# ---------------------------------------------------------------------
|
| 168 |
+
import gradio as gr
|
| 169 |
+
|
| 170 |
+
# Standard logger for some libs that use stdlib logging
|
| 171 |
+
py_logger = logging.getLogger("backgroundfx_pro")
|
| 172 |
+
if not py_logger.handlers:
|
| 173 |
+
h = logging.StreamHandler()
|
| 174 |
+
h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
|
| 175 |
+
py_logger.addHandler(h)
|
| 176 |
+
py_logger.setLevel(logging.INFO)
|
| 177 |
+
|
| 178 |
+
def _log_web_stack_versions_and_paths():
|
| 179 |
+
import inspect
|
| 180 |
+
try:
|
| 181 |
+
import fastapi, starlette, pydantic, httpx, anyio
|
| 182 |
+
try:
|
| 183 |
+
import pydantic_core
|
| 184 |
+
pc_ver = pydantic_core.__version__
|
| 185 |
+
except Exception:
|
| 186 |
+
pc_ver = "unknown"
|
| 187 |
+
logger.info(
|
| 188 |
+
"[WEB-STACK] fastapi=%s | starlette=%s | pydantic=%s | pydantic-core=%s | httpx=%s | anyio=%s",
|
| 189 |
+
getattr(fastapi, "__version__", "?"),
|
| 190 |
+
getattr(starlette, "__version__", "?"),
|
| 191 |
+
getattr(pydantic, "__version__", "?"),
|
| 192 |
+
pc_ver,
|
| 193 |
+
getattr(httpx, "__version__", "?"),
|
| 194 |
+
getattr(anyio, "__version__", "?"),
|
| 195 |
+
)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.warning("[WEB-STACK] version probe failed: %s", e)
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
import gradio
|
| 201 |
+
import gradio.routes as gr_routes
|
| 202 |
+
import gradio.queueing as gr_queueing
|
| 203 |
+
logger.info("[PATH] gradio.__file__ = %s", getattr(gradio, "__file__", "?"))
|
| 204 |
+
logger.info("[PATH] gradio.routes = %s", inspect.getfile(gr_routes))
|
| 205 |
+
logger.info("[PATH] gradio.queueing = %s", inspect.getfile(gr_queueing))
|
| 206 |
+
import starlette.exceptions as st_exc
|
| 207 |
+
logger.info("[PATH] starlette.exceptions= %s", inspect.getfile(st_exc))
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.warning("[PATH] probe failed: %s", e)
|
| 210 |
+
|
| 211 |
+
def _post_launch_diag():
|
| 212 |
+
try:
|
| 213 |
+
if not torch:
|
| 214 |
+
return
|
| 215 |
+
avail = torch.cuda.is_available()
|
| 216 |
+
logger.info("CUDA available (post-launch): %s", avail)
|
| 217 |
+
if avail:
|
| 218 |
+
idx = torch.cuda.current_device()
|
| 219 |
+
name = torch.cuda.get_device_name(idx)
|
| 220 |
+
cap = torch.cuda.get_device_capability(idx)
|
| 221 |
+
logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning("Post-launch CUDA diag failed: %s", e)
|
| 224 |
+
|
| 225 |
+
# ---------------------------------------------------------------------
|
| 226 |
+
# UI factory (uses your existing builder)
|
| 227 |
+
# ---------------------------------------------------------------------
|
| 228 |
+
def build_ui() -> gr.Blocks:
|
| 229 |
+
# FIX: import from ui_core_interface (not from ui)
|
| 230 |
+
from ui_core_interface import create_interface
|
| 231 |
+
return create_interface()
|
| 232 |
+
|
| 233 |
+
# ---------------------------------------------------------------------
|
| 234 |
+
# Optional: custom FastAPI mount mode
|
| 235 |
+
# ---------------------------------------------------------------------
|
| 236 |
+
def build_fastapi_with_gradio(demo: gr.Blocks):
|
| 237 |
+
"""
|
| 238 |
+
Returns a FastAPI app with Gradio mounted at root.
|
| 239 |
+
Also exposes JSON health and a config shim using demo.get_config_file().
|
| 240 |
+
"""
|
| 241 |
+
from fastapi import FastAPI
|
| 242 |
+
from fastapi.responses import JSONResponse
|
| 243 |
+
|
| 244 |
+
app = FastAPI(title="VideoBackgroundReplacer2")
|
| 245 |
+
|
| 246 |
+
@app.get("/healthz")
|
| 247 |
+
def _healthz():
|
| 248 |
+
return {"ok": True, "ts": time.time()}
|
| 249 |
+
|
| 250 |
+
@app.get("/config")
|
| 251 |
+
def _config():
|
| 252 |
+
try:
|
| 253 |
+
cfg = demo.get_config_file()
|
| 254 |
+
return JSONResponse(content=cfg)
|
| 255 |
+
except Exception as e:
|
| 256 |
+
return JSONResponse(
|
| 257 |
+
status_code=500,
|
| 258 |
+
content={"error": "config_generation_failed", "detail": str(e)},
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Mount Gradio UI at root; our /config route remains at parent level
|
| 262 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
| 263 |
+
return app
|
| 264 |
+
|
| 265 |
+
# ---------------------------------------------------------------------
|
| 266 |
+
# Entrypoint
|
| 267 |
+
# ---------------------------------------------------------------------
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
host = os.environ.get("HOST", "0.0.0.0")
|
| 270 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 271 |
+
mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
|
| 272 |
+
|
| 273 |
+
logger.info("Launching on %s:%s (mount_mode=%s)β¦", host, port, mount_mode)
|
| 274 |
+
_log_web_stack_versions_and_paths()
|
| 275 |
+
|
| 276 |
+
demo = build_ui()
|
| 277 |
+
demo.queue(max_size=16, api_open=False)
|
| 278 |
+
|
| 279 |
+
threading.Thread(target=_post_launch_diag, daemon=True).start()
|
| 280 |
+
|
| 281 |
+
if mount_mode:
|
| 282 |
+
try:
|
| 283 |
+
from uvicorn import run as uvicorn_run
|
| 284 |
+
except Exception:
|
| 285 |
+
logger.error("uvicorn is not installed; mount mode cannot start.")
|
| 286 |
+
raise
|
| 287 |
+
|
| 288 |
+
app = build_fastapi_with_gradio(demo)
|
| 289 |
+
uvicorn_run(app=app, host=host, port=port, log_level="info")
|
| 290 |
+
else:
|
| 291 |
+
demo.launch(
|
| 292 |
+
server_name=host,
|
| 293 |
+
server_port=port,
|
| 294 |
+
share=False,
|
| 295 |
+
show_api=False,
|
| 296 |
+
show_error=True,
|
| 297 |
+
quiet=False,
|
| 298 |
+
debug=True,
|
| 299 |
+
max_threads=1,
|
| 300 |
+
)
|
pipeline_utils.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional, Union, Callable
|
| 7 |
+
import logging
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
# Configure logging
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class VideoProcessor:
|
| 15 |
+
def __init__(self, temp_dir: Optional[str] = None):
|
| 16 |
+
"""
|
| 17 |
+
Initialize the video processor.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
temp_dir: Directory for temporary files. If None, creates a temp directory.
|
| 21 |
+
"""
|
| 22 |
+
self.temp_dir = Path(temp_dir) if temp_dir else Path(tempfile.mkdtemp(prefix="bg_replace_"))
|
| 23 |
+
self.temp_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
self.device = self._get_device()
|
| 25 |
+
logger.info(f"Initialized VideoProcessor with device: {self.device}")
|
| 26 |
+
|
| 27 |
+
def _get_device(self) -> str:
|
| 28 |
+
"""Check if CUDA is available."""
|
| 29 |
+
try:
|
| 30 |
+
import torch
|
| 31 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
except ImportError:
|
| 33 |
+
return "cpu"
|
| 34 |
+
|
| 35 |
+
def _create_static_bg_video(
|
| 36 |
+
self,
|
| 37 |
+
bg_image: np.ndarray,
|
| 38 |
+
reference_video: str,
|
| 39 |
+
output_path: str
|
| 40 |
+
) -> str:
|
| 41 |
+
"""
|
| 42 |
+
Create a static background video matching the input video's duration.
|
| 43 |
+
"""
|
| 44 |
+
cap = cv2.VideoCapture(reference_video)
|
| 45 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 46 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 47 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 48 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 49 |
+
cap.release()
|
| 50 |
+
|
| 51 |
+
# Resize background image
|
| 52 |
+
bg_image = cv2.resize(bg_image, (width, height))
|
| 53 |
+
|
| 54 |
+
# Write video
|
| 55 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 56 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 57 |
+
|
| 58 |
+
for _ in range(total_frames):
|
| 59 |
+
out.write(bg_image)
|
| 60 |
+
|
| 61 |
+
out.release()
|
| 62 |
+
return output_path
|
| 63 |
+
|
| 64 |
+
def _process_with_pipeline(
|
| 65 |
+
self,
|
| 66 |
+
input_video: str,
|
| 67 |
+
background: Optional[Union[str, np.ndarray]] = None,
|
| 68 |
+
bg_type: str = "blur",
|
| 69 |
+
progress_callback: Optional[Callable] = None
|
| 70 |
+
) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Process video using the two-stage pipeline.
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
# Import the pipeline
|
| 76 |
+
from integrated_pipeline import TwoStageProcessor
|
| 77 |
+
|
| 78 |
+
# Update progress
|
| 79 |
+
if progress_callback:
|
| 80 |
+
progress_callback(0.1, "Initializing pipeline...")
|
| 81 |
+
|
| 82 |
+
# Handle background
|
| 83 |
+
bg_video_path = ""
|
| 84 |
+
if bg_type == "image" and background is not None:
|
| 85 |
+
if isinstance(background, str):
|
| 86 |
+
bg_image = cv2.imread(background)
|
| 87 |
+
else:
|
| 88 |
+
bg_image = background
|
| 89 |
+
|
| 90 |
+
bg_video_path = str(self.temp_dir / "background.mp4")
|
| 91 |
+
self._create_static_bg_video(bg_image, input_video, bg_video_path)
|
| 92 |
+
|
| 93 |
+
# Initialize processor
|
| 94 |
+
processor = TwoStageProcessor(temp_dir=str(self.temp_dir))
|
| 95 |
+
|
| 96 |
+
# Process video
|
| 97 |
+
output_path = str(self.temp_dir / "output.mp4")
|
| 98 |
+
|
| 99 |
+
# Mock click points (center of frame)
|
| 100 |
+
click_points = [[0.5, 0.5]]
|
| 101 |
+
|
| 102 |
+
# Process
|
| 103 |
+
success = processor.process_video(
|
| 104 |
+
input_video=input_video,
|
| 105 |
+
background_video=bg_video_path if bg_type == "image" else "",
|
| 106 |
+
click_points=click_points,
|
| 107 |
+
output_path=output_path,
|
| 108 |
+
use_matanyone=True,
|
| 109 |
+
progress_callback=progress_callback
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not success:
|
| 113 |
+
raise RuntimeError("Video processing failed")
|
| 114 |
+
|
| 115 |
+
return output_path
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error in pipeline: {str(e)}")
|
| 119 |
+
raise
|
| 120 |
+
|
| 121 |
+
def process_video(
|
| 122 |
+
self,
|
| 123 |
+
input_path: Union[str, bytes],
|
| 124 |
+
background: Optional[Union[str, np.ndarray]] = None,
|
| 125 |
+
bg_type: str = "blur",
|
| 126 |
+
progress_callback: Optional[Callable] = None
|
| 127 |
+
) -> bytes:
|
| 128 |
+
"""
|
| 129 |
+
Process a video with the given background.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
input_path: Path to input video or bytes
|
| 133 |
+
background: Background image path or numpy array
|
| 134 |
+
bg_type: Type of background ("image", "color", or "blur")
|
| 135 |
+
progress_callback: Optional callback for progress updates
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Processed video as bytes
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
# Save input to temp file if it's bytes
|
| 142 |
+
if isinstance(input_path, bytes):
|
| 143 |
+
input_video = str(self.temp_dir / "input.mp4")
|
| 144 |
+
with open(input_video, "wb") as f:
|
| 145 |
+
f.write(input_path)
|
| 146 |
+
else:
|
| 147 |
+
input_video = input_path
|
| 148 |
+
|
| 149 |
+
# Process the video
|
| 150 |
+
output_path = self._process_with_pipeline(
|
| 151 |
+
input_video,
|
| 152 |
+
background,
|
| 153 |
+
bg_type,
|
| 154 |
+
progress_callback
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Read the output file
|
| 158 |
+
with open(output_path, "rb") as f:
|
| 159 |
+
return f.read()
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"Error processing video: {str(e)}")
|
| 163 |
+
raise
|
| 164 |
+
|
| 165 |
+
# Global instance
|
| 166 |
+
video_processor = VideoProcessor()
|
| 167 |
+
|
| 168 |
+
def process_video_pipeline(
|
| 169 |
+
input_data: Union[str, bytes],
|
| 170 |
+
background: Optional[Union[str, np.ndarray]] = None,
|
| 171 |
+
bg_type: str = "blur",
|
| 172 |
+
progress_callback: Optional[Callable] = None
|
| 173 |
+
) -> bytes:
|
| 174 |
+
"""
|
| 175 |
+
High-level function to process a video.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
input_data: Input video path or bytes
|
| 179 |
+
background: Background image path or numpy array
|
| 180 |
+
bg_type: Type of background ("image", "color", or "blur")
|
| 181 |
+
progress_callback: Optional progress callback
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Processed video as bytes
|
| 185 |
+
"""
|
| 186 |
+
return video_processor.process_video(
|
| 187 |
+
input_data,
|
| 188 |
+
background,
|
| 189 |
+
bg_type,
|
| 190 |
+
progress_callback
|
| 191 |
+
)
|
requirements.txt
CHANGED
|
@@ -35,22 +35,21 @@ iopath>=0.1.10,<0.2.0
|
|
| 35 |
kornia>=0.7.0,<0.8.0
|
| 36 |
tqdm>=4.60.0,<5.0.0
|
| 37 |
|
| 38 |
-
# ===== UI
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
-
# ===== Web stack
|
| 43 |
-
fastapi
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
anyio==4.4.0
|
| 48 |
orjson>=3.10.0
|
| 49 |
|
| 50 |
# ===== Pydantic family (avoid breaking core 2.23.x) =====
|
| 51 |
pydantic==2.8.2
|
| 52 |
pydantic-core==2.20.1
|
| 53 |
-
annotated-types==0.
|
| 54 |
typing-extensions==4.12.2
|
| 55 |
|
| 56 |
# ===== Helpers and Utilities =====
|
|
@@ -69,4 +68,4 @@ nvidia-ml-py3>=7.352.0,<12.0.0
|
|
| 69 |
loguru>=0.6.0,<1.0.0
|
| 70 |
|
| 71 |
# File handling
|
| 72 |
-
python-multipart>=0.0.5,<1.0.0
|
|
|
|
| 35 |
kornia>=0.7.0,<0.8.0
|
| 36 |
tqdm>=4.60.0,<5.0.0
|
| 37 |
|
| 38 |
+
# ===== Streamlit UI =====
|
| 39 |
+
streamlit>=1.32.0
|
| 40 |
+
streamlit-webrtc>=0.50.0 # For real-time video processing
|
| 41 |
|
| 42 |
+
# ===== Web stack =====
|
| 43 |
+
fastapi>=0.104.0
|
| 44 |
+
uvicorn>=0.24.0
|
| 45 |
+
httpx>=0.25.0
|
| 46 |
+
anyio>=4.0.0
|
|
|
|
| 47 |
orjson>=3.10.0
|
| 48 |
|
| 49 |
# ===== Pydantic family (avoid breaking core 2.23.x) =====
|
| 50 |
pydantic==2.8.2
|
| 51 |
pydantic-core==2.20.1
|
| 52 |
+
annotated-types==0.60.0
|
| 53 |
typing-extensions==4.12.2
|
| 54 |
|
| 55 |
# ===== Helpers and Utilities =====
|
|
|
|
| 68 |
loguru>=0.6.0,<1.0.0
|
| 69 |
|
| 70 |
# File handling
|
| 71 |
+
python-multipart>=0.0.5,<1.0.0
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# streamlit_ui.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import logging
|
| 12 |
+
import io
|
| 13 |
+
|
| 14 |
+
# Add project root to path
|
| 15 |
+
sys.path.append(str(Path(__file__).parent.absolute()))
|
| 16 |
+
|
| 17 |
+
# Configure logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Set page config
|
| 22 |
+
st.set_page_config(
|
| 23 |
+
page_title="π¬ Advanced Video Background Replacer",
|
| 24 |
+
page_icon="π₯",
|
| 25 |
+
layout="wide",
|
| 26 |
+
initial_sidebar_state="expanded"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Custom CSS for better UI
|
| 30 |
+
st.markdown("""
|
| 31 |
+
<style>
|
| 32 |
+
.main .block-container {
|
| 33 |
+
padding-top: 2rem;
|
| 34 |
+
padding-bottom: 2rem;
|
| 35 |
+
}
|
| 36 |
+
.stButton>button {
|
| 37 |
+
width: 100%;
|
| 38 |
+
background-color: #4CAF50;
|
| 39 |
+
color: white;
|
| 40 |
+
font-weight: bold;
|
| 41 |
+
transition: all 0.3s;
|
| 42 |
+
}
|
| 43 |
+
.stButton>button:hover {
|
| 44 |
+
background-color: #45a049;
|
| 45 |
+
}
|
| 46 |
+
.stProgress > div > div > div > div {
|
| 47 |
+
background-color: #4CAF50;
|
| 48 |
+
}
|
| 49 |
+
.stAlert {
|
| 50 |
+
border-radius: 10px;
|
| 51 |
+
}
|
| 52 |
+
.stTabs [data-baseweb="tab-list"] {
|
| 53 |
+
gap: 10px;
|
| 54 |
+
}
|
| 55 |
+
.stTabs [data-baseweb="tab"] {
|
| 56 |
+
height: 50px;
|
| 57 |
+
white-space: pre;
|
| 58 |
+
background-color: #f0f2f6;
|
| 59 |
+
border-radius: 4px 4px 0 0;
|
| 60 |
+
padding: 10px 20px;
|
| 61 |
+
margin-right: 5px;
|
| 62 |
+
}
|
| 63 |
+
.stTabs [aria-selected="true"] {
|
| 64 |
+
background-color: #4CAF50;
|
| 65 |
+
color: white;
|
| 66 |
+
}
|
| 67 |
+
.video-container {
|
| 68 |
+
border: 2px dashed #4CAF50;
|
| 69 |
+
border-radius: 10px;
|
| 70 |
+
padding: 10px;
|
| 71 |
+
margin-bottom: 20px;
|
| 72 |
+
}
|
| 73 |
+
</style>
|
| 74 |
+
""", unsafe_allow_html=True)
|
| 75 |
+
|
| 76 |
+
def initialize_session_state():
|
| 77 |
+
"""Initialize all session state variables"""
|
| 78 |
+
if 'uploaded_video' not in st.session_state:
|
| 79 |
+
st.session_state.uploaded_video = None
|
| 80 |
+
if 'bg_image' not in st.session_state:
|
| 81 |
+
st.session_state.bg_image = None
|
| 82 |
+
if 'bg_color' not in st.session_state:
|
| 83 |
+
st.session_state.bg_color = "#00FF00"
|
| 84 |
+
if 'processed_video_path' not in st.session_state:
|
| 85 |
+
st.session_state.processed_video_path = None
|
| 86 |
+
if 'processing' not in st.session_state:
|
| 87 |
+
st.session_state.processing = False
|
| 88 |
+
if 'progress' not in st.session_state:
|
| 89 |
+
st.session_state.progress = 0
|
| 90 |
+
if 'progress_text' not in st.session_state:
|
| 91 |
+
st.session_state.progress_text = "Ready"
|
| 92 |
+
|
| 93 |
+
def handle_video_upload():
|
| 94 |
+
"""Handle video file upload"""
|
| 95 |
+
uploaded = st.file_uploader(
|
| 96 |
+
"πΉ Upload Video",
|
| 97 |
+
type=["mp4", "mov", "avi"],
|
| 98 |
+
key="video_uploader"
|
| 99 |
+
)
|
| 100 |
+
if uploaded is not None:
|
| 101 |
+
st.session_state.uploaded_video = uploaded
|
| 102 |
+
|
| 103 |
+
def show_video_preview():
|
| 104 |
+
"""Show video preview in the UI"""
|
| 105 |
+
st.markdown("### Video Preview")
|
| 106 |
+
if st.session_state.uploaded_video is not None:
|
| 107 |
+
video_bytes = st.session_state.uploaded_video.getvalue()
|
| 108 |
+
st.video(video_bytes)
|
| 109 |
+
st.session_state.uploaded_video.seek(0)
|
| 110 |
+
|
| 111 |
+
def handle_background_selection():
|
| 112 |
+
"""Handle background selection UI"""
|
| 113 |
+
st.markdown("### Background Options")
|
| 114 |
+
bg_type = st.radio(
|
| 115 |
+
"Select Background Type:",
|
| 116 |
+
["Image", "Color", "Blur"],
|
| 117 |
+
horizontal=True,
|
| 118 |
+
index=0
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if bg_type == "Image":
|
| 122 |
+
bg_image = st.file_uploader(
|
| 123 |
+
"πΌοΈ Upload Background Image",
|
| 124 |
+
type=["jpg", "png", "jpeg"],
|
| 125 |
+
key="bg_image_uploader"
|
| 126 |
+
)
|
| 127 |
+
if bg_image is not None:
|
| 128 |
+
st.session_state.bg_image = Image.open(bg_image)
|
| 129 |
+
st.image(
|
| 130 |
+
st.session_state.bg_image,
|
| 131 |
+
caption="Selected Background",
|
| 132 |
+
use_container_width=True
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
elif bg_type == "Color":
|
| 136 |
+
st.session_state.bg_color = st.color_picker(
|
| 137 |
+
"π¨ Choose Background Color",
|
| 138 |
+
st.session_state.bg_color
|
| 139 |
+
)
|
| 140 |
+
color_rgb = tuple(int(st.session_state.bg_color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
|
| 141 |
+
color_display = np.zeros((100, 100, 3), dtype=np.uint8)
|
| 142 |
+
color_display[:, :] = color_rgb[::-1] # RGB to BGR for OpenCV
|
| 143 |
+
st.image(color_display, caption="Selected Color", width=200)
|
| 144 |
+
|
| 145 |
+
return bg_type.lower()
|
| 146 |
+
|
| 147 |
+
def process_video(input_file, background, bg_type="image"):
|
| 148 |
+
"""
|
| 149 |
+
Process video with the selected background using SAM2 and MatAnyone pipeline.
|
| 150 |
+
Returns the path to the processed video file.
|
| 151 |
+
"""
|
| 152 |
+
try:
|
| 153 |
+
# Create a temporary directory for processing
|
| 154 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 155 |
+
temp_dir = Path(temp_dir)
|
| 156 |
+
|
| 157 |
+
# Save the uploaded video to a temporary file
|
| 158 |
+
input_path = str(temp_dir / "input.mp4")
|
| 159 |
+
with open(input_path, "wb") as f:
|
| 160 |
+
f.write(input_file.getvalue())
|
| 161 |
+
|
| 162 |
+
# Prepare background
|
| 163 |
+
bg_path = None
|
| 164 |
+
if bg_type == "image" and background is not None:
|
| 165 |
+
# Convert PIL Image to OpenCV format
|
| 166 |
+
bg_cv = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
|
| 167 |
+
bg_path = str(temp_dir / "background.jpg")
|
| 168 |
+
cv2.imwrite(bg_path, bg_cv)
|
| 169 |
+
elif bg_type == "color" and hasattr(st.session_state, 'bg_color'):
|
| 170 |
+
# Create a solid color image
|
| 171 |
+
color_hex = st.session_state.bg_color.lstrip('#')
|
| 172 |
+
color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
|
| 173 |
+
bg_path = str(temp_dir / "background.jpg")
|
| 174 |
+
cv2.imwrite(bg_path, np.ones((100, 100, 3), dtype=np.uint8) * color_rgb[::-1])
|
| 175 |
+
|
| 176 |
+
# Set up progress callback
|
| 177 |
+
progress_bar = st.progress(0)
|
| 178 |
+
status_text = st.empty()
|
| 179 |
+
|
| 180 |
+
def progress_callback(progress, message):
|
| 181 |
+
progress = max(0, min(1, float(progress)))
|
| 182 |
+
progress_bar.progress(progress)
|
| 183 |
+
status_text.text(f"Status: {message}")
|
| 184 |
+
st.session_state.progress = int(progress * 100)
|
| 185 |
+
st.session_state.progress_text = message
|
| 186 |
+
|
| 187 |
+
# Process the video
|
| 188 |
+
output_path = str(temp_dir / "output.mp4")
|
| 189 |
+
|
| 190 |
+
# Mock click points (center of the frame)
|
| 191 |
+
click_points = [[0.5, 0.5]]
|
| 192 |
+
|
| 193 |
+
# Import the pipeline processor
|
| 194 |
+
from integrated_pipeline import TwoStageProcessor
|
| 195 |
+
|
| 196 |
+
# Initialize the processor
|
| 197 |
+
processor = TwoStageProcessor(temp_dir=str(temp_dir))
|
| 198 |
+
|
| 199 |
+
# Process the video
|
| 200 |
+
success = processor.process_video(
|
| 201 |
+
input_video=input_path,
|
| 202 |
+
background_video=bg_path if bg_type == "image" else "",
|
| 203 |
+
click_points=click_points,
|
| 204 |
+
output_path=output_path,
|
| 205 |
+
use_matanyone=True,
|
| 206 |
+
progress_callback=progress_callback
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
if not success:
|
| 210 |
+
raise RuntimeError("Video processing failed")
|
| 211 |
+
|
| 212 |
+
# Return the path to the processed video
|
| 213 |
+
return output_path
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Error in video processing: {str(e)}", exc_info=True)
|
| 217 |
+
st.error(f"An error occurred during processing: {str(e)}")
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
def main():
|
| 221 |
+
st.title("π¬ Advanced Video Background Replacer")
|
| 222 |
+
st.markdown("---")
|
| 223 |
+
|
| 224 |
+
# Initialize session state
|
| 225 |
+
initialize_session_state()
|
| 226 |
+
|
| 227 |
+
# Main layout
|
| 228 |
+
col1, col2 = st.columns([1, 1], gap="large")
|
| 229 |
+
|
| 230 |
+
with col1:
|
| 231 |
+
st.header("1. Upload Video")
|
| 232 |
+
handle_video_upload()
|
| 233 |
+
show_video_preview()
|
| 234 |
+
|
| 235 |
+
with col2:
|
| 236 |
+
st.header("2. Background Settings")
|
| 237 |
+
bg_type = handle_background_selection()
|
| 238 |
+
|
| 239 |
+
st.header("3. Process & Download")
|
| 240 |
+
if st.button(
|
| 241 |
+
"π Process Video",
|
| 242 |
+
type="primary",
|
| 243 |
+
disabled=not st.session_state.uploaded_video or st.session_state.processing,
|
| 244 |
+
use_container_width=True
|
| 245 |
+
):
|
| 246 |
+
with st.spinner("Processing video (this may take a few minutes)..."):
|
| 247 |
+
st.session_state.processing = True
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
# Prepare background based on type
|
| 251 |
+
background = None
|
| 252 |
+
if bg_type == "image" and 'bg_image' in st.session_state and st.session_state.bg_image is not None:
|
| 253 |
+
background = st.session_state.bg_image
|
| 254 |
+
elif bg_type == "color" and 'bg_color' in st.session_state:
|
| 255 |
+
background = st.session_state.bg_color
|
| 256 |
+
|
| 257 |
+
# Process the video
|
| 258 |
+
output_path = process_video(
|
| 259 |
+
st.session_state.uploaded_video,
|
| 260 |
+
background,
|
| 261 |
+
bg_type=bg_type
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if output_path and os.path.exists(output_path):
|
| 265 |
+
# Store the path to the processed video
|
| 266 |
+
st.session_state.processed_video_path = output_path
|
| 267 |
+
st.success("β
Video processing complete!")
|
| 268 |
+
else:
|
| 269 |
+
st.error("β Failed to process video. Please check the logs for details.")
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
st.error(f"β An error occurred: {str(e)}")
|
| 273 |
+
logger.exception("Video processing failed")
|
| 274 |
+
|
| 275 |
+
finally:
|
| 276 |
+
st.session_state.processing = False
|
| 277 |
+
|
| 278 |
+
# Show processed video if available
|
| 279 |
+
if 'processed_video_path' in st.session_state and st.session_state.processed_video_path:
|
| 280 |
+
st.markdown("### Processed Video")
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
# Display the video directly from the file
|
| 284 |
+
with open(st.session_state.processed_video_path, 'rb') as f:
|
| 285 |
+
video_bytes = f.read()
|
| 286 |
+
st.video(video_bytes)
|
| 287 |
+
|
| 288 |
+
# Download button
|
| 289 |
+
st.download_button(
|
| 290 |
+
label="πΎ Download Processed Video",
|
| 291 |
+
data=video_bytes,
|
| 292 |
+
file_name="processed_video.mp4",
|
| 293 |
+
mime="video/mp4",
|
| 294 |
+
use_container_width=True
|
| 295 |
+
)
|
| 296 |
+
except Exception as e:
|
| 297 |
+
st.error(f"Error displaying video: {str(e)}")
|
| 298 |
+
logger.error(f"Error displaying video: {str(e)}", exc_info=True)
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
main()
|