MogensR commited on
Commit
f92d322
·
1 Parent(s): d5bc777
Files changed (1) hide show
  1. models/sam2_loader.py +107 -279
models/sam2_loader.py CHANGED
@@ -1,279 +1,107 @@
1
- #!/usr/bin/env python3
2
- """
3
- SAM2 Loader Robust loading and mask generation for SAM2
4
- ========================================================
5
- - Loads SAM2 model with Hydra config resolution
6
- - Generates seed masks for MatAnyone
7
- - Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
8
-
9
- Changes (2025-09-16):
10
- - Aligned with torch==2.3.1+cu121 and SAM2 commit
11
- - Added GPU memory logging for Tesla T4
12
- - Added SAM2 version logging via importlib.metadata
13
- - Simplified config resolution to match __init__.py
14
- """
15
-
16
- from __future__ import annotations
17
-
18
- import os
19
- import logging
20
- import importlib.metadata
21
- from pathlib import Path
22
- from typing import Optional, Tuple, Dict, Any
23
-
24
- import numpy as np
25
- import yaml
26
- import torch
27
-
28
- # --------------------------------------------------------------------------------------
29
- # Logging
30
- # --------------------------------------------------------------------------------------
31
- logger = logging.getLogger("backgroundfx_pro")
32
- if not logger.handlers:
33
- _h = logging.StreamHandler()
34
- _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
35
- logger.addHandler(_h)
36
- logger.setLevel(logging.INFO)
37
-
38
- # --------------------------------------------------------------------------------------
39
- # Path setup for third_party repos
40
- # --------------------------------------------------------------------------------------
41
- ROOT = Path(__file__).resolve().parent.parent # project root
42
- TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
43
-
44
- def _add_sys_path(p: Path) -> None:
45
- if p.exists():
46
- p_str = str(p)
47
- if p_str not in sys.path:
48
- sys.path.insert(0, p_str)
49
- else:
50
- logger.warning(f"third_party path not found: {p}")
51
-
52
- _add_sys_path(TP_SAM2)
53
-
54
- # --------------------------------------------------------------------------------------
55
- # Safe Torch accessors
56
- # --------------------------------------------------------------------------------------
57
- def _torch():
58
- try:
59
- import torch
60
- return torch
61
- except Exception as e:
62
- logger.warning(f"[sam2_loader.safe-torch] import failed: {e}")
63
- return None
64
-
65
- def _has_cuda() -> bool:
66
- t = _torch()
67
- if t is None:
68
- return False
69
- try:
70
- return bool(t.cuda.is_available())
71
- except Exception as e:
72
- logger.warning(f"[sam2_loader.safe-torch] cuda.is_available() failed: {e}")
73
- return False
74
-
75
- def _pick_device(env_key: str) -> str:
76
- requested = os.environ.get(env_key, "").strip().lower()
77
- has_cuda = _has_cuda()
78
-
79
- logger.info(f"CUDA environment variables: {{'SAM2_DEVICE': '{os.environ.get('SAM2_DEVICE', '')}'}}")
80
- logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
81
-
82
- if has_cuda and requested not in {"cpu"}:
83
- logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
84
- return "cuda"
85
- elif requested in {"cuda", "cpu"}:
86
- logger.info(f"Using explicitly requested device: {requested}")
87
- return requested
88
-
89
- result = "cuda" if has_cuda else "cpu"
90
- logger.info(f"Auto-selected device: {result}")
91
- return result
92
-
93
- # --------------------------------------------------------------------------------------
94
- # SAM2 Loading and Mask Generation
95
- # --------------------------------------------------------------------------------------
96
- def _resolve_sam2_cfg(cfg_str: str) -> str:
97
- """Resolve SAM2 config path - return relative path for Hydra compatibility."""
98
- logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
99
-
100
- candidate = os.path.join(TP_SAM2, cfg_str)
101
- logger.info(f"Candidate path: {candidate}")
102
- logger.info(f"Candidate exists: {os.path.exists(candidate)}")
103
-
104
- if os.path.exists(candidate):
105
- if cfg_str.startswith("sam2/configs/"):
106
- relative_path = cfg_str.replace("sam2/configs/", "configs/")
107
- else:
108
- relative_path = cfg_str
109
- logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
110
- return relative_path
111
-
112
- fallbacks = [
113
- os.path.join(TP_SAM2, "sam2", cfg_str),
114
- os.path.join(TP_SAM2, "configs", cfg_str),
115
- ]
116
-
117
- for fallback in fallbacks:
118
- logger.info(f"Trying fallback: {fallback}")
119
- if os.path.exists(fallback):
120
- if "configs/" in fallback:
121
- relative_path = "configs/" + fallback.split("configs/")[-1]
122
- logger.info(f"Returning fallback relative path: {relative_path}")
123
- return relative_path
124
-
125
- logger.warning(f"Config not found, returning original: {cfg_str}")
126
- return cfg_str
127
-
128
- def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
129
- """If config references 'hieradet', try to find a 'hiera' config."""
130
- try:
131
- with open(cfg_path, "r") as f:
132
- data = yaml.safe_load(f)
133
- model = data.get("model", {}) or {}
134
- enc = model.get("image_encoder") or {}
135
- trunk = enc.get("trunk") or {}
136
- target = trunk.get("_target_") or trunk.get("target")
137
- if isinstance(target, str) and "hieradet" in target:
138
- for y in TP_SAM2.rglob("*.yaml"):
139
- try:
140
- with open(y, "r") as f2:
141
- d2 = yaml.safe_load(f2) or {}
142
- e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
143
- t2 = (e2.get("trunk") or {})
144
- tgt2 = t2.get("_target_") or t2.get("target")
145
- if isinstance(tgt2, str) and ".hiera." in tgt2:
146
- logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
147
- return str(y)
148
- except Exception:
149
- continue
150
- except Exception:
151
- pass
152
- return None
153
-
154
- def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
155
- """Robust SAM2 loader with config resolution and error handling."""
156
- meta = {"sam2_import_ok": False, "sam2_init_ok": False}
157
- try:
158
- from sam2.build_sam import build_sam2
159
- from sam2.sam2_image_predictor import SAM2ImagePredictor
160
- meta["sam2_import_ok"] = True
161
- except Exception as e:
162
- logger.warning(f"SAM2 import failed: {e}")
163
- return None, False, meta
164
-
165
- # Log SAM2 version
166
- try:
167
- version = importlib.metadata.version("segment-anything-2")
168
- logger.info(f"[SAM2] SAM2 version: {version}")
169
- except Exception:
170
- logger.info("[SAM2] SAM2 version unknown")
171
-
172
- # Check GPU memory before loading
173
- if torch and torch.cuda.is_available():
174
- mem_before = torch.cuda.memory_allocated() / 1024**3
175
- logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB")
176
-
177
- device = _pick_device("SAM2_DEVICE")
178
- cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
179
- cfg = _resolve_sam2_cfg(cfg_env)
180
- ckpt = os.environ.get("SAM2_CHECKPOINT", "")
181
-
182
- def _try_build(cfg_path: str):
183
- logger.info(f"_try_build called with cfg_path: {cfg_path}")
184
- params = set(inspect.signature(build_sam2).parameters.keys())
185
- logger.info(f"build_sam2 parameters: {list(params)}")
186
- kwargs = {}
187
- if "config_file" in params:
188
- kwargs["config_file"] = cfg_path
189
- logger.info(f"Using config_file parameter: {cfg_path}")
190
- elif "model_cfg" in params:
191
- kwargs["model_cfg"] = cfg_path
192
- logger.info(f"Using model_cfg parameter: {cfg_path}")
193
- if ckpt:
194
- if "checkpoint" in params:
195
- kwargs["checkpoint"] = ckpt
196
- elif "ckpt_path" in params:
197
- kwargs["ckpt_path"] = ckpt
198
- elif "weights" in params:
199
- kwargs["weights"] = ckpt
200
- if "device" in params:
201
- kwargs["device"] = device
202
- try:
203
- logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
204
- result = build_sam2(**kwargs)
205
- logger.info(f"build_sam2 succeeded with kwargs")
206
- if hasattr(result, 'device'):
207
- logger.info(f"SAM2 model device: {result.device}")
208
- elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
209
- logger.info(f"SAM2 model device: {result.image_encoder.device}")
210
- return result
211
- except TypeError as e:
212
- logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
213
- pos = [cfg_path]
214
- if ckpt:
215
- pos.append(ckpt)
216
- if "device" not in kwargs:
217
- pos.append(device)
218
- logger.info(f"Calling build_sam2 with positional args: {pos}")
219
- result = build_sam2(*pos)
220
- logger.info(f"build_sam2 succeeded with positional args")
221
- return result
222
-
223
- try:
224
- try:
225
- sam = _try_build(cfg)
226
- except Exception:
227
- alt_cfg = _find_hiera_config_if_hieradet(cfg)
228
- if alt_cfg:
229
- sam = _try_build(alt_cfg)
230
- else:
231
- raise
232
-
233
- if sam is not None:
234
- predictor = SAM2ImagePredictor(sam)
235
- meta["sam2_init_ok"] = True
236
- meta["sam2_device"] = device
237
- return predictor, True, meta
238
- else:
239
- return None, False, meta
240
-
241
- except Exception as e:
242
- logger.error(f"SAM2 loading failed: {e}")
243
- return None, False, meta
244
-
245
- def run_sam2_mask(predictor: object,
246
- first_frame_bgr: np.ndarray,
247
- point: Optional[Tuple[int, int]] = None,
248
- auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
249
- """Generate a seed mask for MatAnyone. Returns (mask_uint8_0_255, ok)."""
250
- if predictor is None:
251
- return None, False
252
- try:
253
- import cv2
254
- rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
255
- predictor.set_image(rgb)
256
-
257
- if auto:
258
- h, w = rgb.shape[:2]
259
- box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
260
- masks, _, _ = predictor.predict(box=box)
261
- elif point is not None:
262
- x, y = int(point[0]), int(point[1])
263
- pts = np.array([[x, y]], dtype=np.int32)
264
- labels = np.array([1], dtype=np.int32)
265
- masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
266
- else:
267
- h, w = rgb.shape[:2]
268
- box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
269
- masks, _, _ = predictor.predict(box=box)
270
-
271
- if masks is None or len(masks) == 0:
272
- return None, False
273
-
274
- m = masks[0].astype(np.uint8) * 255
275
- logger.info(f"[SAM2] Generated mask: shape={m.shape}, dtype={m.dtype}")
276
- return m, True
277
- except Exception as e:
278
- logger.warning(f"SAM2 mask generation failed: {e}")
279
- return None, False
 
1
+ # ===============================
2
+ # Optimized Dockerfile for Hugging Face Spaces
3
+ # PyTorch 2.3.1 + CUDA 12.1
4
+ # ===============================
5
+
6
+ # Base image with CUDA 12.1.1
7
+ FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
8
+
9
+ # Environment variables
10
+ ENV DEBIAN_FRONTEND=noninteractive \
11
+ PYTHONUNBUFFERED=1 \
12
+ PYTHONDONTWRITEBYTECODE=1 \
13
+ PIP_NO_CACHE_DIR=1 \
14
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
15
+ TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX" \
16
+ FORCE_CUDA="1" \
17
+ CUDA_VISIBLE_DEVICES="0"
18
+
19
+ # Create non-root user
20
+ RUN useradd -m -u 1000 user
21
+ ENV HOME=/home/user
22
+ WORKDIR $HOME/app
23
+
24
+ # Install system dependencies in a single layer
25
+ RUN apt-get update && apt-get install -y --no-install-recommends \
26
+ git \
27
+ ffmpeg \
28
+ wget \
29
+ python3 \
30
+ python3-pip \
31
+ python3-venv \
32
+ python3-dev \
33
+ build-essential \
34
+ gcc \
35
+ g++ \
36
+ pkg-config \
37
+ libffi-dev \
38
+ libssl-dev \
39
+ libc6-dev \
40
+ libgl1-mesa-glx \
41
+ libglib2.0-0 \
42
+ libsm6 \
43
+ libxext6 \
44
+ libxrender1 \
45
+ libgomp1 \
46
+ && rm -rf /var/lib/apt/lists/*
47
+
48
+ # Set up Python environment
49
+ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
50
+
51
+ # Install PyTorch with CUDA 12.1 first (base for other dependencies)
52
+ RUN python3 -m pip install --no-cache-dir \
53
+ --extra-index-url https://download.pytorch.org/whl/cu121 \
54
+ torch==2.3.1+cu121 \
55
+ torchvision==0.18.1+cu121 \
56
+ torchaudio==2.3.1+cu121 \
57
+ && python3 -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}'); print(f'CUDA version: {torch.version.cuda if torch.cuda.is_available() else \"N/A\"}'); print(f'cuDNN version: {torch.backends.cudnn.version() if torch.cuda.is_available() else \"N/A\"}')"
58
+
59
+ # Copy requirements files first for better caching
60
+ COPY --chown=user requirements.txt requirements-hf.txt ./
61
+
62
+ # Install Python dependencies
63
+ RUN python3 -m pip install --no-cache-dir -r requirements.txt
64
+
65
+ # Install MatAnyone with retry logic and fallback dependencies
66
+ RUN echo "Installing problematic dependencies first..." && \
67
+ python3 -m pip install --no-cache-dir chardet charset-normalizer && \
68
+ echo "Installing MatAnyone..." && \
69
+ (python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone || \
70
+ (echo "Retrying MatAnyone installation..." && \
71
+ python3 -m pip install --no-cache-dir -v git+https://github.com/pq-yang/MatAnyone@main#egg=matanyone)) && \
72
+ python3 -c "import matanyone; print('MatAnyone import successful')"
73
+
74
+ # Copy application code
75
+ COPY --chown=user . .
76
+
77
+ # Install SAM2
78
+ RUN echo "Installing SAM2..." && \
79
+ git clone --depth=1 https://github.com/facebookresearch/segment-anything-2.git third_party/sam2 && \
80
+ cd third_party/sam2 && \
81
+ python3 -m pip install --no-cache-dir -e .
82
+
83
+ # Set up environment variables
84
+ ENV PYTHONPATH=/home/user/app:/home/user/app/third_party:/home/user/app/third_party/sam2 \
85
+ FFMPEG_BIN=ffmpeg \
86
+ THIRD_PARTY_SAM2_DIR=/home/user/app/third_party/sam2 \
87
+ ENABLE_MATANY=1 \
88
+ SAM2_DEVICE=cuda \
89
+ MATANY_DEVICE=cuda \
90
+ OMP_NUM_THREADS=2 \
91
+ TF_CPP_MIN_LOG_LEVEL=2 \
92
+ SAM2_CHECKPOINT=/home/user/app/checkpoints/sam2_hiera_large.pt
93
+
94
+ # Create checkpoints directory
95
+ RUN mkdir -p /home/user/app/checkpoints
96
+
97
+ # Note: SAM2 model will be downloaded at runtime via lazy loading
98
+
99
+ # Health check
100
+ HEALTHCHECK --interval=30s --timeout=5s --retries=3 CMD python3 -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
101
+
102
+ # Run as non-root user
103
+ USER user
104
+ EXPOSE 7860
105
+
106
+ # Start the application
107
+ CMD ["python3", "-u", "app_hf.py"]