MogensR commited on
Commit
0fb1268
·
1 Parent(s): 3723e02
Files changed (1) hide show
  1. models/sam2_loader.py +233 -384
models/sam2_loader.py CHANGED
@@ -1,413 +1,262 @@
1
  #!/usr/bin/env python3
2
  """
3
- SAM2 Loader Robust loading and mask generation for SAM2
4
- ========================================================
5
- - Loads SAM2 model with Hydra config resolution
6
- - Auto-downloads missing checkpoint files
7
- - Generates seed masks for MatAnyone
8
- - Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
9
-
10
- Changes (2025-09-17):
11
- - Added automatic checkpoint download functionality
12
- - Enhanced error handling and logging
13
- - Fixed missing checkpoint issue that was causing fallback mask generation
14
  """
15
 
16
- from __future__ import annotations
17
-
18
  import os
19
- import sys
20
- import inspect
21
  import logging
22
- import importlib.metadata
23
- import urllib.request
24
- import urllib.error
25
- from pathlib import Path
26
- from typing import Optional, Tuple, Dict, Any
27
-
28
  import numpy as np
29
- import yaml
30
- import torch
31
-
32
- # --------------------------------------------------------------------------------------
33
- # Logging
34
- # --------------------------------------------------------------------------------------
35
- logger = logging.getLogger("backgroundfx_pro")
36
- if not logger.handlers:
37
- _h = logging.StreamHandler()
38
- _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
39
- logger.addHandler(_h)
40
- logger.setLevel(logging.INFO)
41
-
42
- # --------------------------------------------------------------------------------------
43
- # Path setup for third_party repos
44
- # --------------------------------------------------------------------------------------
45
- ROOT = Path(__file__).resolve().parent.parent # project root
46
- TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve()
47
-
48
- def _add_sys_path(p: Path) -> None:
49
- if p.exists():
50
- p_str = str(p)
51
- if p_str not in sys.path:
52
- sys.path.insert(0, p_str)
53
- else:
54
- logger.warning(f"third_party path not found: {p}")
55
-
56
- _add_sys_path(TP_SAM2)
57
 
58
- # --------------------------------------------------------------------------------------
59
- # Checkpoint Download Functionality
60
- # --------------------------------------------------------------------------------------
61
- SAM2_CHECKPOINT_URLS = {
62
- "sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
63
- "sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
64
- "sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
65
- "sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
66
- }
67
 
68
- def _download_checkpoint(checkpoint_path: str, checkpoint_name: str) -> bool:
69
- """Download SAM2 checkpoint if it doesn't exist."""
70
- if os.path.exists(checkpoint_path):
71
- logger.info(f"Checkpoint already exists: {checkpoint_path}")
72
- return True
73
-
74
- if checkpoint_name not in SAM2_CHECKPOINT_URLS:
75
- logger.error(f"Unknown checkpoint: {checkpoint_name}. Available: {list(SAM2_CHECKPOINT_URLS.keys())}")
76
- return False
77
-
78
- url = SAM2_CHECKPOINT_URLS[checkpoint_name]
79
- logger.info(f"Downloading SAM2 checkpoint: {checkpoint_name}")
80
- logger.info(f"URL: {url}")
81
- logger.info(f"Destination: {checkpoint_path}")
82
 
83
- try:
84
- # Create directory if it doesn't exist
85
- os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
 
 
 
86
 
87
- # Download with progress
88
- def _progress_hook(block_num, block_size, total_size):
89
- if total_size > 0:
90
- percent = min(100, block_num * block_size * 100 // total_size)
91
- if percent % 10 == 0: # Log every 10%
92
- logger.info(f"Download progress: {percent}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- urllib.request.urlretrieve(url, checkpoint_path, reporthook=_progress_hook)
 
 
 
 
 
 
 
95
 
96
- # Verify the file was downloaded
97
- if os.path.exists(checkpoint_path) and os.path.getsize(checkpoint_path) > 0:
98
- size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
99
- logger.info(f"Successfully downloaded {checkpoint_name} ({size_mb:.1f} MB)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  return True
101
- else:
102
- logger.error(f"Download failed: {checkpoint_path} does not exist or is empty")
103
- return False
104
 
105
- except urllib.error.URLError as e:
106
- logger.error(f"URL error downloading checkpoint: {e}")
107
- return False
108
- except Exception as e:
109
- logger.error(f"Error downloading checkpoint: {e}")
110
- return False
111
-
112
- # --------------------------------------------------------------------------------------
113
- # Safe Torch accessors
114
- # --------------------------------------------------------------------------------------
115
- def _torch():
116
- try:
117
- import torch
118
- return torch
119
- except Exception as e:
120
- logger.warning(f"[sam2_loader.safe-torch] import failed: {e}")
121
- return None
122
-
123
- def _has_cuda() -> bool:
124
- t = _torch()
125
- if t is None:
126
- return False
127
- try:
128
- return bool(t.cuda.is_available())
129
- except Exception as e:
130
- logger.warning(f"[sam2_loader.safe-torch] cuda.is_available() failed: {e}")
131
- return False
132
-
133
- def _pick_device(env_key: str) -> str:
134
- requested = os.environ.get(env_key, "").strip().lower()
135
- has_cuda = _has_cuda()
136
-
137
- logger.info(f"CUDA environment variables: {dict((k, v) for k, v in os.environ.items() if 'CUDA' in k or 'SAM2' in k)}")
138
- logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
139
-
140
- if has_cuda and requested not in {"cpu"}:
141
- logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')")
142
- return "cuda"
143
- elif requested in {"cuda", "cpu"}:
144
- logger.info(f"Using explicitly requested device: {requested}")
145
- return requested
146
-
147
- result = "cuda" if has_cuda else "cpu"
148
- logger.info(f"Auto-selected device: {result}")
149
- return result
150
-
151
- # --------------------------------------------------------------------------------------
152
- # SAM2 Loading and Mask Generation
153
- # --------------------------------------------------------------------------------------
154
- def _resolve_sam2_cfg(cfg_str: str) -> str:
155
- """Resolve SAM2 config path - return relative path for Hydra compatibility."""
156
- logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
157
- logger.info(f"TP_SAM2 = {TP_SAM2}")
158
-
159
- candidate = TP_SAM2 / cfg_str
160
- logger.info(f"Candidate path: {candidate}")
161
- logger.info(f"Candidate exists: {candidate.exists()}")
162
-
163
- if candidate.exists():
164
- if cfg_str.startswith("sam2/configs/"):
165
- relative_path = cfg_str.replace("sam2/configs/", "configs/")
166
- else:
167
- relative_path = cfg_str
168
- logger.info(f"Returning Hydra-compatible relative path: {relative_path}")
169
- return relative_path
170
-
171
- fallbacks = [
172
- TP_SAM2 / "sam2" / cfg_str,
173
- TP_SAM2 / "configs" / cfg_str,
174
- ]
175
-
176
- for fallback in fallbacks:
177
- logger.info(f"Trying fallback: {fallback}")
178
- if fallback.exists():
179
- if "configs" in str(fallback):
180
- relative_path = "configs/" + str(fallback).split("configs/")[-1]
181
- logger.info(f"Returning fallback relative path: {relative_path}")
182
- return relative_path
183
-
184
- logger.warning(f"Config not found, returning original: {cfg_str}")
185
- return cfg_str
186
-
187
- def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
188
- """If config references 'hieradet', try to find a 'hiera' config."""
189
- try:
190
- with open(cfg_path, "r") as f:
191
- data = yaml.safe_load(f)
192
- model = data.get("model", {}) or {}
193
- enc = model.get("image_encoder") or {}
194
- trunk = enc.get("trunk") or {}
195
- target = trunk.get("_target_") or trunk.get("target")
196
- if isinstance(target, str) and "hieradet" in target:
197
- for y in TP_SAM2.rglob("*.yaml"):
198
- try:
199
- with open(y, "r") as f2:
200
- d2 = yaml.safe_load(f2) or {}
201
- e2 = (d2.get("model", {}) or {}).get("image_encoder") or {}
202
- t2 = (e2.get("trunk") or {})
203
- tgt2 = t2.get("_target_") or t2.get("target")
204
- if isinstance(tgt2, str) and ".hiera." in tgt2:
205
- logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
206
- return str(y)
207
- except Exception:
208
- continue
209
- except Exception:
210
- pass
211
- return None
212
-
213
- def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
214
- """Robust SAM2 loader with config resolution and checkpoint auto-download."""
215
- meta = {"sam2_import_ok": False, "sam2_init_ok": False}
216
- try:
217
- from sam2.build_sam import build_sam2
218
- from sam2.sam2_image_predictor import SAM2ImagePredictor
219
- meta["sam2_import_ok"] = True
220
- except Exception as e:
221
- logger.warning(f"SAM2 import failed: {e}")
222
- return None, False, meta
223
-
224
- # Log SAM2 version
225
- try:
226
- version = importlib.metadata.version("segment-anything-2")
227
- logger.info(f"[SAM2] SAM2 version: {version}")
228
- except Exception:
229
- logger.info("[SAM2] SAM2 version unknown")
230
-
231
- # Check GPU memory before loading
232
- if torch.cuda.is_available():
233
- mem_before = torch.cuda.memory_allocated() / 1024**3
234
- logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB")
235
-
236
- device = _pick_device("SAM2_DEVICE")
237
- cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
238
- cfg = _resolve_sam2_cfg(cfg_env)
239
-
240
- # Handle checkpoint with auto-download
241
- ckpt = os.environ.get("SAM2_CHECKPOINT", "/home/user/app/checkpoints/sam2_hiera_large.pt")
242
- checkpoint_name = os.path.basename(ckpt)
243
 
244
- # Auto-download checkpoint if missing
245
- if not os.path.exists(ckpt):
246
- logger.info(f"SAM2 checkpoint not found: {ckpt}")
247
- if not _download_checkpoint(ckpt, checkpoint_name):
248
- logger.error(f"Failed to download SAM2 checkpoint: {checkpoint_name}")
249
- return None, False, meta
250
- else:
251
- logger.info(f"Using existing SAM2 checkpoint: {ckpt}")
252
-
253
- def _try_build(cfg_path: str):
254
- logger.info(f"_try_build called with cfg_path: {cfg_path}")
255
- params = set(inspect.signature(build_sam2).parameters.keys())
256
- logger.info(f"build_sam2 parameters: {list(params)}")
257
- kwargs = {}
258
- if "config_file" in params:
259
- kwargs["config_file"] = cfg_path
260
- logger.info(f"Using config_file parameter: {cfg_path}")
261
- elif "model_cfg" in params:
262
- kwargs["model_cfg"] = cfg_path
263
- logger.info(f"Using model_cfg parameter: {cfg_path}")
264
- if ckpt:
265
- if "checkpoint" in params:
266
- kwargs["checkpoint"] = ckpt
267
- elif "ckpt_path" in params:
268
- kwargs["ckpt_path"] = ckpt
269
- elif "weights" in params:
270
- kwargs["weights"] = ckpt
271
- if "device" in params:
272
- kwargs["device"] = device
273
- try:
274
- logger.info(f"Calling build_sam2 with kwargs: {kwargs}")
275
- result = build_sam2(**kwargs)
276
- logger.info(f"build_sam2 succeeded with kwargs")
277
- if hasattr(result, 'device'):
278
- logger.info(f"SAM2 model device: {result.device}")
279
- elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'):
280
- logger.info(f"SAM2 model device: {result.image_encoder.device}")
281
- return result
282
- except TypeError as e:
283
- logger.info(f"build_sam2 kwargs failed: {e}, trying positional args")
284
- pos = [cfg_path]
285
- if ckpt:
286
- pos.append(ckpt)
287
- if "device" not in kwargs:
288
- pos.append(device)
289
- logger.info(f"Calling build_sam2 with positional args: {pos}")
290
- result = build_sam2(*pos)
291
- logger.info(f"build_sam2 succeeded with positional args")
292
- return result
293
-
294
- try:
295
  try:
296
- sam = _try_build(cfg)
297
- except Exception:
298
- alt_cfg = _find_hiera_config_if_hieradet(cfg)
299
- if alt_cfg:
300
- sam = _try_build(alt_cfg)
301
- else:
302
- raise
 
 
 
 
 
 
 
 
 
303
 
304
- if sam is not None:
305
- predictor = SAM2ImagePredictor(sam)
306
- meta["sam2_init_ok"] = True
307
- meta["sam2_device"] = device
308
- logger.info("✅ SAM2 loaded successfully with auto-downloaded checkpoint")
309
- return predictor, True, meta
310
- else:
311
- logger.error("❌ SAM2 initialization returned None")
312
- return None, False, meta
313
-
314
- except Exception as e:
315
- logger.error(f"❌ SAM2 loading failed: {e}")
316
- import traceback
317
- logger.error(f"SAM2 loading traceback: {traceback.format_exc()}")
318
- return None, False, meta
319
-
320
- def run_sam2_mask(predictor: object,
321
- first_frame_bgr: np.ndarray,
322
- point: Optional[Tuple[int, int]] = None,
323
- auto: bool = False) -> Tuple[Optional[np.ndarray], bool]:
324
- """Generate a seed mask for MatAnyone. Returns (mask_uint8_0_255, ok)."""
325
- if predictor is None:
326
- return None, False
327
- try:
328
- import cv2
329
- rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB)
330
- predictor.set_image(rgb)
331
-
332
- if auto:
333
- h, w = rgb.shape[:2]
334
- box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)])
335
- masks, _, _ = predictor.predict(box=box)
336
- elif point is not None:
337
- x, y = int(point[0]), int(point[1])
338
- pts = np.array([[x, y]], dtype=np.int32)
339
- labels = np.array([1], dtype=np.int32)
340
- masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels)
341
- else:
342
- h, w = rgb.shape[:2]
343
- box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)])
344
- masks, _, _ = predictor.predict(box=box)
345
-
346
- if masks is None or len(masks) == 0:
347
- return None, False
348
-
349
- m = masks[0].astype(np.uint8) * 255
350
- logger.info(f"[SAM2] Generated mask: shape={m.shape}, dtype={m.dtype}")
351
- return m, True
352
- except Exception as e:
353
- logger.warning(f"SAM2 mask generation failed: {e}")
354
- return None, False
355
-
356
- # --------------------------------------------------------------------------------------
357
- # SAM2Model Wrapper Class for app_hf.py compatibility
358
- # --------------------------------------------------------------------------------------
359
- class SAM2Model:
360
- """Wrapper class for SAM2 model to match app_hf.py interface"""
361
 
362
- def __init__(self, device="cuda"):
363
- self.device = device
364
- self.predictor = None
365
- self.loaded = False
366
- logger.info(f"Initializing SAM2Model on device: {device}")
367
 
368
- # Load the model immediately
369
- self._load_model()
 
 
 
 
 
 
 
 
 
370
 
371
- def _load_model(self):
372
- """Load the SAM2 model"""
 
 
 
373
  try:
374
- self.predictor, self.loaded, meta = load_sam2()
375
- if self.loaded:
376
- logger.info("SAM2Model loaded successfully")
377
- else:
378
- logger.error("Failed to load SAM2Model")
379
  except Exception as e:
380
- logger.error(f"Error loading SAM2Model: {e}")
381
- self.loaded = False
382
 
383
- def predict(self, video_path):
384
- """Generate masks for video frames"""
385
- if not self.loaded:
386
- logger.error("SAM2Model not loaded")
387
- return None
388
-
389
  try:
390
- import cv2
391
-
392
- # Read first frame of video to generate initial mask
393
- cap = cv2.VideoCapture(video_path)
394
- ret, frame = cap.read()
395
- cap.release()
 
 
 
 
 
396
 
397
- if not ret:
398
- logger.error(f"Could not read video: {video_path}")
399
- return None
 
 
 
 
 
 
 
400
 
401
- # Generate mask for the frame
402
- mask, success = run_sam2_mask(self.predictor, frame, auto=True)
403
 
404
- if success:
405
- logger.info("Successfully generated mask from video")
406
- return mask
407
- else:
408
- logger.error("Failed to generate mask from video")
409
- return None
410
-
411
  except Exception as e:
412
- logger.error(f"Error predicting masks: {e}")
413
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ SAM2 Loader with T4-optimized predictor wrapper
4
+ Provides SAM2Predictor class with memory management and optimization features
 
 
 
 
 
 
 
 
 
5
  """
6
 
 
 
7
  import os
8
+ import gc
9
+ import torch
10
  import logging
 
 
 
 
 
 
11
  import numpy as np
12
+ from pathlib import Path
13
+ from typing import Optional, Any, Dict, List, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
16
 
17
+ class SAM2Predictor:
18
+ """
19
+ T4-optimized SAM2 video predictor wrapper with memory management
20
+ """
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def __init__(self, device: torch.device, model_size: str = "small"):
23
+ self.device = device
24
+ self.model_size = model_size
25
+ self.predictor = None
26
+ self.model = None
27
+ self._load_predictor()
28
 
29
+ def _load_predictor(self):
30
+ """Load SAM2 predictor with optimizations"""
31
+ try:
32
+ from sam2.build_sam import build_sam2_video_predictor
33
+
34
+ # Download checkpoint if needed
35
+ checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt"
36
+ if not self._ensure_checkpoint(checkpoint_path):
37
+ raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint")
38
+
39
+ # Build predictor
40
+ model_cfg = f"sam2_hiera_{self.model_size[0]}.yaml" # small -> s, base -> b, large -> l
41
+ self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
42
+
43
+ # Apply T4 optimizations
44
+ self._optimize_for_t4()
45
+
46
+ logger.info(f"SAM2 {self.model_size} predictor loaded successfully")
47
+
48
+ except ImportError as e:
49
+ logger.error(f"SAM2 import failed: {e}")
50
+ raise RuntimeError("SAM2 not available - check third_party/sam2 installation")
51
+ except Exception as e:
52
+ logger.error(f"SAM2 loading failed: {e}")
53
+ raise
54
+
55
+ def _ensure_checkpoint(self, checkpoint_path: str) -> bool:
56
+ """Ensure checkpoint exists, download if needed"""
57
+ checkpoint_file = Path(checkpoint_path)
58
 
59
+ if checkpoint_file.exists():
60
+ file_size = checkpoint_file.stat().st_size / (1024**2)
61
+ if file_size > 50: # At least 50MB
62
+ logger.info(f"SAM2 checkpoint exists: {file_size:.1f}MB")
63
+ return True
64
+ else:
65
+ logger.warning(f"Checkpoint too small ({file_size:.1f}MB), re-downloading")
66
+ checkpoint_file.unlink()
67
 
68
+ return self._download_checkpoint(checkpoint_path)
69
+
70
+ def _download_checkpoint(self, checkpoint_path: str, timeout_seconds: int = 600) -> bool:
71
+ """Download SAM2 checkpoint"""
72
+ try:
73
+ logger.info(f"Downloading SAM2 {self.model_size} checkpoint...")
74
+
75
+ checkpoint_file = Path(checkpoint_path)
76
+ checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
77
+
78
+ import requests
79
+
80
+ # Checkpoint URLs
81
+ urls = {
82
+ "small": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
83
+ "base": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
84
+ "large": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
85
+ }
86
+
87
+ if self.model_size not in urls:
88
+ raise ValueError(f"Unknown model size: {self.model_size}")
89
+
90
+ checkpoint_url = urls[self.model_size]
91
+
92
+ import time
93
+ start_time = time.time()
94
+ response = requests.get(checkpoint_url, stream=True, timeout=30)
95
+ response.raise_for_status()
96
+
97
+ total_size = int(response.headers.get('content-length', 0))
98
+
99
+ temp_path = checkpoint_file.with_suffix('.download')
100
+ downloaded = 0
101
+ last_log = start_time
102
+
103
+ with open(temp_path, 'wb') as f:
104
+ for chunk in response.iter_content(chunk_size=1024*1024):
105
+ if chunk:
106
+ f.write(chunk)
107
+ downloaded += len(chunk)
108
+
109
+ current_time = time.time()
110
+ if current_time - start_time > timeout_seconds:
111
+ raise TimeoutError(f"Download timeout after {timeout_seconds}s")
112
+
113
+ # Progress logging every 15 seconds
114
+ if current_time - last_log > 15:
115
+ progress = (downloaded / total_size * 100) if total_size > 0 else 0
116
+ speed = downloaded / (current_time - start_time) / (1024**2)
117
+ logger.info(f"Download: {progress:.1f}% ({speed:.1f}MB/s)")
118
+ last_log = current_time
119
+
120
+ temp_path.rename(checkpoint_file)
121
+
122
+ download_time = time.time() - start_time
123
+ speed = downloaded / download_time / (1024**2)
124
+ logger.info(f"Download complete: {downloaded/(1024**2):.1f}MB in {download_time:.1f}s ({speed:.1f}MB/s)")
125
+
126
  return True
 
 
 
127
 
128
+ except Exception as e:
129
+ logger.error(f"Checkpoint download failed: {e}")
130
+ if Path(checkpoint_path).exists():
131
+ Path(checkpoint_path).unlink()
132
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ def _optimize_for_t4(self):
135
+ """Apply T4-specific optimizations"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  try:
137
+ if hasattr(self.predictor, "model") and self.predictor.model is not None:
138
+ self.model = self.predictor.model
139
+
140
+ # Apply fp16 and channels_last for T4 efficiency
141
+ self.model = self.model.half().to(self.device)
142
+ self.model = self.model.to(memory_format=torch.channels_last)
143
+
144
+ logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
145
+
146
+ except Exception as e:
147
+ logger.warning(f"SAM2 T4 optimization warning: {e}")
148
+
149
+ def init_state(self, video_path: str):
150
+ """Initialize video processing state"""
151
+ if self.predictor is None:
152
+ raise RuntimeError("Predictor not loaded")
153
 
154
+ try:
155
+ return self.predictor.init_state(video_path=video_path)
156
+ except Exception as e:
157
+ logger.error(f"Failed to initialize video state: {e}")
158
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
161
+ points: np.ndarray, labels: np.ndarray):
162
+ """Add new points for tracking"""
163
+ if self.predictor is None:
164
+ raise RuntimeError("Predictor not loaded")
165
 
166
+ try:
167
+ return self.predictor.add_new_points(
168
+ inference_state=inference_state,
169
+ frame_idx=frame_idx,
170
+ obj_id=obj_id,
171
+ points=points,
172
+ labels=labels
173
+ )
174
+ except Exception as e:
175
+ logger.error(f"Failed to add new points: {e}")
176
+ raise
177
 
178
+ def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
179
+ """Propagate through video with optional scaling"""
180
+ if self.predictor is None:
181
+ raise RuntimeError("Predictor not loaded")
182
+
183
  try:
184
+ # Use the predictor's propagate_in_video method
185
+ return self.predictor.propagate_in_video(inference_state, **kwargs)
 
 
 
186
  except Exception as e:
187
+ logger.error(f"Failed to propagate in video: {e}")
188
+ raise
189
 
190
+ def prune_state(self, inference_state, keep: int):
191
+ """Prune SAM2 state to keep only recent frames in memory"""
 
 
 
 
192
  try:
193
+ # Try to access and prune internal caches
194
+ # This is model-specific and may need adjustment based on SAM2 internals
195
+ if hasattr(inference_state, 'cached_features'):
196
+ # Keep only the most recent 'keep' frames
197
+ cached_keys = list(inference_state.cached_features.keys())
198
+ if len(cached_keys) > keep:
199
+ keys_to_remove = cached_keys[:-keep]
200
+ for key in keys_to_remove:
201
+ if key in inference_state.cached_features:
202
+ del inference_state.cached_features[key]
203
+ logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
204
 
205
+ # Clear other potential caches
206
+ if hasattr(inference_state, 'point_inputs_per_obj'):
207
+ # Keep recent point inputs only
208
+ for obj_id in list(inference_state.point_inputs_per_obj.keys()):
209
+ obj_inputs = inference_state.point_inputs_per_obj[obj_id]
210
+ if len(obj_inputs) > keep:
211
+ # Keep only recent entries
212
+ recent_keys = sorted(obj_inputs.keys())[-keep:]
213
+ new_inputs = {k: obj_inputs[k] for k in recent_keys}
214
+ inference_state.point_inputs_per_obj[obj_id] = new_inputs
215
 
216
+ # Force garbage collection
217
+ torch.cuda.empty_cache() if self.device.type == 'cuda' else None
218
 
 
 
 
 
 
 
 
219
  except Exception as e:
220
+ logger.debug(f"State pruning warning: {e}")
221
+
222
+ def clear_memory(self):
223
+ """Clear GPU memory aggressively"""
224
+ try:
225
+ if self.device.type == 'cuda':
226
+ torch.cuda.empty_cache()
227
+ torch.cuda.synchronize()
228
+ torch.cuda.ipc_collect()
229
+ gc.collect()
230
+ except Exception as e:
231
+ logger.warning(f"Memory clearing warning: {e}")
232
+
233
+ def get_memory_usage(self) -> Dict[str, float]:
234
+ """Get current memory usage statistics"""
235
+ if self.device.type != 'cuda':
236
+ return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
237
+
238
+ try:
239
+ allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
240
+ reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
241
+ free, total = torch.cuda.mem_get_info(self.device)
242
+ free_gb = free / (1024**3)
243
+
244
+ return {
245
+ "allocated_gb": allocated,
246
+ "reserved_gb": reserved,
247
+ "free_gb": free_gb,
248
+ "total_gb": total / (1024**3)
249
+ }
250
+ except Exception:
251
+ return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
252
+
253
+ def __del__(self):
254
+ """Cleanup on deletion"""
255
+ try:
256
+ if hasattr(self, 'predictor') and self.predictor is not None:
257
+ del self.predictor
258
+ if hasattr(self, 'model') and self.model is not None:
259
+ del self.model
260
+ self.clear_memory()
261
+ except Exception:
262
+ pass