MogensR commited on
Commit
01d284f
·
verified ·
1 Parent(s): 77b93ef

Update models/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/sam2_loader.py +55 -70
models/sam2_loader.py CHANGED
@@ -3,6 +3,7 @@
3
  SAM2 Loader with Hugging Face Hub integration
4
  Provides SAM2Predictor class with memory management and optimization features
5
  Updated to use Hugging Face Hub models instead of direct downloads
 
6
  """
7
 
8
  import os
@@ -13,6 +14,7 @@
13
  from pathlib import Path
14
  from typing import Optional, Any, Dict, List, Tuple
15
 
 
16
  logger = logging.getLogger(__name__)
17
 
18
  class SAM2Predictor:
@@ -21,6 +23,7 @@ class SAM2Predictor:
21
  """
22
 
23
  def __init__(self, device: torch.device, model_size: str = "small"):
 
24
  self.device = device
25
  self.model_size = model_size
26
  self.predictor = None
@@ -30,73 +33,60 @@ def __init__(self, device: torch.device, model_size: str = "small"):
30
  def _load_predictor(self):
31
  """Load SAM2 predictor with Hugging Face Hub integration"""
32
  try:
 
33
  from sam2.build_sam import build_sam2_video_predictor
34
 
35
- # Get checkpoint from Hugging Face Hub
36
  checkpoint_path = self._get_hf_checkpoint()
37
  if not checkpoint_path:
 
38
  raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub")
39
 
40
- # Get model config
41
  model_cfg = self._get_model_config()
 
42
 
43
- # Build predictor
44
  self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
45
-
46
- # Apply T4 optimizations
47
  self._optimize_for_t4()
48
-
49
  logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub")
50
-
51
  except ImportError as e:
52
  logger.error(f"SAM2 import failed: {e}")
53
  raise RuntimeError("SAM2 not available - check sam2 installation")
54
  except Exception as e:
55
- logger.error(f"SAM2 loading failed: {e}")
56
  raise
57
 
58
  def _get_hf_checkpoint(self) -> Optional[str]:
59
  """Download checkpoint from Hugging Face Hub"""
60
  try:
 
61
  from huggingface_hub import hf_hub_download
62
 
63
- # Repository mapping for different model sizes
64
  repo_mapping = {
65
  "small": "facebook/sam2-hiera-small",
66
  "base": "facebook/sam2-hiera-base-plus",
67
  "large": "facebook/sam2-hiera-large"
68
  }
69
-
70
  filename_mapping = {
71
  "small": "sam2_hiera_small.pt",
72
  "base": "sam2_hiera_base_plus.pt",
73
  "large": "sam2_hiera_large.pt"
74
  }
75
-
76
  if self.model_size not in repo_mapping:
77
  logger.error(f"Unknown model size: {self.model_size}")
78
  return None
79
-
80
  repo_id = repo_mapping[self.model_size]
81
  filename = filename_mapping[self.model_size]
82
-
83
  logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}")
84
-
85
- # Download from Hugging Face Hub
86
  checkpoint_path = hf_hub_download(
87
  repo_id=repo_id,
88
  filename=filename,
89
- cache_dir=None, # Use default cache
90
- force_download=False, # Use cached version if available
91
- token=None # No auth token needed for public models
92
  )
93
-
94
  logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}")
95
  return checkpoint_path
96
-
97
  except Exception as e:
98
  logger.error(f"HF Hub download failed: {e}")
99
- # Fallback to local checkpoint if HF download fails
100
  return self._fallback_local_checkpoint()
101
 
102
  def _fallback_local_checkpoint(self) -> Optional[str]:
@@ -120,63 +110,64 @@ def _get_model_config(self) -> str:
120
  "base": "sam2_hiera_b+.yaml",
121
  "large": "sam2_hiera_l.yaml"
122
  }
123
-
124
- return config_mapping.get(self.model_size, "sam2_hiera_s.yaml")
 
125
 
126
  def _optimize_for_t4(self):
127
  """Apply T4-specific optimizations"""
128
  try:
 
129
  if hasattr(self.predictor, "model") and self.predictor.model is not None:
130
  self.model = self.predictor.model
131
-
132
- # Apply fp16 and channels_last for T4 efficiency
133
  self.model = self.model.half().to(self.device)
134
  self.model = self.model.to(memory_format=torch.channels_last)
135
-
136
  logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
137
-
138
  except Exception as e:
139
- logger.warning(f"SAM2 T4 optimization warning: {e}")
140
 
141
  def init_state(self, video_path: str):
142
- """Initialize video processing state"""
143
  if self.predictor is None:
 
144
  raise RuntimeError("Predictor not loaded")
145
-
146
  try:
147
- return self.predictor.init_state(video_path=video_path)
 
 
148
  except Exception as e:
149
- logger.error(f"Failed to initialize video state: {e}")
150
  raise
151
 
152
  def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
153
  points: np.ndarray, labels: np.ndarray):
154
- """Add new points for tracking"""
155
  if self.predictor is None:
 
156
  raise RuntimeError("Predictor not loaded")
157
-
158
  try:
159
- return self.predictor.add_new_points(
160
  inference_state=inference_state,
161
  frame_idx=frame_idx,
162
  obj_id=obj_id,
163
  points=points,
164
  labels=labels
165
  )
 
 
166
  except Exception as e:
167
- logger.error(f"Failed to add new points: {e}")
168
  raise
169
 
170
  def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
171
  points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True):
172
- """Add new points or box for tracking (newer SAM2 API)"""
173
  if self.predictor is None:
 
174
  raise RuntimeError("Predictor not loaded")
175
-
176
  try:
177
- # Try the newer API first
178
  if hasattr(self.predictor, 'add_new_points_or_box'):
179
- return self.predictor.add_new_points_or_box(
180
  inference_state=inference_state,
181
  frame_idx=frame_idx,
182
  obj_id=obj_id,
@@ -184,38 +175,39 @@ def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
184
  labels=labels,
185
  clear_old_points=clear_old_points
186
  )
 
 
187
  else:
188
- # Fallback to older API
189
- return self.predictor.add_new_points(
190
  inference_state=inference_state,
191
  frame_idx=frame_idx,
192
  obj_id=obj_id,
193
  points=points,
194
  labels=labels
195
  )
 
 
196
  except Exception as e:
197
- logger.error(f"Failed to add new points or box: {e}")
198
  raise
199
 
200
  def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
201
- """Propagate through video with optional scaling"""
202
  if self.predictor is None:
 
203
  raise RuntimeError("Predictor not loaded")
204
-
205
  try:
206
- # Use the predictor's propagate_in_video method
207
- return self.predictor.propagate_in_video(inference_state, **kwargs)
 
208
  except Exception as e:
209
- logger.error(f"Failed to propagate in video: {e}")
210
  raise
211
 
212
  def prune_state(self, inference_state, keep: int):
213
- """Prune SAM2 state to keep only recent frames in memory"""
214
  try:
215
- # Try to access and prune internal caches
216
- # This is model-specific and may need adjustment based on SAM2 internals
217
  if hasattr(inference_state, 'cached_features'):
218
- # Keep only the most recent 'keep' frames
219
  cached_keys = list(inference_state.cached_features.keys())
220
  if len(cached_keys) > keep:
221
  keys_to_remove = cached_keys[:-keep]
@@ -223,26 +215,20 @@ def prune_state(self, inference_state, keep: int):
223
  if key in inference_state.cached_features:
224
  del inference_state.cached_features[key]
225
  logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
226
-
227
- # Clear other potential caches
228
  if hasattr(inference_state, 'point_inputs_per_obj'):
229
- # Keep recent point inputs only
230
  for obj_id in list(inference_state.point_inputs_per_obj.keys()):
231
  obj_inputs = inference_state.point_inputs_per_obj[obj_id]
232
  if len(obj_inputs) > keep:
233
- # Keep only recent entries
234
  recent_keys = sorted(obj_inputs.keys())[-keep:]
235
  new_inputs = {k: obj_inputs[k] for k in recent_keys}
236
  inference_state.point_inputs_per_obj[obj_id] = new_inputs
237
-
238
- # Force garbage collection
239
- torch.cuda.empty_cache() if self.device.type == 'cuda' else None
240
-
241
  except Exception as e:
242
- logger.debug(f"State pruning warning: {e}")
243
 
244
  def clear_memory(self):
245
- """Clear GPU memory aggressively"""
246
  try:
247
  if self.device.type == 'cuda':
248
  torch.cuda.empty_cache()
@@ -250,35 +236,34 @@ def clear_memory(self):
250
  torch.cuda.ipc_collect()
251
  gc.collect()
252
  except Exception as e:
253
- logger.warning(f"Memory clearing warning: {e}")
254
 
255
  def get_memory_usage(self) -> Dict[str, float]:
256
- """Get current memory usage statistics"""
257
  if self.device.type != 'cuda':
258
  return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
259
-
260
  try:
261
  allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
262
  reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
263
  free, total = torch.cuda.mem_get_info(self.device)
264
  free_gb = free / (1024**3)
265
-
266
  return {
267
  "allocated_gb": allocated,
268
  "reserved_gb": reserved,
269
  "free_gb": free_gb,
270
  "total_gb": total / (1024**3)
271
  }
272
- except Exception:
 
273
  return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
274
 
275
  def __del__(self):
276
- """Cleanup on deletion"""
277
  try:
278
  if hasattr(self, 'predictor') and self.predictor is not None:
279
  del self.predictor
280
  if hasattr(self, 'model') and self.model is not None:
281
  del self.model
282
  self.clear_memory()
283
- except Exception:
284
- pass
 
3
  SAM2 Loader with Hugging Face Hub integration
4
  Provides SAM2Predictor class with memory management and optimization features
5
  Updated to use Hugging Face Hub models instead of direct downloads
6
+ (Enhanced logging and exception safety)
7
  """
8
 
9
  import os
 
14
  from pathlib import Path
15
  from typing import Optional, Any, Dict, List, Tuple
16
 
17
+ logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
  class SAM2Predictor:
 
23
  """
24
 
25
  def __init__(self, device: torch.device, model_size: str = "small"):
26
+ logger.info(f"[SAM2Predictor.__init__] device={device}, model_size={model_size}") # [LOG+SAFETY PATCH]
27
  self.device = device
28
  self.model_size = model_size
29
  self.predictor = None
 
33
  def _load_predictor(self):
34
  """Load SAM2 predictor with Hugging Face Hub integration"""
35
  try:
36
+ logger.info("[SAM2Predictor._load_predictor] Loading SAM2 predictor...") # [LOG+SAFETY PATCH]
37
  from sam2.build_sam import build_sam2_video_predictor
38
 
 
39
  checkpoint_path = self._get_hf_checkpoint()
40
  if not checkpoint_path:
41
+ logger.error(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") # [LOG+SAFETY PATCH]
42
  raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub")
43
 
 
44
  model_cfg = self._get_model_config()
45
+ logger.info(f"[SAM2Predictor._load_predictor] Using model_cfg: {model_cfg}") # [LOG+SAFETY PATCH]
46
 
 
47
  self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
 
 
48
  self._optimize_for_t4()
 
49
  logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub")
 
50
  except ImportError as e:
51
  logger.error(f"SAM2 import failed: {e}")
52
  raise RuntimeError("SAM2 not available - check sam2 installation")
53
  except Exception as e:
54
+ logger.error(f"SAM2 loading failed: {e}", exc_info=True)
55
  raise
56
 
57
  def _get_hf_checkpoint(self) -> Optional[str]:
58
  """Download checkpoint from Hugging Face Hub"""
59
  try:
60
+ logger.info(f"[SAM2Predictor._get_hf_checkpoint] Downloading checkpoint...") # [LOG+SAFETY PATCH]
61
  from huggingface_hub import hf_hub_download
62
 
 
63
  repo_mapping = {
64
  "small": "facebook/sam2-hiera-small",
65
  "base": "facebook/sam2-hiera-base-plus",
66
  "large": "facebook/sam2-hiera-large"
67
  }
 
68
  filename_mapping = {
69
  "small": "sam2_hiera_small.pt",
70
  "base": "sam2_hiera_base_plus.pt",
71
  "large": "sam2_hiera_large.pt"
72
  }
 
73
  if self.model_size not in repo_mapping:
74
  logger.error(f"Unknown model size: {self.model_size}")
75
  return None
 
76
  repo_id = repo_mapping[self.model_size]
77
  filename = filename_mapping[self.model_size]
 
78
  logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}")
 
 
79
  checkpoint_path = hf_hub_download(
80
  repo_id=repo_id,
81
  filename=filename,
82
+ cache_dir=None,
83
+ force_download=False,
84
+ token=None
85
  )
 
86
  logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}")
87
  return checkpoint_path
 
88
  except Exception as e:
89
  logger.error(f"HF Hub download failed: {e}")
 
90
  return self._fallback_local_checkpoint()
91
 
92
  def _fallback_local_checkpoint(self) -> Optional[str]:
 
110
  "base": "sam2_hiera_b+.yaml",
111
  "large": "sam2_hiera_l.yaml"
112
  }
113
+ cfg = config_mapping.get(self.model_size, "sam2_hiera_s.yaml")
114
+ logger.info(f"[SAM2Predictor._get_model_config] Returning config: {cfg}") # [LOG+SAFETY PATCH]
115
+ return cfg
116
 
117
  def _optimize_for_t4(self):
118
  """Apply T4-specific optimizations"""
119
  try:
120
+ logger.info("[SAM2Predictor._optimize_for_t4] Optimizing for T4...") # [LOG+SAFETY PATCH]
121
  if hasattr(self.predictor, "model") and self.predictor.model is not None:
122
  self.model = self.predictor.model
 
 
123
  self.model = self.model.half().to(self.device)
124
  self.model = self.model.to(memory_format=torch.channels_last)
 
125
  logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
 
126
  except Exception as e:
127
+ logger.warning(f"SAM2 T4 optimization warning: {e}", exc_info=True)
128
 
129
  def init_state(self, video_path: str):
130
+ logger.info(f"[SAM2Predictor.init_state] Initializing video state for: {video_path}") # [LOG+SAFETY PATCH]
131
  if self.predictor is None:
132
+ logger.error("Predictor not loaded in init_state")
133
  raise RuntimeError("Predictor not loaded")
 
134
  try:
135
+ state = self.predictor.init_state(video_path=video_path)
136
+ logger.info("[SAM2Predictor.init_state] Video state initialized OK")
137
+ return state
138
  except Exception as e:
139
+ logger.error(f"Failed to initialize video state: {e}", exc_info=True)
140
  raise
141
 
142
  def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
143
  points: np.ndarray, labels: np.ndarray):
144
+ logger.info(f"[SAM2Predictor.add_new_points] Adding points for frame {frame_idx}, obj {obj_id}") # [LOG+SAFETY PATCH]
145
  if self.predictor is None:
146
+ logger.error("Predictor not loaded in add_new_points")
147
  raise RuntimeError("Predictor not loaded")
 
148
  try:
149
+ out = self.predictor.add_new_points(
150
  inference_state=inference_state,
151
  frame_idx=frame_idx,
152
  obj_id=obj_id,
153
  points=points,
154
  labels=labels
155
  )
156
+ logger.info(f"[SAM2Predictor.add_new_points] Points added OK")
157
+ return out
158
  except Exception as e:
159
+ logger.error(f"Failed to add new points: {e}", exc_info=True)
160
  raise
161
 
162
  def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
163
  points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True):
164
+ logger.info(f"[SAM2Predictor.add_new_points_or_box] Adding points/box for frame {frame_idx}, obj {obj_id}") # [LOG+SAFETY PATCH]
165
  if self.predictor is None:
166
+ logger.error("Predictor not loaded in add_new_points_or_box")
167
  raise RuntimeError("Predictor not loaded")
 
168
  try:
 
169
  if hasattr(self.predictor, 'add_new_points_or_box'):
170
+ out = self.predictor.add_new_points_or_box(
171
  inference_state=inference_state,
172
  frame_idx=frame_idx,
173
  obj_id=obj_id,
 
175
  labels=labels,
176
  clear_old_points=clear_old_points
177
  )
178
+ logger.info(f"[SAM2Predictor.add_new_points_or_box] Used new API, points/box added OK")
179
+ return out
180
  else:
181
+ out = self.predictor.add_new_points(
 
182
  inference_state=inference_state,
183
  frame_idx=frame_idx,
184
  obj_id=obj_id,
185
  points=points,
186
  labels=labels
187
  )
188
+ logger.info(f"[SAM2Predictor.add_new_points_or_box] Used fallback, points added OK")
189
+ return out
190
  except Exception as e:
191
+ logger.error(f"Failed to add new points or box: {e}", exc_info=True)
192
  raise
193
 
194
  def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
195
+ logger.info(f"[SAM2Predictor.propagate_in_video] Propagating in video...") # [LOG+SAFETY PATCH]
196
  if self.predictor is None:
197
+ logger.error("Predictor not loaded in propagate_in_video")
198
  raise RuntimeError("Predictor not loaded")
 
199
  try:
200
+ out = self.predictor.propagate_in_video(inference_state, **kwargs)
201
+ logger.info(f"[SAM2Predictor.propagate_in_video] Propagation OK")
202
+ return out
203
  except Exception as e:
204
+ logger.error(f"Failed to propagate in video: {e}", exc_info=True)
205
  raise
206
 
207
  def prune_state(self, inference_state, keep: int):
208
+ logger.info(f"[SAM2Predictor.prune_state] Pruning state to keep {keep} frames...") # [LOG+SAFETY PATCH]
209
  try:
 
 
210
  if hasattr(inference_state, 'cached_features'):
 
211
  cached_keys = list(inference_state.cached_features.keys())
212
  if len(cached_keys) > keep:
213
  keys_to_remove = cached_keys[:-keep]
 
215
  if key in inference_state.cached_features:
216
  del inference_state.cached_features[key]
217
  logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
 
 
218
  if hasattr(inference_state, 'point_inputs_per_obj'):
 
219
  for obj_id in list(inference_state.point_inputs_per_obj.keys()):
220
  obj_inputs = inference_state.point_inputs_per_obj[obj_id]
221
  if len(obj_inputs) > keep:
 
222
  recent_keys = sorted(obj_inputs.keys())[-keep:]
223
  new_inputs = {k: obj_inputs[k] for k in recent_keys}
224
  inference_state.point_inputs_per_obj[obj_id] = new_inputs
225
+ if self.device.type == 'cuda':
226
+ torch.cuda.empty_cache()
 
 
227
  except Exception as e:
228
+ logger.debug(f"State pruning warning: {e}", exc_info=True)
229
 
230
  def clear_memory(self):
231
+ logger.info("[SAM2Predictor.clear_memory] Clearing GPU memory") # [LOG+SAFETY PATCH]
232
  try:
233
  if self.device.type == 'cuda':
234
  torch.cuda.empty_cache()
 
236
  torch.cuda.ipc_collect()
237
  gc.collect()
238
  except Exception as e:
239
+ logger.warning(f"Memory clearing warning: {e}", exc_info=True)
240
 
241
  def get_memory_usage(self) -> Dict[str, float]:
242
+ logger.info("[SAM2Predictor.get_memory_usage] Checking memory usage") # [LOG+SAFETY PATCH]
243
  if self.device.type != 'cuda':
244
  return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
 
245
  try:
246
  allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
247
  reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
248
  free, total = torch.cuda.mem_get_info(self.device)
249
  free_gb = free / (1024**3)
 
250
  return {
251
  "allocated_gb": allocated,
252
  "reserved_gb": reserved,
253
  "free_gb": free_gb,
254
  "total_gb": total / (1024**3)
255
  }
256
+ except Exception as e:
257
+ logger.warning(f"Error checking memory usage: {e}", exc_info=True)
258
  return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
259
 
260
  def __del__(self):
261
+ logger.info("[SAM2Predictor.__del__] Cleaning up...") # [LOG+SAFETY PATCH]
262
  try:
263
  if hasattr(self, 'predictor') and self.predictor is not None:
264
  del self.predictor
265
  if hasattr(self, 'model') and self.model is not None:
266
  del self.model
267
  self.clear_memory()
268
+ except Exception as e:
269
+ logger.warning(f"Error in __del__: {e}", exc_info=True)