Update models/sam2_loader.py
Browse files- models/sam2_loader.py +55 -70
models/sam2_loader.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
SAM2 Loader with Hugging Face Hub integration
|
| 4 |
Provides SAM2Predictor class with memory management and optimization features
|
| 5 |
Updated to use Hugging Face Hub models instead of direct downloads
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
@@ -13,6 +14,7 @@
|
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Optional, Any, Dict, List, Tuple
|
| 15 |
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
class SAM2Predictor:
|
|
@@ -21,6 +23,7 @@ class SAM2Predictor:
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(self, device: torch.device, model_size: str = "small"):
|
|
|
|
| 24 |
self.device = device
|
| 25 |
self.model_size = model_size
|
| 26 |
self.predictor = None
|
|
@@ -30,73 +33,60 @@ def __init__(self, device: torch.device, model_size: str = "small"):
|
|
| 30 |
def _load_predictor(self):
|
| 31 |
"""Load SAM2 predictor with Hugging Face Hub integration"""
|
| 32 |
try:
|
|
|
|
| 33 |
from sam2.build_sam import build_sam2_video_predictor
|
| 34 |
|
| 35 |
-
# Get checkpoint from Hugging Face Hub
|
| 36 |
checkpoint_path = self._get_hf_checkpoint()
|
| 37 |
if not checkpoint_path:
|
|
|
|
| 38 |
raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub")
|
| 39 |
|
| 40 |
-
# Get model config
|
| 41 |
model_cfg = self._get_model_config()
|
|
|
|
| 42 |
|
| 43 |
-
# Build predictor
|
| 44 |
self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
|
| 45 |
-
|
| 46 |
-
# Apply T4 optimizations
|
| 47 |
self._optimize_for_t4()
|
| 48 |
-
|
| 49 |
logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub")
|
| 50 |
-
|
| 51 |
except ImportError as e:
|
| 52 |
logger.error(f"SAM2 import failed: {e}")
|
| 53 |
raise RuntimeError("SAM2 not available - check sam2 installation")
|
| 54 |
except Exception as e:
|
| 55 |
-
logger.error(f"SAM2 loading failed: {e}")
|
| 56 |
raise
|
| 57 |
|
| 58 |
def _get_hf_checkpoint(self) -> Optional[str]:
|
| 59 |
"""Download checkpoint from Hugging Face Hub"""
|
| 60 |
try:
|
|
|
|
| 61 |
from huggingface_hub import hf_hub_download
|
| 62 |
|
| 63 |
-
# Repository mapping for different model sizes
|
| 64 |
repo_mapping = {
|
| 65 |
"small": "facebook/sam2-hiera-small",
|
| 66 |
"base": "facebook/sam2-hiera-base-plus",
|
| 67 |
"large": "facebook/sam2-hiera-large"
|
| 68 |
}
|
| 69 |
-
|
| 70 |
filename_mapping = {
|
| 71 |
"small": "sam2_hiera_small.pt",
|
| 72 |
"base": "sam2_hiera_base_plus.pt",
|
| 73 |
"large": "sam2_hiera_large.pt"
|
| 74 |
}
|
| 75 |
-
|
| 76 |
if self.model_size not in repo_mapping:
|
| 77 |
logger.error(f"Unknown model size: {self.model_size}")
|
| 78 |
return None
|
| 79 |
-
|
| 80 |
repo_id = repo_mapping[self.model_size]
|
| 81 |
filename = filename_mapping[self.model_size]
|
| 82 |
-
|
| 83 |
logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}")
|
| 84 |
-
|
| 85 |
-
# Download from Hugging Face Hub
|
| 86 |
checkpoint_path = hf_hub_download(
|
| 87 |
repo_id=repo_id,
|
| 88 |
filename=filename,
|
| 89 |
-
cache_dir=None,
|
| 90 |
-
force_download=False,
|
| 91 |
-
token=None
|
| 92 |
)
|
| 93 |
-
|
| 94 |
logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}")
|
| 95 |
return checkpoint_path
|
| 96 |
-
|
| 97 |
except Exception as e:
|
| 98 |
logger.error(f"HF Hub download failed: {e}")
|
| 99 |
-
# Fallback to local checkpoint if HF download fails
|
| 100 |
return self._fallback_local_checkpoint()
|
| 101 |
|
| 102 |
def _fallback_local_checkpoint(self) -> Optional[str]:
|
|
@@ -120,63 +110,64 @@ def _get_model_config(self) -> str:
|
|
| 120 |
"base": "sam2_hiera_b+.yaml",
|
| 121 |
"large": "sam2_hiera_l.yaml"
|
| 122 |
}
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
def _optimize_for_t4(self):
|
| 127 |
"""Apply T4-specific optimizations"""
|
| 128 |
try:
|
|
|
|
| 129 |
if hasattr(self.predictor, "model") and self.predictor.model is not None:
|
| 130 |
self.model = self.predictor.model
|
| 131 |
-
|
| 132 |
-
# Apply fp16 and channels_last for T4 efficiency
|
| 133 |
self.model = self.model.half().to(self.device)
|
| 134 |
self.model = self.model.to(memory_format=torch.channels_last)
|
| 135 |
-
|
| 136 |
logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
|
| 137 |
-
|
| 138 |
except Exception as e:
|
| 139 |
-
logger.warning(f"SAM2 T4 optimization warning: {e}")
|
| 140 |
|
| 141 |
def init_state(self, video_path: str):
|
| 142 |
-
"
|
| 143 |
if self.predictor is None:
|
|
|
|
| 144 |
raise RuntimeError("Predictor not loaded")
|
| 145 |
-
|
| 146 |
try:
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
except Exception as e:
|
| 149 |
-
logger.error(f"Failed to initialize video state: {e}")
|
| 150 |
raise
|
| 151 |
|
| 152 |
def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
|
| 153 |
points: np.ndarray, labels: np.ndarray):
|
| 154 |
-
"
|
| 155 |
if self.predictor is None:
|
|
|
|
| 156 |
raise RuntimeError("Predictor not loaded")
|
| 157 |
-
|
| 158 |
try:
|
| 159 |
-
|
| 160 |
inference_state=inference_state,
|
| 161 |
frame_idx=frame_idx,
|
| 162 |
obj_id=obj_id,
|
| 163 |
points=points,
|
| 164 |
labels=labels
|
| 165 |
)
|
|
|
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
-
logger.error(f"Failed to add new points: {e}")
|
| 168 |
raise
|
| 169 |
|
| 170 |
def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
|
| 171 |
points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True):
|
| 172 |
-
"
|
| 173 |
if self.predictor is None:
|
|
|
|
| 174 |
raise RuntimeError("Predictor not loaded")
|
| 175 |
-
|
| 176 |
try:
|
| 177 |
-
# Try the newer API first
|
| 178 |
if hasattr(self.predictor, 'add_new_points_or_box'):
|
| 179 |
-
|
| 180 |
inference_state=inference_state,
|
| 181 |
frame_idx=frame_idx,
|
| 182 |
obj_id=obj_id,
|
|
@@ -184,38 +175,39 @@ def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
|
|
| 184 |
labels=labels,
|
| 185 |
clear_old_points=clear_old_points
|
| 186 |
)
|
|
|
|
|
|
|
| 187 |
else:
|
| 188 |
-
|
| 189 |
-
return self.predictor.add_new_points(
|
| 190 |
inference_state=inference_state,
|
| 191 |
frame_idx=frame_idx,
|
| 192 |
obj_id=obj_id,
|
| 193 |
points=points,
|
| 194 |
labels=labels
|
| 195 |
)
|
|
|
|
|
|
|
| 196 |
except Exception as e:
|
| 197 |
-
logger.error(f"Failed to add new points or box: {e}")
|
| 198 |
raise
|
| 199 |
|
| 200 |
def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
|
| 201 |
-
"
|
| 202 |
if self.predictor is None:
|
|
|
|
| 203 |
raise RuntimeError("Predictor not loaded")
|
| 204 |
-
|
| 205 |
try:
|
| 206 |
-
|
| 207 |
-
|
|
|
|
| 208 |
except Exception as e:
|
| 209 |
-
logger.error(f"Failed to propagate in video: {e}")
|
| 210 |
raise
|
| 211 |
|
| 212 |
def prune_state(self, inference_state, keep: int):
|
| 213 |
-
"
|
| 214 |
try:
|
| 215 |
-
# Try to access and prune internal caches
|
| 216 |
-
# This is model-specific and may need adjustment based on SAM2 internals
|
| 217 |
if hasattr(inference_state, 'cached_features'):
|
| 218 |
-
# Keep only the most recent 'keep' frames
|
| 219 |
cached_keys = list(inference_state.cached_features.keys())
|
| 220 |
if len(cached_keys) > keep:
|
| 221 |
keys_to_remove = cached_keys[:-keep]
|
|
@@ -223,26 +215,20 @@ def prune_state(self, inference_state, keep: int):
|
|
| 223 |
if key in inference_state.cached_features:
|
| 224 |
del inference_state.cached_features[key]
|
| 225 |
logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
|
| 226 |
-
|
| 227 |
-
# Clear other potential caches
|
| 228 |
if hasattr(inference_state, 'point_inputs_per_obj'):
|
| 229 |
-
# Keep recent point inputs only
|
| 230 |
for obj_id in list(inference_state.point_inputs_per_obj.keys()):
|
| 231 |
obj_inputs = inference_state.point_inputs_per_obj[obj_id]
|
| 232 |
if len(obj_inputs) > keep:
|
| 233 |
-
# Keep only recent entries
|
| 234 |
recent_keys = sorted(obj_inputs.keys())[-keep:]
|
| 235 |
new_inputs = {k: obj_inputs[k] for k in recent_keys}
|
| 236 |
inference_state.point_inputs_per_obj[obj_id] = new_inputs
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
torch.cuda.empty_cache() if self.device.type == 'cuda' else None
|
| 240 |
-
|
| 241 |
except Exception as e:
|
| 242 |
-
logger.debug(f"State pruning warning: {e}")
|
| 243 |
|
| 244 |
def clear_memory(self):
|
| 245 |
-
"
|
| 246 |
try:
|
| 247 |
if self.device.type == 'cuda':
|
| 248 |
torch.cuda.empty_cache()
|
|
@@ -250,35 +236,34 @@ def clear_memory(self):
|
|
| 250 |
torch.cuda.ipc_collect()
|
| 251 |
gc.collect()
|
| 252 |
except Exception as e:
|
| 253 |
-
logger.warning(f"Memory clearing warning: {e}")
|
| 254 |
|
| 255 |
def get_memory_usage(self) -> Dict[str, float]:
|
| 256 |
-
"
|
| 257 |
if self.device.type != 'cuda':
|
| 258 |
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
|
| 259 |
-
|
| 260 |
try:
|
| 261 |
allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
|
| 262 |
reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
|
| 263 |
free, total = torch.cuda.mem_get_info(self.device)
|
| 264 |
free_gb = free / (1024**3)
|
| 265 |
-
|
| 266 |
return {
|
| 267 |
"allocated_gb": allocated,
|
| 268 |
"reserved_gb": reserved,
|
| 269 |
"free_gb": free_gb,
|
| 270 |
"total_gb": total / (1024**3)
|
| 271 |
}
|
| 272 |
-
except Exception:
|
|
|
|
| 273 |
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
|
| 274 |
|
| 275 |
def __del__(self):
|
| 276 |
-
"
|
| 277 |
try:
|
| 278 |
if hasattr(self, 'predictor') and self.predictor is not None:
|
| 279 |
del self.predictor
|
| 280 |
if hasattr(self, 'model') and self.model is not None:
|
| 281 |
del self.model
|
| 282 |
self.clear_memory()
|
| 283 |
-
except Exception:
|
| 284 |
-
|
|
|
|
| 3 |
SAM2 Loader with Hugging Face Hub integration
|
| 4 |
Provides SAM2Predictor class with memory management and optimization features
|
| 5 |
Updated to use Hugging Face Hub models instead of direct downloads
|
| 6 |
+
(Enhanced logging and exception safety)
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
from typing import Optional, Any, Dict, List, Tuple
|
| 16 |
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
class SAM2Predictor:
|
|
|
|
| 23 |
"""
|
| 24 |
|
| 25 |
def __init__(self, device: torch.device, model_size: str = "small"):
|
| 26 |
+
logger.info(f"[SAM2Predictor.__init__] device={device}, model_size={model_size}") # [LOG+SAFETY PATCH]
|
| 27 |
self.device = device
|
| 28 |
self.model_size = model_size
|
| 29 |
self.predictor = None
|
|
|
|
| 33 |
def _load_predictor(self):
|
| 34 |
"""Load SAM2 predictor with Hugging Face Hub integration"""
|
| 35 |
try:
|
| 36 |
+
logger.info("[SAM2Predictor._load_predictor] Loading SAM2 predictor...") # [LOG+SAFETY PATCH]
|
| 37 |
from sam2.build_sam import build_sam2_video_predictor
|
| 38 |
|
|
|
|
| 39 |
checkpoint_path = self._get_hf_checkpoint()
|
| 40 |
if not checkpoint_path:
|
| 41 |
+
logger.error(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") # [LOG+SAFETY PATCH]
|
| 42 |
raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub")
|
| 43 |
|
|
|
|
| 44 |
model_cfg = self._get_model_config()
|
| 45 |
+
logger.info(f"[SAM2Predictor._load_predictor] Using model_cfg: {model_cfg}") # [LOG+SAFETY PATCH]
|
| 46 |
|
|
|
|
| 47 |
self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
|
|
|
|
|
|
|
| 48 |
self._optimize_for_t4()
|
|
|
|
| 49 |
logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub")
|
|
|
|
| 50 |
except ImportError as e:
|
| 51 |
logger.error(f"SAM2 import failed: {e}")
|
| 52 |
raise RuntimeError("SAM2 not available - check sam2 installation")
|
| 53 |
except Exception as e:
|
| 54 |
+
logger.error(f"SAM2 loading failed: {e}", exc_info=True)
|
| 55 |
raise
|
| 56 |
|
| 57 |
def _get_hf_checkpoint(self) -> Optional[str]:
|
| 58 |
"""Download checkpoint from Hugging Face Hub"""
|
| 59 |
try:
|
| 60 |
+
logger.info(f"[SAM2Predictor._get_hf_checkpoint] Downloading checkpoint...") # [LOG+SAFETY PATCH]
|
| 61 |
from huggingface_hub import hf_hub_download
|
| 62 |
|
|
|
|
| 63 |
repo_mapping = {
|
| 64 |
"small": "facebook/sam2-hiera-small",
|
| 65 |
"base": "facebook/sam2-hiera-base-plus",
|
| 66 |
"large": "facebook/sam2-hiera-large"
|
| 67 |
}
|
|
|
|
| 68 |
filename_mapping = {
|
| 69 |
"small": "sam2_hiera_small.pt",
|
| 70 |
"base": "sam2_hiera_base_plus.pt",
|
| 71 |
"large": "sam2_hiera_large.pt"
|
| 72 |
}
|
|
|
|
| 73 |
if self.model_size not in repo_mapping:
|
| 74 |
logger.error(f"Unknown model size: {self.model_size}")
|
| 75 |
return None
|
|
|
|
| 76 |
repo_id = repo_mapping[self.model_size]
|
| 77 |
filename = filename_mapping[self.model_size]
|
|
|
|
| 78 |
logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}")
|
|
|
|
|
|
|
| 79 |
checkpoint_path = hf_hub_download(
|
| 80 |
repo_id=repo_id,
|
| 81 |
filename=filename,
|
| 82 |
+
cache_dir=None,
|
| 83 |
+
force_download=False,
|
| 84 |
+
token=None
|
| 85 |
)
|
|
|
|
| 86 |
logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}")
|
| 87 |
return checkpoint_path
|
|
|
|
| 88 |
except Exception as e:
|
| 89 |
logger.error(f"HF Hub download failed: {e}")
|
|
|
|
| 90 |
return self._fallback_local_checkpoint()
|
| 91 |
|
| 92 |
def _fallback_local_checkpoint(self) -> Optional[str]:
|
|
|
|
| 110 |
"base": "sam2_hiera_b+.yaml",
|
| 111 |
"large": "sam2_hiera_l.yaml"
|
| 112 |
}
|
| 113 |
+
cfg = config_mapping.get(self.model_size, "sam2_hiera_s.yaml")
|
| 114 |
+
logger.info(f"[SAM2Predictor._get_model_config] Returning config: {cfg}") # [LOG+SAFETY PATCH]
|
| 115 |
+
return cfg
|
| 116 |
|
| 117 |
def _optimize_for_t4(self):
|
| 118 |
"""Apply T4-specific optimizations"""
|
| 119 |
try:
|
| 120 |
+
logger.info("[SAM2Predictor._optimize_for_t4] Optimizing for T4...") # [LOG+SAFETY PATCH]
|
| 121 |
if hasattr(self.predictor, "model") and self.predictor.model is not None:
|
| 122 |
self.model = self.predictor.model
|
|
|
|
|
|
|
| 123 |
self.model = self.model.half().to(self.device)
|
| 124 |
self.model = self.model.to(memory_format=torch.channels_last)
|
|
|
|
| 125 |
logger.info("SAM2: fp16 + channels_last applied for T4 optimization")
|
|
|
|
| 126 |
except Exception as e:
|
| 127 |
+
logger.warning(f"SAM2 T4 optimization warning: {e}", exc_info=True)
|
| 128 |
|
| 129 |
def init_state(self, video_path: str):
|
| 130 |
+
logger.info(f"[SAM2Predictor.init_state] Initializing video state for: {video_path}") # [LOG+SAFETY PATCH]
|
| 131 |
if self.predictor is None:
|
| 132 |
+
logger.error("Predictor not loaded in init_state")
|
| 133 |
raise RuntimeError("Predictor not loaded")
|
|
|
|
| 134 |
try:
|
| 135 |
+
state = self.predictor.init_state(video_path=video_path)
|
| 136 |
+
logger.info("[SAM2Predictor.init_state] Video state initialized OK")
|
| 137 |
+
return state
|
| 138 |
except Exception as e:
|
| 139 |
+
logger.error(f"Failed to initialize video state: {e}", exc_info=True)
|
| 140 |
raise
|
| 141 |
|
| 142 |
def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
|
| 143 |
points: np.ndarray, labels: np.ndarray):
|
| 144 |
+
logger.info(f"[SAM2Predictor.add_new_points] Adding points for frame {frame_idx}, obj {obj_id}") # [LOG+SAFETY PATCH]
|
| 145 |
if self.predictor is None:
|
| 146 |
+
logger.error("Predictor not loaded in add_new_points")
|
| 147 |
raise RuntimeError("Predictor not loaded")
|
|
|
|
| 148 |
try:
|
| 149 |
+
out = self.predictor.add_new_points(
|
| 150 |
inference_state=inference_state,
|
| 151 |
frame_idx=frame_idx,
|
| 152 |
obj_id=obj_id,
|
| 153 |
points=points,
|
| 154 |
labels=labels
|
| 155 |
)
|
| 156 |
+
logger.info(f"[SAM2Predictor.add_new_points] Points added OK")
|
| 157 |
+
return out
|
| 158 |
except Exception as e:
|
| 159 |
+
logger.error(f"Failed to add new points: {e}", exc_info=True)
|
| 160 |
raise
|
| 161 |
|
| 162 |
def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
|
| 163 |
points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True):
|
| 164 |
+
logger.info(f"[SAM2Predictor.add_new_points_or_box] Adding points/box for frame {frame_idx}, obj {obj_id}") # [LOG+SAFETY PATCH]
|
| 165 |
if self.predictor is None:
|
| 166 |
+
logger.error("Predictor not loaded in add_new_points_or_box")
|
| 167 |
raise RuntimeError("Predictor not loaded")
|
|
|
|
| 168 |
try:
|
|
|
|
| 169 |
if hasattr(self.predictor, 'add_new_points_or_box'):
|
| 170 |
+
out = self.predictor.add_new_points_or_box(
|
| 171 |
inference_state=inference_state,
|
| 172 |
frame_idx=frame_idx,
|
| 173 |
obj_id=obj_id,
|
|
|
|
| 175 |
labels=labels,
|
| 176 |
clear_old_points=clear_old_points
|
| 177 |
)
|
| 178 |
+
logger.info(f"[SAM2Predictor.add_new_points_or_box] Used new API, points/box added OK")
|
| 179 |
+
return out
|
| 180 |
else:
|
| 181 |
+
out = self.predictor.add_new_points(
|
|
|
|
| 182 |
inference_state=inference_state,
|
| 183 |
frame_idx=frame_idx,
|
| 184 |
obj_id=obj_id,
|
| 185 |
points=points,
|
| 186 |
labels=labels
|
| 187 |
)
|
| 188 |
+
logger.info(f"[SAM2Predictor.add_new_points_or_box] Used fallback, points added OK")
|
| 189 |
+
return out
|
| 190 |
except Exception as e:
|
| 191 |
+
logger.error(f"Failed to add new points or box: {e}", exc_info=True)
|
| 192 |
raise
|
| 193 |
|
| 194 |
def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
|
| 195 |
+
logger.info(f"[SAM2Predictor.propagate_in_video] Propagating in video...") # [LOG+SAFETY PATCH]
|
| 196 |
if self.predictor is None:
|
| 197 |
+
logger.error("Predictor not loaded in propagate_in_video")
|
| 198 |
raise RuntimeError("Predictor not loaded")
|
|
|
|
| 199 |
try:
|
| 200 |
+
out = self.predictor.propagate_in_video(inference_state, **kwargs)
|
| 201 |
+
logger.info(f"[SAM2Predictor.propagate_in_video] Propagation OK")
|
| 202 |
+
return out
|
| 203 |
except Exception as e:
|
| 204 |
+
logger.error(f"Failed to propagate in video: {e}", exc_info=True)
|
| 205 |
raise
|
| 206 |
|
| 207 |
def prune_state(self, inference_state, keep: int):
|
| 208 |
+
logger.info(f"[SAM2Predictor.prune_state] Pruning state to keep {keep} frames...") # [LOG+SAFETY PATCH]
|
| 209 |
try:
|
|
|
|
|
|
|
| 210 |
if hasattr(inference_state, 'cached_features'):
|
|
|
|
| 211 |
cached_keys = list(inference_state.cached_features.keys())
|
| 212 |
if len(cached_keys) > keep:
|
| 213 |
keys_to_remove = cached_keys[:-keep]
|
|
|
|
| 215 |
if key in inference_state.cached_features:
|
| 216 |
del inference_state.cached_features[key]
|
| 217 |
logger.debug(f"Pruned {len(keys_to_remove)} old cached features")
|
|
|
|
|
|
|
| 218 |
if hasattr(inference_state, 'point_inputs_per_obj'):
|
|
|
|
| 219 |
for obj_id in list(inference_state.point_inputs_per_obj.keys()):
|
| 220 |
obj_inputs = inference_state.point_inputs_per_obj[obj_id]
|
| 221 |
if len(obj_inputs) > keep:
|
|
|
|
| 222 |
recent_keys = sorted(obj_inputs.keys())[-keep:]
|
| 223 |
new_inputs = {k: obj_inputs[k] for k in recent_keys}
|
| 224 |
inference_state.point_inputs_per_obj[obj_id] = new_inputs
|
| 225 |
+
if self.device.type == 'cuda':
|
| 226 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 227 |
except Exception as e:
|
| 228 |
+
logger.debug(f"State pruning warning: {e}", exc_info=True)
|
| 229 |
|
| 230 |
def clear_memory(self):
|
| 231 |
+
logger.info("[SAM2Predictor.clear_memory] Clearing GPU memory") # [LOG+SAFETY PATCH]
|
| 232 |
try:
|
| 233 |
if self.device.type == 'cuda':
|
| 234 |
torch.cuda.empty_cache()
|
|
|
|
| 236 |
torch.cuda.ipc_collect()
|
| 237 |
gc.collect()
|
| 238 |
except Exception as e:
|
| 239 |
+
logger.warning(f"Memory clearing warning: {e}", exc_info=True)
|
| 240 |
|
| 241 |
def get_memory_usage(self) -> Dict[str, float]:
|
| 242 |
+
logger.info("[SAM2Predictor.get_memory_usage] Checking memory usage") # [LOG+SAFETY PATCH]
|
| 243 |
if self.device.type != 'cuda':
|
| 244 |
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
|
|
|
|
| 245 |
try:
|
| 246 |
allocated = torch.cuda.memory_allocated(self.device) / (1024**3)
|
| 247 |
reserved = torch.cuda.memory_reserved(self.device) / (1024**3)
|
| 248 |
free, total = torch.cuda.mem_get_info(self.device)
|
| 249 |
free_gb = free / (1024**3)
|
|
|
|
| 250 |
return {
|
| 251 |
"allocated_gb": allocated,
|
| 252 |
"reserved_gb": reserved,
|
| 253 |
"free_gb": free_gb,
|
| 254 |
"total_gb": total / (1024**3)
|
| 255 |
}
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.warning(f"Error checking memory usage: {e}", exc_info=True)
|
| 258 |
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0}
|
| 259 |
|
| 260 |
def __del__(self):
|
| 261 |
+
logger.info("[SAM2Predictor.__del__] Cleaning up...") # [LOG+SAFETY PATCH]
|
| 262 |
try:
|
| 263 |
if hasattr(self, 'predictor') and self.predictor is not None:
|
| 264 |
del self.predictor
|
| 265 |
if hasattr(self, 'model') and self.model is not None:
|
| 266 |
del self.model
|
| 267 |
self.clear_memory()
|
| 268 |
+
except Exception as e:
|
| 269 |
+
logger.warning(f"Error in __del__: {e}", exc_info=True)
|