MogensR commited on
Commit
c30a2cc
·
1 Parent(s): 47f3540

Update models/loaders/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/sam2_loader.py +96 -184
models/loaders/sam2_loader.py CHANGED
@@ -1,12 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe)
4
-
5
- - Loads a SAM2 image predictor on the desired device.
6
- - set_image(): accepts RGB/BGR, uint8/float; optional model-only downscale to save VRAM.
7
- - predict(): forwards prompts, upsamples masks back to original size, normalizes outputs.
8
- - Uses torch.inference_mode + optional autocast on CUDA.
9
- - Returns shapes compatible with utils.cv_processing.segment_person_hq logic.
10
  """
11
 
12
  from __future__ import annotations
@@ -20,12 +14,10 @@
20
  import numpy as np
21
  import torch
22
  import cv2
 
23
 
24
  logger = logging.getLogger(__name__)
25
 
26
-
27
- # -------------------------- helpers --------------------------
28
-
29
  def _select_device(pref: str) -> str:
30
  pref = (pref or "").lower()
31
  if pref.startswith("cuda"):
@@ -34,21 +26,12 @@ def _select_device(pref: str) -> str:
34
  return "cpu"
35
  return "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
-
38
  def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray:
39
- """
40
- Accept BGR/RGB, 3ch/4ch, uint8/float; return RGB uint8 [H,W,3].
41
- We DO NOT blindly swap channels; cv_processing already feeds RGB.
42
- Set force_bgr_to_rgb=True only if you know inputs are BGR.
43
- """
44
  if img is None:
45
  raise ValueError("set_image received None image")
46
-
47
  arr = np.asarray(img)
48
  if arr.ndim != 3 or arr.shape[2] < 3:
49
  raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}")
50
-
51
- # If float, clamp + scale to uint8
52
  if np.issubdtype(arr.dtype, np.floating):
53
  arr = np.clip(arr, 0.0, 1.0)
54
  arr = (arr * 255.0 + 0.5).astype(np.uint8)
@@ -57,17 +40,12 @@ def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.nda
57
  arr = (arr / 257).astype(np.uint8)
58
  else:
59
  arr = arr.astype(np.uint8)
60
-
61
- # If 4-channel, drop alpha
62
  if arr.shape[2] == 4:
63
  arr = arr[:, :, :3]
64
-
65
  if force_bgr_to_rgb:
66
  arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
67
-
68
  return arr
69
 
70
-
71
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
72
  if h <= 0 or w <= 0:
73
  return h, w, 1.0
@@ -78,17 +56,12 @@ def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> T
78
  nw = max(1, int(round(w * s)))
79
  return nh, nw, s
80
 
81
-
82
  def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
83
- """
84
- Progressive smaller sizes for OOM fallback.
85
- """
86
  sizes = [(nh, nw)]
87
  sizes.append((max(1, int(nh * 0.85)), max(1, int(nw * 0.85))))
88
  sizes.append((max(1, int(nh * 0.70)), max(1, int(nw * 0.70))))
89
  sizes.append((max(1, int(nh * 0.50)), max(1, int(nw * 0.50))))
90
  sizes.append((max(1, int(nh * 0.35)), max(1, int(nw * 0.35))))
91
- # de-duplicate and keep order
92
  uniq = []
93
  seen = set()
94
  for s in sizes:
@@ -96,11 +69,7 @@ def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
96
  uniq.append(s); seen.add(s)
97
  return uniq
98
 
99
-
100
  def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
101
- """
102
- masks: (N,h,w) float → bilinear → (N,H,W) float [0..1]
103
- """
104
  if masks.ndim != 3:
105
  masks = np.asarray(masks)
106
  if masks.ndim == 2:
@@ -108,7 +77,6 @@ def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
108
  elif masks.ndim == 4 and masks.shape[1] == 1:
109
  masks = masks[:, 0, :, :]
110
  else:
111
- # try to squeeze to N,H,W
112
  masks = np.squeeze(masks)
113
  if masks.ndim == 2:
114
  masks = masks[None, ...]
@@ -121,14 +89,12 @@ def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
121
  out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
122
  return np.clip(out, 0.0, 1.0)
123
 
124
-
125
  def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray:
126
  x = np.asarray(x)
127
  if x.dtype == np.uint8:
128
  return (x.astype(np.float32) / 255.0)
129
  return x.astype(np.float32, copy=False)
130
 
131
-
132
  # -------------------------- adapter --------------------------
133
 
134
  class _SAM2Adapter:
@@ -138,22 +104,16 @@ class _SAM2Adapter:
138
  - model-only downscale on set_image
139
  - OOM-aware predict with retry at smaller sizes
140
  - upsample masks back to original size
 
141
  """
142
  def __init__(self, predictor, device: str):
143
  self.pred = predictor
144
  self.device = device
145
-
146
- # original image size (for upsample)
147
  self.orig_hw: Tuple[int, int] = (0, 0)
148
-
149
- # env tunables
150
  self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024"))
151
  self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000"))
152
  self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1"
153
-
154
- # precision
155
  self.use_autocast = (device == "cuda")
156
- # prefer bf16 if available, else fp16; it's only a hint for the internal ops
157
  self.autocast_dtype = None
158
  if self.use_autocast:
159
  try:
@@ -164,138 +124,103 @@ def __init__(self, predictor, device: str):
164
  self.autocast_dtype = torch.float16 if cc[0] >= 7 else None
165
  except Exception:
166
  self.autocast_dtype = None
167
-
168
- # cached current working image (RGB uint8) and its size
169
  self._current_rgb: Optional[np.ndarray] = None
170
  self._current_hw: Tuple[int, int] = (0, 0)
171
-
172
- # --- API mirror ---
173
 
174
  def set_image(self, image: np.ndarray):
175
- """
176
- Accept RGB or BGR, uint8 or float, any resolution.
177
- Model-only downscale; keep orig H,W for upsample later.
178
- """
179
- rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb)
180
- H, W = rgb.shape[:2]
181
- self.orig_hw = (H, W)
182
-
183
- nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
184
- if s < 1.0:
185
- work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
186
- self._current_rgb = work
187
- self._current_hw = (nh, nw)
188
- else:
189
- self._current_rgb = rgb
190
- self._current_hw = (H, W)
191
-
192
- # prime embeddings on predictor
193
- self.pred.set_image(self._current_rgb)
194
 
195
  def predict(self, **kwargs) -> Dict[str, Any]:
196
- """
197
- Forwards prompts to underlying predictor; retries smaller if OOM.
198
- Always returns:
199
- {"masks": (N,H,W) float32 [0..1], "scores": (N,), "logits": optional}
200
- where (H,W) are the ORIGINAL image size provided to set_image().
201
- """
202
- if self._current_rgb is None or self.orig_hw == (0, 0):
203
- raise RuntimeError("SAM2Adapter.predict called before set_image()")
204
-
205
- H, W = self.orig_hw
206
- nh, nw = self._current_hw
207
- sizes = _ladder(nh, nw)
208
-
209
- last_exc: Optional[BaseException] = None
210
-
211
- for (th, tw) in sizes:
212
- try:
213
- # if we need a smaller embedding, rebuild set_image()
214
- if (th, tw) != (nh, nw):
215
- small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA)
216
- self.pred.set_image(small)
217
-
218
- # inference guard
219
- class _NoOp:
220
- def __enter__(self): return None
221
- def __exit__(self, *a): return False
222
-
223
- amp_ctx = _NoOp()
224
- if self.use_autocast and self.autocast_dtype is not None:
225
- amp_ctx = torch.cuda.amp.autocast(dtype=self.autocast_dtype)
226
-
227
- with torch.inference_mode():
228
- with amp_ctx:
229
- out = self.pred.predict(**kwargs)
230
-
231
- # normalize outputs to dict
232
- masks = None
233
- scores = None
234
- logits = None
235
-
236
- if isinstance(out, dict):
237
- masks = out.get("masks", None)
238
- scores = out.get("scores", None)
239
- logits = out.get("logits", None)
240
- elif isinstance(out, (tuple, list)):
241
- if len(out) >= 1: masks = out[0]
242
- if len(out) >= 2: scores = out[1]
243
- if len(out) >= 3: logits = out[2]
244
- else:
245
- masks = out
246
-
247
- if masks is None:
248
- raise RuntimeError("SAM2 returned no masks")
249
-
250
- masks = np.asarray(masks)
251
- # SAM2 variants: (N,H,W) or (N,1,H,W) or (H,W)
252
- if masks.ndim == 2:
253
- masks = masks[None, ...]
254
- elif masks.ndim == 4 and masks.shape[1] == 1:
255
- masks = masks[:, 0, :, :]
256
-
257
- masks = _normalize_masks_dtype(masks)
258
-
259
- # upsample to original resolution
260
- masks_up = _upsample_stack(masks, (H, W))
261
-
262
- # standardize scores
263
- if scores is None:
264
- scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5
265
- else:
266
- scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1)
267
-
268
- out_dict = {"masks": masks_up, "scores": scores}
269
- if logits is not None:
270
- # best-effort: resize per-channel to (H,W)
271
- lg = np.asarray(logits)
272
- if lg.ndim == 3:
273
- lg = _upsample_stack(lg, (H, W))
274
- elif lg.ndim == 4 and lg.shape[1] == 1:
275
- lg = _upsample_stack(lg[:, 0, :, :], (H, W))
276
- out_dict["logits"] = lg.astype(np.float32, copy=False)
277
- return out_dict
278
-
279
- except torch.cuda.OutOfMemoryError as e:
280
- last_exc = e
281
- logger.warning(f"SAM2 OOM at {th}x{tw}; retrying smaller. {e}")
282
- torch.cuda.empty_cache()
283
- continue
284
- except Exception as e:
285
- last_exc = e
286
- logger.debug(traceback.format_exc())
287
- logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}")
288
- torch.cuda.empty_cache()
289
- continue
290
-
291
- # All attempts failed → safe fallback (full mask)
292
- logger.warning(f"SAM2 calls failed; returning fallback. {last_exc}")
293
- return {
294
- "masks": np.ones((1, H, W), dtype=np.float32),
295
- "scores": np.array([0.5], dtype=np.float32),
296
- }
297
-
298
-
299
  # -------------------------- Loader --------------------------
300
 
301
  class SAM2Loader:
@@ -306,7 +231,7 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_ca
306
  self.cache_dir = cache_dir
307
  os.makedirs(self.cache_dir, exist_ok=True)
308
 
309
- # Configure HF hub for spaces
310
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
311
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
312
 
@@ -325,20 +250,15 @@ def load(self, model_size: str = "auto") -> Optional[Any]:
325
  """
326
  if model_size == "auto":
327
  model_size = self._determine_optimal_size()
328
-
329
  model_map = {
330
  "tiny": "facebook/sam2.1-hiera-tiny",
331
  "small": "facebook/sam2.1-hiera-small",
332
  "base": "facebook/sam2.1-hiera-base-plus",
333
  "large": "facebook/sam2.1-hiera-large",
334
  }
335
-
336
  self.model_id = model_map.get(model_size, model_map["tiny"])
337
  logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
338
-
339
- # Try the official loader
340
  strategies = [("official", self._load_official), ("fallback", self._load_fallback)]
341
-
342
  for name, fn in strategies:
343
  try:
344
  t0 = time.time()
@@ -353,7 +273,6 @@ def load(self, model_size: str = "auto") -> Optional[Any]:
353
  except Exception as e:
354
  logger.error(f"SAM2 {name} strategy failed: {e}")
355
  logger.debug(traceback.format_exc())
356
-
357
  logger.error("All SAM2 loading strategies failed")
358
  return None
359
 
@@ -374,26 +293,21 @@ def _determine_optimal_size(self) -> str:
374
  def _load_official(self) -> Optional[Any]:
375
  """Load using official SAM2 API"""
376
  from sam2.sam2_image_predictor import SAM2ImagePredictor
377
-
378
  predictor = SAM2ImagePredictor.from_pretrained(
379
  self.model_id,
380
  cache_dir=self.cache_dir,
381
  local_files_only=False,
382
  trust_remote_code=True,
383
  )
384
-
385
- # Move internal model to device if present
386
  if hasattr(predictor, "model"):
387
  predictor.model = predictor.model.to(self.device)
388
  predictor.model.eval()
389
  if hasattr(predictor, "device"):
390
  predictor.device = self.device
391
-
392
  return predictor
393
 
394
  def _load_fallback(self) -> Optional[Any]:
395
  """Create a tiny fallback predictor"""
396
-
397
  class FallbackSAM2:
398
  def __init__(self, device):
399
  self.device = device
@@ -405,16 +319,15 @@ def predict(self, **kwargs):
405
  h, w = self._img.shape[:2]
406
  else:
407
  h, w = 512, 512
 
408
  return {
409
  "masks": np.ones((1, h, w), dtype=np.float32),
410
  "scores": np.array([0.5], dtype=np.float32),
411
  }
412
-
413
  logger.warning("Using fallback SAM2 (no real segmentation)")
414
  return FallbackSAM2(self.device)
415
 
416
  def cleanup(self):
417
- """Clean up resources"""
418
  self.adapter = None
419
  if self.model is not None:
420
  try:
@@ -426,7 +339,6 @@ def cleanup(self):
426
  torch.cuda.empty_cache()
427
 
428
  def get_info(self) -> Dict[str, Any]:
429
- """Get loader information"""
430
  return {
431
  "loaded": self.adapter is not None,
432
  "model_id": self.model_id,
 
1
  #!/usr/bin/env python3
2
  """
3
+ SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch2-ready)
 
 
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
 
14
  import numpy as np
15
  import torch
16
  import cv2
17
+ import threading
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
 
 
21
  def _select_device(pref: str) -> str:
22
  pref = (pref or "").lower()
23
  if pref.startswith("cuda"):
 
26
  return "cpu"
27
  return "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
29
  def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray:
 
 
 
 
 
30
  if img is None:
31
  raise ValueError("set_image received None image")
 
32
  arr = np.asarray(img)
33
  if arr.ndim != 3 or arr.shape[2] < 3:
34
  raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}")
 
 
35
  if np.issubdtype(arr.dtype, np.floating):
36
  arr = np.clip(arr, 0.0, 1.0)
37
  arr = (arr * 255.0 + 0.5).astype(np.uint8)
 
40
  arr = (arr / 257).astype(np.uint8)
41
  else:
42
  arr = arr.astype(np.uint8)
 
 
43
  if arr.shape[2] == 4:
44
  arr = arr[:, :, :3]
 
45
  if force_bgr_to_rgb:
46
  arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
 
47
  return arr
48
 
 
49
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
50
  if h <= 0 or w <= 0:
51
  return h, w, 1.0
 
56
  nw = max(1, int(round(w * s)))
57
  return nh, nw, s
58
 
 
59
  def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
 
 
 
60
  sizes = [(nh, nw)]
61
  sizes.append((max(1, int(nh * 0.85)), max(1, int(nw * 0.85))))
62
  sizes.append((max(1, int(nh * 0.70)), max(1, int(nw * 0.70))))
63
  sizes.append((max(1, int(nh * 0.50)), max(1, int(nw * 0.50))))
64
  sizes.append((max(1, int(nh * 0.35)), max(1, int(nw * 0.35))))
 
65
  uniq = []
66
  seen = set()
67
  for s in sizes:
 
69
  uniq.append(s); seen.add(s)
70
  return uniq
71
 
 
72
  def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
 
 
 
73
  if masks.ndim != 3:
74
  masks = np.asarray(masks)
75
  if masks.ndim == 2:
 
77
  elif masks.ndim == 4 and masks.shape[1] == 1:
78
  masks = masks[:, 0, :, :]
79
  else:
 
80
  masks = np.squeeze(masks)
81
  if masks.ndim == 2:
82
  masks = masks[None, ...]
 
89
  out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
90
  return np.clip(out, 0.0, 1.0)
91
 
 
92
  def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray:
93
  x = np.asarray(x)
94
  if x.dtype == np.uint8:
95
  return (x.astype(np.float32) / 255.0)
96
  return x.astype(np.float32, copy=False)
97
 
 
98
  # -------------------------- adapter --------------------------
99
 
100
  class _SAM2Adapter:
 
104
  - model-only downscale on set_image
105
  - OOM-aware predict with retry at smaller sizes
106
  - upsample masks back to original size
107
+ - now thread-safe
108
  """
109
  def __init__(self, predictor, device: str):
110
  self.pred = predictor
111
  self.device = device
 
 
112
  self.orig_hw: Tuple[int, int] = (0, 0)
 
 
113
  self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024"))
114
  self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000"))
115
  self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1"
 
 
116
  self.use_autocast = (device == "cuda")
 
117
  self.autocast_dtype = None
118
  if self.use_autocast:
119
  try:
 
124
  self.autocast_dtype = torch.float16 if cc[0] >= 7 else None
125
  except Exception:
126
  self.autocast_dtype = None
 
 
127
  self._current_rgb: Optional[np.ndarray] = None
128
  self._current_hw: Tuple[int, int] = (0, 0)
129
+ self._lock = threading.Lock()
 
130
 
131
  def set_image(self, image: np.ndarray):
132
+ with self._lock:
133
+ rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb)
134
+ H, W = rgb.shape[:2]
135
+ self.orig_hw = (H, W)
136
+ nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
137
+ if s < 1.0:
138
+ work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
139
+ self._current_rgb = work
140
+ self._current_hw = (nh, nw)
141
+ else:
142
+ self._current_rgb = rgb
143
+ self._current_hw = (H, W)
144
+ self.pred.set_image(self._current_rgb)
 
 
 
 
 
 
145
 
146
  def predict(self, **kwargs) -> Dict[str, Any]:
147
+ with self._lock:
148
+ if self._current_rgb is None or self.orig_hw == (0, 0):
149
+ raise RuntimeError("SAM2Adapter.predict called before set_image()")
150
+ H, W = self.orig_hw
151
+ nh, nw = self._current_hw
152
+ sizes = _ladder(nh, nw)
153
+ last_exc: Optional[BaseException] = None
154
+ for (th, tw) in sizes:
155
+ try:
156
+ if (th, tw) != (nh, nw):
157
+ small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA)
158
+ self.pred.set_image(small)
159
+ class _NoOp:
160
+ def __enter__(self): return None
161
+ def __exit__(self, *a): return False
162
+ # -------- PyTorch 2.x autocast signature --------
163
+ if self.use_autocast and self.autocast_dtype is not None:
164
+ amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
165
+ else:
166
+ amp_ctx = _NoOp()
167
+ with torch.inference_mode():
168
+ with amp_ctx:
169
+ out = self.pred.predict(**kwargs)
170
+ # normalize outputs to dict
171
+ masks = None
172
+ scores = None
173
+ logits = None
174
+ if isinstance(out, dict):
175
+ masks = out.get("masks", None)
176
+ scores = out.get("scores", None)
177
+ logits = out.get("logits", None)
178
+ elif isinstance(out, (tuple, list)):
179
+ if len(out) >= 1: masks = out[0]
180
+ if len(out) >= 2: scores = out[1]
181
+ if len(out) >= 3: logits = out[2]
182
+ else:
183
+ masks = out
184
+ if masks is None:
185
+ raise RuntimeError("SAM2 returned no masks")
186
+ masks = np.asarray(masks)
187
+ if masks.ndim == 2:
188
+ masks = masks[None, ...]
189
+ elif masks.ndim == 4 and masks.shape[1] == 1:
190
+ masks = masks[:, 0, :, :]
191
+ masks = _normalize_masks_dtype(masks)
192
+ masks_up = _upsample_stack(masks, (H, W))
193
+ if scores is None:
194
+ scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5
195
+ else:
196
+ scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1)
197
+ out_dict = {"masks": masks_up, "scores": scores}
198
+ if logits is not None:
199
+ lg = np.asarray(logits)
200
+ if lg.ndim == 3:
201
+ lg = _upsample_stack(lg, (H, W))
202
+ elif lg.ndim == 4 and lg.shape[1] == 1:
203
+ lg = _upsample_stack(lg[:, 0, :, :], (H, W))
204
+ out_dict["logits"] = lg.astype(np.float32, copy=False)
205
+ return out_dict
206
+ except torch.cuda.OutOfMemoryError as e:
207
+ last_exc = e
208
+ if torch.cuda.is_available():
209
+ torch.cuda.empty_cache()
210
+ logger.warning(f"SAM2 OOM at {th}x{tw}; retrying smaller. {e}")
211
+ continue
212
+ except Exception as e:
213
+ last_exc = e
214
+ if torch.cuda.is_available():
215
+ torch.cuda.empty_cache()
216
+ logger.debug(traceback.format_exc())
217
+ logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}")
218
+ continue
219
+ logger.warning(f"SAM2 calls failed; returning fallback. {last_exc}")
220
+ return {
221
+ "masks": np.ones((1, H, W), dtype=np.float32),
222
+ "scores": np.array([0.5], dtype=np.float32),
223
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # -------------------------- Loader --------------------------
225
 
226
  class SAM2Loader:
 
231
  self.cache_dir = cache_dir
232
  os.makedirs(self.cache_dir, exist_ok=True)
233
 
234
+ # HuggingFace Hub for spaces: avoid symlink errors
235
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
236
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
237
 
 
250
  """
251
  if model_size == "auto":
252
  model_size = self._determine_optimal_size()
 
253
  model_map = {
254
  "tiny": "facebook/sam2.1-hiera-tiny",
255
  "small": "facebook/sam2.1-hiera-small",
256
  "base": "facebook/sam2.1-hiera-base-plus",
257
  "large": "facebook/sam2.1-hiera-large",
258
  }
 
259
  self.model_id = model_map.get(model_size, model_map["tiny"])
260
  logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
 
 
261
  strategies = [("official", self._load_official), ("fallback", self._load_fallback)]
 
262
  for name, fn in strategies:
263
  try:
264
  t0 = time.time()
 
273
  except Exception as e:
274
  logger.error(f"SAM2 {name} strategy failed: {e}")
275
  logger.debug(traceback.format_exc())
 
276
  logger.error("All SAM2 loading strategies failed")
277
  return None
278
 
 
293
  def _load_official(self) -> Optional[Any]:
294
  """Load using official SAM2 API"""
295
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
296
  predictor = SAM2ImagePredictor.from_pretrained(
297
  self.model_id,
298
  cache_dir=self.cache_dir,
299
  local_files_only=False,
300
  trust_remote_code=True,
301
  )
 
 
302
  if hasattr(predictor, "model"):
303
  predictor.model = predictor.model.to(self.device)
304
  predictor.model.eval()
305
  if hasattr(predictor, "device"):
306
  predictor.device = self.device
 
307
  return predictor
308
 
309
  def _load_fallback(self) -> Optional[Any]:
310
  """Create a tiny fallback predictor"""
 
311
  class FallbackSAM2:
312
  def __init__(self, device):
313
  self.device = device
 
319
  h, w = self._img.shape[:2]
320
  else:
321
  h, w = 512, 512
322
+ # Return a full-ones mask—**handled downstream!**
323
  return {
324
  "masks": np.ones((1, h, w), dtype=np.float32),
325
  "scores": np.array([0.5], dtype=np.float32),
326
  }
 
327
  logger.warning("Using fallback SAM2 (no real segmentation)")
328
  return FallbackSAM2(self.device)
329
 
330
  def cleanup(self):
 
331
  self.adapter = None
332
  if self.model is not None:
333
  try:
 
339
  torch.cuda.empty_cache()
340
 
341
  def get_info(self) -> Dict[str, Any]:
 
342
  return {
343
  "loaded": self.adapter is not None,
344
  "model_id": self.model_id,