MogensR commited on
Commit
ac3d151
·
1 Parent(s): f92d322
Files changed (2) hide show
  1. Dockerfile +2 -6
  2. models/sam2_loader.py +279 -107
Dockerfile CHANGED
@@ -89,16 +89,12 @@ ENV PYTHONPATH=/home/user/app:/home/user/app/third_party:/home/user/app/third_pa
89
  MATANY_DEVICE=cuda \
90
  OMP_NUM_THREADS=2 \
91
  TF_CPP_MIN_LOG_LEVEL=2 \
92
- SAM2_MODEL_PATH=/home/user/app/checkpoints/sam2_hiera_large.pt
93
 
94
  # Create checkpoints directory
95
  RUN mkdir -p /home/user/app/checkpoints
96
 
97
- # Download SAM2 model
98
- RUN echo "Downloading SAM2 model..." && \
99
- wget -q -O /home/user/app/checkpoints/sam2_hiera_large.pt \
100
- https://dl.fbaipublicfiles.com/segment_anything/sam2_hiera_large.pt && \
101
- chown -R user:user /home/user/app/checkpoints
102
 
103
  # Health check
104
  HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD python3 -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
 
89
  MATANY_DEVICE=cuda \
90
  OMP_NUM_THREADS=2 \
91
  TF_CPP_MIN_LOG_LEVEL=2 \
92
+ SAM2_CHECKPOINT=/home/user/app/checkpoints/sam2_hiera_large.pt
93
 
94
  # Create checkpoints directory
95
  RUN mkdir -p /home/user/app/checkpoints
96
 
97
+ # Note: SAM2 model will be downloaded at runtime via lazy loading
 
 
 
 
98
 
99
  # Health check
100
  HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD python3 -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
models/sam2_loader.py CHANGED
@@ -1,107 +1,279 @@
1
- # ===============================
2
- # Optimized Dockerfile for Hugging Face Spaces
3
- # PyTorch 2.3.1 + CUDA 12.1
4
- # ===============================
5
-
6
- # Base image with CUDA 12.1.1
7
- FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
8
-
9
- # Environment variables
10
- ENV DEBIAN_FRONTEND=noninteractive \
11
- PYTHONUNBUFFERED=1 \
12
- PYTHONDONTWRITEBYTECODE=1 \
13
- PIP_NO_CACHE_DIR=1 \
14
- PIP_DISABLE_PIP_VERSION_CHECK=1 \
15
- TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX" \
16
- FORCE_CUDA="1" \
17
- CUDA_VISIBLE_DEVICES="0"
18
-
19
- # Create non-root user
20
- RUN useradd -m -u 1000 user
21
- ENV HOME=/home/user
22
- WORKDIR $HOME/app
23
-
24
- # Install system dependencies in a single layer
25
- RUN apt-get update && apt-get install -y --no-install-recommends \
26
- git \
27
- ffmpeg \
28
- wget \
29
- python3 \
30
- python3-pip \
31
- python3-venv \
32
- python3-dev \
33
- build-essential \
34
- gcc \
35
- g++ \
36
- pkg-config \
37
- libffi-dev \
38
- libssl-dev \
39
- libc6-dev \
40
- libgl1-mesa-glx \
41
- libglib2.0-0 \
42
- libsm6 \
43
- libxext6 \
44
- libxrender1 \
45
- libgomp1 \
46
- && rm -rf /var/lib/apt/lists/*
47
-
48
- # Set up Python environment
49
- RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
50
-
51
- # Install PyTorch with CUDA 12.1 first (base for other dependencies)
52
- RUN python3 -m pip install --no-cache-dir \
53
- --extra-index-url https://download.pytorch.org/whl/cu121 \
54
- torch==2.3.1+cu121 \
55
- torchvision==0.18.1+cu121 \
56
- torchaudio==2.3.1+cu121 \
57
- && python3 -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}'); print(f'CUDA version: {torch.version.cuda if torch.cuda.is_available() else \"N/A\"}'); print(f'cuDNN version: {torch.backends.cudnn.version() if torch.cuda.is_available() else \"N/A\"}')"
58
-
59
- # Copy requirements files first for better caching
60
- COPY --chown=user requirements.txt requirements-hf.txt ./
61
-
62
- # Install Python dependencies
63
- RUN python3 -m pip install --no-cache-dir -r requirements.txt
64
-
65
- # Install MatAnyone with retry logic and fallback dependencies
66
- RUN echo "Installing problematic dependencies first..." && \
67
- python3 -m pip install --no-cache-dir chardet charset-normalizer && \
68
- echo "Installing MatAnyone..." && \
69
- (python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone || \
70
- (echo "Retrying MatAnyone installation..." && \
71
- python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone)) && \
72
- python3 -c "import matanyone; print('MatAnyone import successful')"
73
-
74
- # Copy application code
75
- COPY --chown=user . .
76
-
77
- # Install SAM2
78
- RUN echo "Installing SAM2..." && \
79
- git clone --depth=1 https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
80
- cd third_party/sam2 && \
81
- python3 -m pip install --no-cache-dir -e .
82
-
83
- # Set up environment variables
84
- ENV PYTHONPATH=/home/user/app:/home/user/app/third_party:/home/user/app/third_party/sam2 \
85
- FFMPEG_BIN=ffmpeg \
86
- THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
87
- ENABLE_MATANY=1 \
88
- SAM2_DEVICE=cuda \
89
- MATANY_DEVICE=cuda \
90
- OMP_NUM_THREADS=2 \
91
- TF_CPP_MIN_LOG_LEVEL=2 \
92
- SAM2_CHECKPOINT=/home/user/app/checkpoints/sam2_hiera_large.pt
93
-
94
- # Create checkpoints directory
95
- RUN mkdir -p /home/user/app/checkpoints
96
-
97
- # Note: SAM2 model will be downloaded at runtime via lazy loading
98
-
99
- # Health check
100
- HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD python3 -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
101
-
102
- # Run as non-root user
103
- USER user
104
- EXPOSE 7860
105
-
106
- # Start the application
107
- CMD ["python3", "-u", "app_hf.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAM2 Loader Robust loading and mask generation for SAM2
4
+ ========================================================
5
+ - Loads SAM2 model with Hydra config resolution
6
+ - Generates seed masks for MatAnyone
7
+ - Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
8
+
9
+ Changes (2025-09-16):
10
+ - Aligned with torch==2.3.1+cu121 and SAM2 commit
11
+ - Added GPU memory logging for Tesla T4
12
+ - Added SAM2 version logging via importlib.metadata
13
+ - Simplified config resolution to match __init__.py
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import logging
20
+ import importlib.metadata
21
+ from pathlib import Path
22
+ from typing import Optional, Tuple, Dict, Any
23
+
24
+ import numpy as np
25
+ import yaml
26
+ import torch
27
+
28
+ # --------------------------------------------------------------------------------------
29
+ # Logging
30
+ # --------------------------------------------------------------------------------------
31
+ logger = logging.getLogger("backgroundfx_pro")
32
+ if not logger.handlers:
33
+ _h = logging.StreamHandler()
34
+ _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
35
+ logger.addHandler(_h)
36
+ logger.setLevel(logging.INFO)
37
+
38
+ # --------------------------------------------------------------------------------------
39
+ # Path setup for third_party repos
40
+ # --------------------------------------------------------------------------------------
41
+ ROOT = Path(__file__).resolve().parent.parent # project root
42
+ TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
43
+
44
+ def _add_sys_path(p: Path) -> None:
45
+ if p.exists():
46
+ p_str = str(p)
47
+ if p_str not in sys.path:
48
+ sys.path.insert(0, p_str)
49
+ else:
50
+ logger.warning(f"third_party path not found: {p}")
51
+
52
+ _add_sys_path(TP_SAM2)
53
+
54
+ # --------------------------------------------------------------------------------------
55
+ # Safe Torch accessors
56
+ # --------------------------------------------------------------------------------------
57
+ def _torch():
58
+ try:
59
+ import torch
60
+ return torch
61
+ except Exception as e:
62
+ logger.warning(f"[sam2_loader.safe-torch] import failed: {e}")
63
+ return None
64
+
65
+ def _has_cuda() -> bool:
66
+ t = _torch()
67
+ if t is None:
68
+ return False
69
+ try:
70
+ return bool(t.cuda.is_available())
71
+ except Exception as e:
72
+ logger.warning(f"[sam2_loader.safe-torch] cuda.is_available() failed: {e}")
73
+ return False
74
+
75
+ def _pick_device(env_key: str) -> str:
76
+ requested = os.environ.get(env_key, "").strip().lower()
77
+ has_cuda = _has_cuda()
78
+
79
+ logger.info(f"CUDA environment variables: {{'SAM2_DEVICE': '{os.environ.get('SAM2_DEVICE', '')}'}}")
80
+ logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
81
+
82
+ if has_cuda and requested not in {"cpu"}:
83
+ logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
84
+ return "cuda"
85
+ elif requested in {"cuda", "cpu"}:
86
+ logger.info(f"Using explicitly requested device: {requested}")
87
+ return requested
88
+
89
+ result = "cuda" if has_cuda else "cpu"
90
+ logger.info(f"Auto-selected device: {result}")
91
+ return result
92
+
93
+ # --------------------------------------------------------------------------------------
94
+ # SAM2 Loading and Mask Generation
95
+ # --------------------------------------------------------------------------------------
96
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
97
+ """Resolve SAM2 config path - return relative path for Hydra compatibility."""
98
+ logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
99
+
100
+ candidate = os.path.join(TP_SAM2, cfg_str)
101
+ logger.info(f"Candidate path: {candidate}")
102
+ logger.info(f"Candidate exists: {os.path.exists(candidate)}")
103
+
104
+ if os.path.exists(candidate):
105
+ if cfg_str.startswith("sam2/configs/"):
106
+ relative_path = cfg_str.replace("sam2/configs/", "configs/")
107
+ else:
108
+ relative_path = cfg_str
109
+ logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
110
+ return relative_path
111
+
112
+ fallbacks = [
113
+ os.path.join(TP_SAM2, "sam2", cfg_str),
114
+ os.path.join(TP_SAM2, "configs", cfg_str),
115
+ ]
116
+
117
+ for fallback in fallbacks:
118
+ logger.info(f"Trying fallback: {fallback}")
119
+ if os.path.exists(fallback):
120
+ if "configs/" in fallback:
121
+ relative_path = "configs/" + fallback.split("configs/")[-1]
122
+ logger.info(f"Returning fallback relative path: {relative_path}")
123
+ return relative_path
124
+
125
+ logger.warning(f"Config not found, returning original: {cfg_str}")
126
+ return cfg_str
127
+
128
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
129
+ """If config references 'hieradet', try to find a 'hiera' config."""
130
+ try:
131
+ with open(cfg_path, "r") as f:
132
+ data = yaml.safe_load(f)
133
+ model = data.get("model", {}) or {}
134
+ enc = model.get("image_encoder") or {}
135
+ trunk = enc.get("trunk") or {}
136
+ target = trunk.get("_target_") or trunk.get("target")
137
+ if isinstance(target, str) and "hieradet" in target:
138
+ for y in TP_SAM2.rglob("*.yaml"):
139
+ try:
140
+ with open(y, "r") as f2:
141
+ d2 = yaml.safe_load(f2) or {}
142
+ e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
143
+ t2 = (e2.get("trunk") or {})
144
+ tgt2 = t2.get("_target_") or t2.get("target")
145
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
146
+ logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
147
+ return str(y)
148
+ except Exception:
149
+ continue
150
+ except Exception:
151
+ pass
152
+ return None
153
+
154
+ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
155
+ """Robust SAM2 loader with config resolution and error handling."""
156
+ meta = {"sam2_import_ok": False, "sam2_init_ok": False}
157
+ try:
158
+ from sam2.build_sam import build_sam2
159
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
160
+ meta["sam2_import_ok"] = True
161
+ except Exception as e:
162
+ logger.warning(f"SAM2 import failed: {e}")
163
+ return None, False, meta
164
+
165
+ # Log SAM2 version
166
+ try:
167
+ version = importlib.metadata.version("segment-anything-2")
168
+ logger.info(f"[SAM2] SAM2 version: {version}")
169
+ except Exception:
170
+ logger.info("[SAM2] SAM2 version unknown")
171
+
172
+ # Check GPU memory before loading
173
+ if torch and torch.cuda.is_available():
174
+ mem_before = torch.cuda.memory_allocated() / 1024**3
175
+ logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB")
176
+
177
+ device = _pick_device("SAM2_DEVICE")
178
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
179
+ cfg = _resolve_sam2_cfg(cfg_env)
180
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "")
181
+
182
+ def _try_build(cfg_path: str):
183
+ logger.info(f"_try_build called with cfg_path: {cfg_path}")
184
+ params = set(inspect.signature(build_sam2).parameters.keys())
185
+ logger.info(f"build_sam2 parameters: {list(params)}")
186
+ kwargs = {}
187
+ if "config_file" in params:
188
+ kwargs["config_file"] = cfg_path
189
+ logger.info(f"Using config_file parameter: {cfg_path}")
190
+ elif "model_cfg" in params:
191
+ kwargs["model_cfg"] = cfg_path
192
+ logger.info(f"Using model_cfg parameter: {cfg_path}")
193
+ if ckpt:
194
+ if "checkpoint" in params:
195
+ kwargs["checkpoint"] = ckpt
196
+ elif "ckpt_path" in params:
197
+ kwargs["ckpt_path"] = ckpt
198
+ elif "weights" in params:
199
+ kwargs["weights"] = ckpt
200
+ if "device" in params:
201
+ kwargs["device"] = device
202
+ try:
203
+ logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
204
+ result = build_sam2(**kwargs)
205
+ logger.info(f"build_sam2 succeeded with kwargs")
206
+ if hasattr(result, 'device'):
207
+ logger.info(f"SAM2 model device: {result.device}")
208
+ elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
209
+ logger.info(f"SAM2 model device: {result.image_encoder.device}")
210
+ return result
211
+ except TypeError as e:
212
+ logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
213
+ pos = [cfg_path]
214
+ if ckpt:
215
+ pos.append(ckpt)
216
+ if "device" not in kwargs:
217
+ pos.append(device)
218
+ logger.info(f"Calling build_sam2 with positional args: {pos}")
219
+ result = build_sam2(*pos)
220
+ logger.info(f"build_sam2 succeeded with positional args")
221
+ return result
222
+
223
+ try:
224
+ try:
225
+ sam = _try_build(cfg)
226
+ except Exception:
227
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
228
+ if alt_cfg:
229
+ sam = _try_build(alt_cfg)
230
+ else:
231
+ raise
232
+
233
+ if sam is not None:
234
+ predictor = SAM2ImagePredictor(sam)
235
+ meta["sam2_init_ok"] = True
236
+ meta["sam2_device"] = device
237
+ return predictor, True, meta
238
+ else:
239
+ return None, False, meta
240
+
241
+ except Exception as e:
242
+ logger.error(f"SAM2 loading failed: {e}")
243
+ return None, False, meta
244
+
245
+ def run_sam2_mask(predictor: object,
246
+ first_frame_bgr: np.ndarray,
247
+ point: Optional[Tuple[int, int]] = None,
248
+ auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
249
+ """Generate a seed mask for MatAnyone. Returns (mask_uint8_0_255, ok)."""
250
+ if predictor is None:
251
+ return None, False
252
+ try:
253
+ import cv2
254
+ rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
255
+ predictor.set_image(rgb)
256
+
257
+ if auto:
258
+ h, w = rgb.shape[:2]
259
+ box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
260
+ masks, _, _ = predictor.predict(box=box)
261
+ elif point is not None:
262
+ x, y = int(point[0]), int(point[1])
263
+ pts = np.array([[x, y]], dtype=np.int32)
264
+ labels = np.array([1], dtype=np.int32)
265
+ masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
266
+ else:
267
+ h, w = rgb.shape[:2]
268
+ box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
269
+ masks, _, _ = predictor.predict(box=box)
270
+
271
+ if masks is None or len(masks) == 0:
272
+ return None, False
273
+
274
+ m = masks[0].astype(np.uint8) * 255
275
+ logger.info(f"[SAM2] Generated mask: shape={m.shape}, dtype={m.dtype}")
276
+ return m, True
277
+ except Exception as e:
278
+ logger.warning(f"SAM2 mask generation failed: {e}")
279
+ return None, False