MogensR commited on
Commit
88eae72
·
1 Parent(s): c29dcc4

Update models/loaders/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/sam2_loader.py +175 -106
models/loaders/sam2_loader.py CHANGED
@@ -1,22 +1,31 @@
1
- #!/usr/bin/env python3
2
- """
3
- SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch2-ready)
4
- """
5
 
6
- from __future__ import annotations
 
 
7
 
8
- import os
9
- import time
10
- import logging
11
- import traceback
12
- from typing import Optional, Dict, Any, Tuple, List
13
 
14
- import numpy as np
15
- import torch
16
- import cv2
17
- import threading
18
 
 
19
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
20
 
21
  def _select_device(pref: str) -> str:
22
  pref = (pref or "").lower()
@@ -26,27 +35,42 @@ def _select_device(pref: str) -> str:
26
  return "cpu"
27
  return "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
29
  def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray:
 
 
 
 
30
  if img is None:
31
  raise ValueError("set_image received None image")
32
  arr = np.asarray(img)
33
  if arr.ndim != 3 or arr.shape[2] < 3:
34
  raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}")
 
35
  if np.issubdtype(arr.dtype, np.floating):
36
  arr = np.clip(arr, 0.0, 1.0)
37
  arr = (arr * 255.0 + 0.5).astype(np.uint8)
 
 
38
  elif arr.dtype != np.uint8:
39
- if arr.dtype == np.uint16:
40
- arr = (arr / 257).astype(np.uint8)
41
- else:
42
- arr = arr.astype(np.uint8)
43
- if arr.shape[2] == 4:
44
  arr = arr[:, :, :3]
 
45
  if force_bgr_to_rgb:
46
  arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
 
47
  return arr
48
 
 
49
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
 
 
 
 
 
 
50
  if h <= 0 or w <= 0:
51
  return h, w, 1.0
52
  s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
@@ -56,30 +80,34 @@ def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> T
56
  nw = max(1, int(round(w * s)))
57
  return nh, nw, s
58
 
 
59
  def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
 
60
  sizes = [(nh, nw)]
61
- sizes.append((max(1, int(nh * 0.85)), max(1, int(nw * 0.85))))
62
- sizes.append((max(1, int(nh * 0.70)), max(1, int(nw * 0.70))))
63
- sizes.append((max(1, int(nh * 0.50)), max(1, int(nw * 0.50))))
64
- sizes.append((max(1, int(nh * 0.35)), max(1, int(nw * 0.35))))
65
- uniq = []
66
- seen = set()
67
  for s in sizes:
68
  if s not in seen:
69
  uniq.append(s); seen.add(s)
70
  return uniq
71
 
 
72
  def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
 
 
 
 
 
 
 
 
 
73
  if masks.ndim != 3:
74
- masks = np.asarray(masks)
 
75
  if masks.ndim == 2:
76
  masks = masks[None, ...]
77
- elif masks.ndim == 4 and masks.shape[1] == 1:
78
- masks = masks[:, 0, :, :]
79
- else:
80
- masks = np.squeeze(masks)
81
- if masks.ndim == 2:
82
- masks = masks[None, ...]
83
  n, h, w = masks.shape
84
  H, W = out_hw
85
  if (h, w) == (H, W):
@@ -89,50 +117,49 @@ def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
89
  out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
90
  return np.clip(out, 0.0, 1.0)
91
 
 
92
  def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray:
93
  x = np.asarray(x)
94
  if x.dtype == np.uint8:
95
  return (x.astype(np.float32) / 255.0)
96
  return x.astype(np.float32, copy=False)
97
 
98
- # -------------------------- adapter --------------------------
99
-
100
  class _SAM2Adapter:
101
  """
102
- Wraps SAM2ImagePredictor to:
103
- - store original H,W
104
- - model-only downscale on set_image
105
- - OOM-aware predict with retry at smaller sizes
106
- - upsample masks back to original size
107
- - now thread-safe
108
  """
109
  def __init__(self, predictor, device: str):
110
  self.pred = predictor
111
  self.device = device
 
 
112
  self.orig_hw: Tuple[int, int] = (0, 0)
 
 
 
 
113
  self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024"))
114
  self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000"))
115
  self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1"
116
- self.use_autocast = (device == "cuda")
117
- self.autocast_dtype = None
118
- if self.use_autocast:
119
- try:
120
- if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
121
- self.autocast_dtype = torch.bfloat16
122
- else:
123
- cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
124
- self.autocast_dtype = torch.float16 if cc[0] >= 7 else None
125
- except Exception:
126
- self.autocast_dtype = None
127
- self._current_rgb: Optional[np.ndarray] = None
128
- self._current_hw: Tuple[int, int] = (0, 0)
129
  self._lock = threading.Lock()
130
 
 
 
131
  def set_image(self, image: np.ndarray):
 
 
 
132
  with self._lock:
133
  rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb)
134
  H, W = rgb.shape[:2]
135
  self.orig_hw = (H, W)
 
136
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
137
  if s < 1.0:
138
  work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
@@ -141,68 +168,83 @@ def set_image(self, image: np.ndarray):
141
  else:
142
  self._current_rgb = rgb
143
  self._current_hw = (H, W)
 
144
  self.pred.set_image(self._current_rgb)
145
 
146
  def predict(self, **kwargs) -> Dict[str, Any]:
 
 
 
 
 
147
  with self._lock:
148
  if self._current_rgb is None or self.orig_hw == (0, 0):
149
  raise RuntimeError("SAM2Adapter.predict called before set_image()")
 
150
  H, W = self.orig_hw
151
  nh, nw = self._current_hw
152
  sizes = _ladder(nh, nw)
153
  last_exc: Optional[BaseException] = None
 
154
  for (th, tw) in sizes:
155
  try:
 
156
  if (th, tw) != (nh, nw):
157
  small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA)
158
  self.pred.set_image(small)
 
 
159
  class _NoOp:
160
  def __enter__(self): return None
161
  def __exit__(self, *a): return False
162
- # -------- PyTorch 2.x autocast signature --------
163
- if self.use_autocast and self.autocast_dtype is not None:
164
- amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
 
 
 
 
165
  else:
166
  amp_ctx = _NoOp()
 
167
  with torch.inference_mode():
168
  with amp_ctx:
169
  out = self.pred.predict(**kwargs)
170
- # normalize outputs to dict
171
- masks = None
172
- scores = None
173
- logits = None
174
  if isinstance(out, dict):
175
- masks = out.get("masks", None)
176
- scores = out.get("scores", None)
177
- logits = out.get("logits", None)
178
  elif isinstance(out, (tuple, list)):
179
  if len(out) >= 1: masks = out[0]
180
  if len(out) >= 2: scores = out[1]
181
  if len(out) >= 3: logits = out[2]
182
  else:
183
  masks = out
 
184
  if masks is None:
185
  raise RuntimeError("SAM2 returned no masks")
186
- masks = np.asarray(masks)
187
- if masks.ndim == 2:
188
- masks = masks[None, ...]
189
- elif masks.ndim == 4 and masks.shape[1] == 1:
190
- masks = masks[:, 0, :, :]
191
  masks = _normalize_masks_dtype(masks)
192
  masks_up = _upsample_stack(masks, (H, W))
 
193
  if scores is None:
194
  scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5
195
  else:
196
  scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1)
 
197
  out_dict = {"masks": masks_up, "scores": scores}
198
  if logits is not None:
199
  lg = np.asarray(logits)
 
200
  if lg.ndim == 3:
201
  lg = _upsample_stack(lg, (H, W))
202
  elif lg.ndim == 4 and lg.shape[1] == 1:
203
  lg = _upsample_stack(lg[:, 0, :, :], (H, W))
204
  out_dict["logits"] = lg.astype(np.float32, copy=False)
 
205
  return out_dict
 
206
  except torch.cuda.OutOfMemoryError as e:
207
  last_exc = e
208
  if torch.cuda.is_available():
@@ -216,40 +258,52 @@ def __exit__(self, *a): return False
216
  logger.debug(traceback.format_exc())
217
  logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}")
218
  continue
219
- logger.warning(f"SAM2 calls failed; returning fallback. {last_exc}")
 
220
  return {
221
  "masks": np.ones((1, H, W), dtype=np.float32),
222
  "scores": np.array([0.5], dtype=np.float32),
223
  }
224
- # -------------------------- Loader --------------------------
225
 
226
  class SAM2Loader:
227
- """Dedicated loader for SAM2 models"""
228
 
229
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"):
230
  self.device = _select_device(device)
231
  self.cache_dir = cache_dir
232
  os.makedirs(self.cache_dir, exist_ok=True)
233
 
234
- # HuggingFace Hub for spaces: avoid symlink errors
235
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
236
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
237
 
238
- self.model = None # underlying predictor (SAM2ImagePredictor)
239
- self.adapter = None # wrapped predictor exposed to callers
240
  self.model_id = None
241
  self.load_time = 0.0
242
 
243
- def load(self, model_size: str = "auto") -> Optional[Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  """
245
- Load SAM2 model with specified size
246
- Args:
247
- model_size: "tiny", "small", "base", "large", or "auto"
248
- Returns:
249
- Wrapped predictor (adapter) or None
250
  """
251
  if model_size == "auto":
252
  model_size = self._determine_optimal_size()
 
253
  model_map = {
254
  "tiny": "facebook/sam2.1-hiera-tiny",
255
  "small": "facebook/sam2.1-hiera-small",
@@ -258,8 +312,8 @@ def load(self, model_size: str = "auto") -> Optional[Any]:
258
  }
259
  self.model_id = model_map.get(model_size, model_map["tiny"])
260
  logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
261
- strategies = [("official", self._load_official), ("fallback", self._load_fallback)]
262
- for name, fn in strategies:
263
  try:
264
  t0 = time.time()
265
  pred = fn()
@@ -273,25 +327,14 @@ def load(self, model_size: str = "auto") -> Optional[Any]:
273
  except Exception as e:
274
  logger.error(f"SAM2 {name} strategy failed: {e}")
275
  logger.debug(traceback.format_exc())
 
276
  logger.error("All SAM2 loading strategies failed")
277
  return None
278
 
279
- def _determine_optimal_size(self) -> str:
280
- """Determine optimal model size based on available memory"""
281
- try:
282
- if torch.cuda.is_available():
283
- props = torch.cuda.get_device_properties(0)
284
- vram_gb = props.total_memory / (1024**3)
285
- if vram_gb < 4: return "tiny"
286
- if vram_gb < 8: return "small"
287
- if vram_gb < 12: return "base"
288
- return "large"
289
- except Exception:
290
- pass
291
- return "tiny"
292
 
293
- def _load_official(self) -> Optional[Any]:
294
- """Load using official SAM2 API"""
295
  from sam2.sam2_image_predictor import SAM2ImagePredictor
296
  predictor = SAM2ImagePredictor.from_pretrained(
297
  self.model_id,
@@ -299,15 +342,14 @@ def _load_official(self) -> Optional[Any]:
299
  local_files_only=False,
300
  trust_remote_code=True,
301
  )
 
302
  if hasattr(predictor, "model"):
303
  predictor.model = predictor.model.to(self.device)
304
  predictor.model.eval()
305
- if hasattr(predictor, "device"):
306
- predictor.device = self.device
307
  return predictor
308
 
309
- def _load_fallback(self) -> Optional[Any]:
310
- """Create a tiny fallback predictor"""
311
  class FallbackSAM2:
312
  def __init__(self, device):
313
  self.device = device
@@ -315,11 +357,7 @@ def __init__(self, device):
315
  def set_image(self, image):
316
  self._img = np.asarray(image)
317
  def predict(self, **kwargs):
318
- if self._img is not None:
319
- h, w = self._img.shape[:2]
320
- else:
321
- h, w = 512, 512
322
- # Return a full-ones mask—**handled downstream!**
323
  return {
324
  "masks": np.ones((1, h, w), dtype=np.float32),
325
  "scores": np.array([0.5], dtype=np.float32),
@@ -327,6 +365,8 @@ def predict(self, **kwargs):
327
  logger.warning("Using fallback SAM2 (no real segmentation)")
328
  return FallbackSAM2(self.device)
329
 
 
 
330
  def cleanup(self):
331
  self.adapter = None
332
  if self.model is not None:
@@ -346,3 +386,32 @@ def get_info(self) -> Dict[str, Any]:
346
  "load_time": self.load_time,
347
  "model_type": type(self.model).__name__ if self.model else None,
348
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sam2_loader import SAM2Loader
2
+ import cv2, numpy as np
 
 
3
 
4
+ # Load SAM2 (auto-selects size from VRAM; or pass "tiny|small|base|large")
5
+ sam_adapter = SAM2Loader(device="cuda").load(model_size="auto")
6
+ assert sam_adapter, "SAM2 failed to load"
7
 
8
+ # 1) Provide the first frame (BGR or RGB ok; float [0..1] or uint8)
9
+ bgr0 = cv2.imread("frame0001.jpg")
10
+ sam_adapter.set_image(bgr0) # internally converts if needed
 
 
11
 
12
+ # 2) Predict a coarse person mask to “boot” MatAnyone
13
+ out = sam_adapter.predict(point_coords=None, point_labels=None) # or your prompt strategy
14
+ masks = out["masks"] # (N,H,W) float32 in [0,1], sized to original frame
15
+ first_mask = masks[0] if masks is not None and len(masks) else np.ones_like(bgr0[...,0], np.float32)
16
 
17
+ # Logging
18
  logger = logging.getLogger(__name__)
19
+ if not logger.handlers:
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+ # Silence bad OMP values that sometimes leak in Spaces
23
+ _val = os.environ.get("OMP_NUM_THREADS")
24
+ if _val is not None and not str(_val).strip().isdigit():
25
+ try:
26
+ del os.environ["OMP_NUM_THREADS"]
27
+ except Exception:
28
+ pass
29
 
30
  def _select_device(pref: str) -> str:
31
  pref = (pref or "").lower()
 
35
  return "cpu"
36
  return "cuda" if torch.cuda.is_available() else "cpu"
37
 
38
+
39
  def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray:
40
+ """
41
+ Accepts: HxWxC where C>=3; dtype uint8/float/uint16; optional BGRA/RGBA.
42
+ Returns: RGB uint8 HxWx3
43
+ """
44
  if img is None:
45
  raise ValueError("set_image received None image")
46
  arr = np.asarray(img)
47
  if arr.ndim != 3 or arr.shape[2] < 3:
48
  raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}")
49
+
50
  if np.issubdtype(arr.dtype, np.floating):
51
  arr = np.clip(arr, 0.0, 1.0)
52
  arr = (arr * 255.0 + 0.5).astype(np.uint8)
53
+ elif arr.dtype == np.uint16:
54
+ arr = (arr / 257).astype(np.uint8) # 16→8 bit
55
  elif arr.dtype != np.uint8:
56
+ arr = arr.astype(np.uint8)
57
+
58
+ if arr.shape[2] == 4: # drop alpha
 
 
59
  arr = arr[:, :, :3]
60
+
61
  if force_bgr_to_rgb:
62
  arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
63
+
64
  return arr
65
 
66
+
67
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
68
+ """
69
+ Scale so that:
70
+ - max(h, w) <= max_edge
71
+ - h*w <= target_pixels
72
+ Returns: (nh, nw, scale) with nh,nw >= 1
73
+ """
74
  if h <= 0 or w <= 0:
75
  return h, w, 1.0
76
  s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
 
80
  nw = max(1, int(round(w * s)))
81
  return nh, nw, s
82
 
83
+
84
  def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
85
+ """Progressive smaller sizes to retry on OOM or other failures."""
86
  sizes = [(nh, nw)]
87
+ for f in (0.85, 0.70, 0.55, 0.40, 0.30):
88
+ sizes.append((max(64, int(nh * f)), max(64, int(nw * f))))
89
+ uniq, seen = [], set()
 
 
 
90
  for s in sizes:
91
  if s not in seen:
92
  uniq.append(s); seen.add(s)
93
  return uniq
94
 
95
+
96
  def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
97
+ """
98
+ Input masks may be (N,H,W) or (N,1,H,W) or (H,W).
99
+ Output is always (N, H_out, W_out) float32 in [0,1].
100
+ """
101
+ masks = np.asarray(masks)
102
+ if masks.ndim == 2:
103
+ masks = masks[None, ...]
104
+ elif masks.ndim == 4 and masks.shape[1] == 1:
105
+ masks = masks[:, 0, :, :]
106
  if masks.ndim != 3:
107
+ # try best-effort squeeze
108
+ masks = np.squeeze(masks)
109
  if masks.ndim == 2:
110
  masks = masks[None, ...]
 
 
 
 
 
 
111
  n, h, w = masks.shape
112
  H, W = out_hw
113
  if (h, w) == (H, W):
 
117
  out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
118
  return np.clip(out, 0.0, 1.0)
119
 
120
+
121
  def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray:
122
  x = np.asarray(x)
123
  if x.dtype == np.uint8:
124
  return (x.astype(np.float32) / 255.0)
125
  return x.astype(np.float32, copy=False)
126
 
 
 
127
  class _SAM2Adapter:
128
  """
129
+ Thin guard around SAM2ImagePredictor that:
130
+ - remembers original H,W
131
+ - VRAM-downscales on set_image(); retries smaller on failure
132
+ - upsamples masks to original H,W
133
+ - uses torch.autocast(device_type="cuda", ...) when available
134
+ - is thread-safe (single predictor instance can serve concurrent calls)
135
  """
136
  def __init__(self, predictor, device: str):
137
  self.pred = predictor
138
  self.device = device
139
+
140
+ # Original and working sizes
141
  self.orig_hw: Tuple[int, int] = (0, 0)
142
+ self._current_rgb: Optional[np.ndarray] = None
143
+ self._current_hw: Tuple[int, int] = (0, 0)
144
+
145
+ # Tuning knobs via env
146
  self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024"))
147
  self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000"))
148
  self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1"
149
+
 
 
 
 
 
 
 
 
 
 
 
 
150
  self._lock = threading.Lock()
151
 
152
+ # ------------------ public API ------------------
153
+
154
  def set_image(self, image: np.ndarray):
155
+ """
156
+ image: RGB or BGR; float [0..1] or uint8; HxWx{3,4}
157
+ """
158
  with self._lock:
159
  rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb)
160
  H, W = rgb.shape[:2]
161
  self.orig_hw = (H, W)
162
+
163
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
164
  if s < 1.0:
165
  work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
 
168
  else:
169
  self._current_rgb = rgb
170
  self._current_hw = (H, W)
171
+
172
  self.pred.set_image(self._current_rgb)
173
 
174
  def predict(self, **kwargs) -> Dict[str, Any]:
175
+ """
176
+ Calls SAM2 predictor with your prompt args (points/boxes/etc).
177
+ Returns: {"masks": (N,H,W) float32, "scores": (N,) float32, "logits"?: ...}
178
+ On any failure path, returns a full-ones mask as a safe fallback.
179
+ """
180
  with self._lock:
181
  if self._current_rgb is None or self.orig_hw == (0, 0):
182
  raise RuntimeError("SAM2Adapter.predict called before set_image()")
183
+
184
  H, W = self.orig_hw
185
  nh, nw = self._current_hw
186
  sizes = _ladder(nh, nw)
187
  last_exc: Optional[BaseException] = None
188
+
189
  for (th, tw) in sizes:
190
  try:
191
+ # Optionally re-set smaller image
192
  if (th, tw) != (nh, nw):
193
  small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA)
194
  self.pred.set_image(small)
195
+
196
+ # PyTorch 2.x autocast
197
  class _NoOp:
198
  def __enter__(self): return None
199
  def __exit__(self, *a): return False
200
+
201
+ use_amp = (self.device == "cuda")
202
+ if use_amp:
203
+ amp_ctx = torch.autocast(
204
+ device_type="cuda",
205
+ dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
206
+ )
207
  else:
208
  amp_ctx = _NoOp()
209
+
210
  with torch.inference_mode():
211
  with amp_ctx:
212
  out = self.pred.predict(**kwargs)
213
+
214
+ # Normalize outputs
215
+ masks = None; scores = None; logits = None
 
216
  if isinstance(out, dict):
217
+ masks = out.get("masks"); scores = out.get("scores"); logits = out.get("logits")
 
 
218
  elif isinstance(out, (tuple, list)):
219
  if len(out) >= 1: masks = out[0]
220
  if len(out) >= 2: scores = out[1]
221
  if len(out) >= 3: logits = out[2]
222
  else:
223
  masks = out
224
+
225
  if masks is None:
226
  raise RuntimeError("SAM2 returned no masks")
227
+
 
 
 
 
228
  masks = _normalize_masks_dtype(masks)
229
  masks_up = _upsample_stack(masks, (H, W))
230
+
231
  if scores is None:
232
  scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5
233
  else:
234
  scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1)
235
+
236
  out_dict = {"masks": masks_up, "scores": scores}
237
  if logits is not None:
238
  lg = np.asarray(logits)
239
+ # Best-effort upsample if spatial
240
  if lg.ndim == 3:
241
  lg = _upsample_stack(lg, (H, W))
242
  elif lg.ndim == 4 and lg.shape[1] == 1:
243
  lg = _upsample_stack(lg[:, 0, :, :], (H, W))
244
  out_dict["logits"] = lg.astype(np.float32, copy=False)
245
+
246
  return out_dict
247
+
248
  except torch.cuda.OutOfMemoryError as e:
249
  last_exc = e
250
  if torch.cuda.is_available():
 
258
  logger.debug(traceback.format_exc())
259
  logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}")
260
  continue
261
+
262
+ logger.warning(f"SAM2 calls failed; returning fallback mask. {last_exc}")
263
  return {
264
  "masks": np.ones((1, H, W), dtype=np.float32),
265
  "scores": np.array([0.5], dtype=np.float32),
266
  }
 
267
 
268
  class SAM2Loader:
269
+ """Dedicated loader for SAM2 models (PyTorch 2.x, Spaces-friendly)."""
270
 
271
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"):
272
  self.device = _select_device(device)
273
  self.cache_dir = cache_dir
274
  os.makedirs(self.cache_dir, exist_ok=True)
275
 
276
+ # Hugging Face Hub knobs for Spaces
277
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
278
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
279
 
280
+ self.model = None # underlying SAM2ImagePredictor
281
+ self.adapter = None # _SAM2Adapter
282
  self.model_id = None
283
  self.load_time = 0.0
284
 
285
+ def _determine_optimal_size(self) -> str:
286
+ """Choose model size based on VRAM."""
287
+ try:
288
+ if torch.cuda.is_available():
289
+ props = torch.cuda.get_device_properties(0)
290
+ vram_gb = props.total_memory / (1024**3)
291
+ if vram_gb < 4: return "tiny"
292
+ if vram_gb < 8: return "small"
293
+ if vram_gb < 12: return "base"
294
+ return "large"
295
+ except Exception:
296
+ pass
297
+ return "tiny"
298
+
299
+ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]:
300
  """
301
+ model_size: "tiny" | "small" | "base" | "large" | "auto"
302
+ Returns: thread-safe adapter or None
 
 
 
303
  """
304
  if model_size == "auto":
305
  model_size = self._determine_optimal_size()
306
+
307
  model_map = {
308
  "tiny": "facebook/sam2.1-hiera-tiny",
309
  "small": "facebook/sam2.1-hiera-small",
 
312
  }
313
  self.model_id = model_map.get(model_size, model_map["tiny"])
314
  logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
315
+
316
+ for name, fn in (("official", self._load_official), ("fallback", self._load_fallback)):
317
  try:
318
  t0 = time.time()
319
  pred = fn()
 
327
  except Exception as e:
328
  logger.error(f"SAM2 {name} strategy failed: {e}")
329
  logger.debug(traceback.format_exc())
330
+
331
  logger.error("All SAM2 loading strategies failed")
332
  return None
333
 
334
+ # -------------- strategies --------------
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ def _load_official(self):
337
+ """Load SAM2ImagePredictor via its official API and move weights to device."""
338
  from sam2.sam2_image_predictor import SAM2ImagePredictor
339
  predictor = SAM2ImagePredictor.from_pretrained(
340
  self.model_id,
 
342
  local_files_only=False,
343
  trust_remote_code=True,
344
  )
345
+ # Move **model** to device; DO NOT set predictor.device (read-only → error)
346
  if hasattr(predictor, "model"):
347
  predictor.model = predictor.model.to(self.device)
348
  predictor.model.eval()
 
 
349
  return predictor
350
 
351
+ def _load_fallback(self):
352
+ """Tiny local fallback that returns a full-ones mask — keeps pipeline alive."""
353
  class FallbackSAM2:
354
  def __init__(self, device):
355
  self.device = device
 
357
  def set_image(self, image):
358
  self._img = np.asarray(image)
359
  def predict(self, **kwargs):
360
+ h, w = (self._img.shape[:2] if self._img is not None else (512, 512))
 
 
 
 
361
  return {
362
  "masks": np.ones((1, h, w), dtype=np.float32),
363
  "scores": np.array([0.5], dtype=np.float32),
 
365
  logger.warning("Using fallback SAM2 (no real segmentation)")
366
  return FallbackSAM2(self.device)
367
 
368
+ # -------------- housekeeping --------------
369
+
370
  def cleanup(self):
371
  self.adapter = None
372
  if self.model is not None:
 
386
  "load_time": self.load_time,
387
  "model_type": type(self.model).__name__ if self.model else None,
388
  }
389
+
390
+ if __name__ == "__main__":
391
+ import sys
392
+
393
+ logging.basicConfig(level=logging.INFO)
394
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
395
+
396
+ if len(sys.argv) < 2:
397
+ print(f"Usage: {sys.argv[0]} image.jpg")
398
+ raise SystemExit(1)
399
+
400
+ path = sys.argv[1]
401
+ img = cv2.imread(path, cv2.IMREAD_COLOR)
402
+ if img is None:
403
+ print(f"Could not load image {path}")
404
+ raise SystemExit(2)
405
+
406
+ loader = SAM2Loader(device=dev)
407
+ sam = loader.load("auto")
408
+ if not sam:
409
+ print("Failed to load SAM2")
410
+ raise SystemExit(3)
411
+
412
+ sam.set_image(img)
413
+ out = sam.predict(point_coords=None, point_labels=None)
414
+ m = out["masks"]
415
+ print("Masks:", m.shape, m.dtype, m.min(), m.max())
416
+ cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
417
+ print("Wrote sam2_mask0.png")