MogensR commited on
Commit
d909e1e
·
verified ·
1 Parent(s): 14dbfef

Update models/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/sam2_loader.py +97 -75
models/sam2_loader.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- SAM2 Loader with T4-optimized predictor wrapper
4
  Provides SAM2Predictor class with memory management and optimization features
 
5
  """
6
 
7
  import os
@@ -27,109 +28,100 @@ def __init__(self, device: torch.device, model_size: str = "small"):
27
  self._load_predictor()
28
 
29
  def _load_predictor(self):
30
- """Load SAM2 predictor with optimizations"""
31
  try:
32
  from sam2.build_sam import build_sam2_video_predictor
33
 
34
- # Download checkpoint if needed
35
- checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt"
36
- if not self._ensure_checkpoint(checkpoint_path):
37
- raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint")
 
 
 
38
 
39
  # Build predictor
40
- model_cfg = f"sam2_hiera_{self.model_size[0]}.yaml" # small -> s, base -> b, large -> l
41
  self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
42
 
43
  # Apply T4 optimizations
44
  self._optimize_for_t4()
45
 
46
- logger.info(f"SAM2 {self.model_size} predictor loaded successfully")
47
 
48
  except ImportError as e:
49
  logger.error(f"SAM2 import failed: {e}")
50
- raise RuntimeError("SAM2 not available - check third_party/sam2 installation")
51
  except Exception as e:
52
  logger.error(f"SAM2 loading failed: {e}")
53
  raise
54
 
55
- def _ensure_checkpoint(self, checkpoint_path: str) -> bool:
56
- """Ensure checkpoint exists, download if needed"""
57
- checkpoint_file = Path(checkpoint_path)
58
-
59
- if checkpoint_file.exists():
60
- file_size = checkpoint_file.stat().st_size / (1024**2)
61
- if file_size > 50: # At least 50MB
62
- logger.info(f"SAM2 checkpoint exists: {file_size:.1f}MB")
63
- return True
64
- else:
65
- logger.warning(f"Checkpoint too small ({file_size:.1f}MB), re-downloading")
66
- checkpoint_file.unlink()
67
-
68
- return self._download_checkpoint(checkpoint_path)
69
-
70
- def _download_checkpoint(self, checkpoint_path: str, timeout_seconds: int = 600) -> bool:
71
- """Download SAM2 checkpoint"""
72
  try:
73
- logger.info(f"Downloading SAM2 {self.model_size} checkpoint...")
74
-
75
- checkpoint_file = Path(checkpoint_path)
76
- checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
77
 
78
- import requests
79
-
80
- # Checkpoint URLs
81
- urls = {
82
- "small": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
83
- "base": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
84
- "large": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
85
  }
86
 
87
- if self.model_size not in urls:
88
- raise ValueError(f"Unknown model size: {self.model_size}")
89
-
90
- checkpoint_url = urls[self.model_size]
91
-
92
- import time
93
- start_time = time.time()
94
- response = requests.get(checkpoint_url, stream=True, timeout=30)
95
- response.raise_for_status()
96
-
97
- total_size = int(response.headers.get('content-length', 0))
98
 
99
- temp_path = checkpoint_file.with_suffix('.download')
100
- downloaded = 0
101
- last_log = start_time
102
 
103
- with open(temp_path, 'wb') as f:
104
- for chunk in response.iter_content(chunk_size=1024*1024):
105
- if chunk:
106
- f.write(chunk)
107
- downloaded += len(chunk)
108
-
109
- current_time = time.time()
110
- if current_time - start_time > timeout_seconds:
111
- raise TimeoutError(f"Download timeout after {timeout_seconds}s")
112
-
113
- # Progress logging every 15 seconds
114
- if current_time - last_log > 15:
115
- progress = (downloaded / total_size * 100) if total_size > 0 else 0
116
- speed = downloaded / (current_time - start_time) / (1024**2)
117
- logger.info(f"Download: {progress:.1f}% ({speed:.1f}MB/s)")
118
- last_log = current_time
119
 
120
- temp_path.rename(checkpoint_file)
121
 
122
- download_time = time.time() - start_time
123
- speed = downloaded / download_time / (1024**2)
124
- logger.info(f"Download complete: {downloaded/(1024**2):.1f}MB in {download_time:.1f}s ({speed:.1f}MB/s)")
 
 
 
 
 
125
 
126
- return True
 
127
 
128
  except Exception as e:
129
- logger.error(f"Checkpoint download failed: {e}")
 
 
 
 
 
 
 
130
  if Path(checkpoint_path).exists():
131
- Path(checkpoint_path).unlink()
132
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def _optimize_for_t4(self):
135
  """Apply T4-specific optimizations"""
@@ -175,6 +167,36 @@ def add_new_points(self, inference_state, frame_idx: int, obj_id: int,
175
  logger.error(f"Failed to add new points: {e}")
176
  raise
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
179
  """Propagate through video with optional scaling"""
180
  if self.predictor is None:
 
1
  #!/usr/bin/env python3
2
  """
3
+ SAM2 Loader with Hugging Face Hub integration
4
  Provides SAM2Predictor class with memory management and optimization features
5
+ Updated to use Hugging Face Hub models instead of direct downloads
6
  """
7
 
8
  import os
 
28
  self._load_predictor()
29
 
30
  def _load_predictor(self):
31
+ """Load SAM2 predictor with Hugging Face Hub integration"""
32
  try:
33
  from sam2.build_sam import build_sam2_video_predictor
34
 
35
+ # Get checkpoint from Hugging Face Hub
36
+ checkpoint_path = self._get_hf_checkpoint()
37
+ if not checkpoint_path:
38
+ raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub")
39
+
40
+ # Get model config
41
+ model_cfg = self._get_model_config()
42
 
43
  # Build predictor
 
44
  self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device)
45
 
46
  # Apply T4 optimizations
47
  self._optimize_for_t4()
48
 
49
+ logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub")
50
 
51
  except ImportError as e:
52
  logger.error(f"SAM2 import failed: {e}")
53
+ raise RuntimeError("SAM2 not available - check sam2 installation")
54
  except Exception as e:
55
  logger.error(f"SAM2 loading failed: {e}")
56
  raise
57
 
58
+ def _get_hf_checkpoint(self) -> Optional[str]:
59
+ """Download checkpoint from Hugging Face Hub"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
+ from huggingface_hub import hf_hub_download
 
 
 
62
 
63
+ # Repository mapping for different model sizes
64
+ repo_mapping = {
65
+ "small": "facebook/sam2-hiera-small",
66
+ "base": "facebook/sam2-hiera-base-plus",
67
+ "large": "facebook/sam2-hiera-large"
 
 
68
  }
69
 
70
+ filename_mapping = {
71
+ "small": "sam2_hiera_small.pt",
72
+ "base": "sam2_hiera_base_plus.pt",
73
+ "large": "sam2_hiera_large.pt"
74
+ }
 
 
 
 
 
 
75
 
76
+ if self.model_size not in repo_mapping:
77
+ logger.error(f"Unknown model size: {self.model_size}")
78
+ return None
79
 
80
+ repo_id = repo_mapping[self.model_size]
81
+ filename = filename_mapping[self.model_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}")
84
 
85
+ # Download from Hugging Face Hub
86
+ checkpoint_path = hf_hub_download(
87
+ repo_id=repo_id,
88
+ filename=filename,
89
+ cache_dir=None, # Use default cache
90
+ force_download=False, # Use cached version if available
91
+ token=None # No auth token needed for public models
92
+ )
93
 
94
+ logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}")
95
+ return checkpoint_path
96
 
97
  except Exception as e:
98
+ logger.error(f"HF Hub download failed: {e}")
99
+ # Fallback to local checkpoint if HF download fails
100
+ return self._fallback_local_checkpoint()
101
+
102
+ def _fallback_local_checkpoint(self) -> Optional[str]:
103
+ """Fallback to local checkpoint files"""
104
+ try:
105
+ checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt"
106
  if Path(checkpoint_path).exists():
107
+ logger.info(f"Using local checkpoint: {checkpoint_path}")
108
+ return checkpoint_path
109
+ else:
110
+ logger.error(f"Local checkpoint not found: {checkpoint_path}")
111
+ return None
112
+ except Exception as e:
113
+ logger.error(f"Local checkpoint fallback failed: {e}")
114
+ return None
115
+
116
+ def _get_model_config(self) -> str:
117
+ """Get the appropriate model config file"""
118
+ config_mapping = {
119
+ "small": "sam2_hiera_s.yaml",
120
+ "base": "sam2_hiera_b+.yaml",
121
+ "large": "sam2_hiera_l.yaml"
122
+ }
123
+
124
+ return config_mapping.get(self.model_size, "sam2_hiera_s.yaml")
125
 
126
  def _optimize_for_t4(self):
127
  """Apply T4-specific optimizations"""
 
167
  logger.error(f"Failed to add new points: {e}")
168
  raise
169
 
170
+ def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int,
171
+ points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True):
172
+ """Add new points or box for tracking (newer SAM2 API)"""
173
+ if self.predictor is None:
174
+ raise RuntimeError("Predictor not loaded")
175
+
176
+ try:
177
+ # Try the newer API first
178
+ if hasattr(self.predictor, 'add_new_points_or_box'):
179
+ return self.predictor.add_new_points_or_box(
180
+ inference_state=inference_state,
181
+ frame_idx=frame_idx,
182
+ obj_id=obj_id,
183
+ points=points,
184
+ labels=labels,
185
+ clear_old_points=clear_old_points
186
+ )
187
+ else:
188
+ # Fallback to older API
189
+ return self.predictor.add_new_points(
190
+ inference_state=inference_state,
191
+ frame_idx=frame_idx,
192
+ obj_id=obj_id,
193
+ points=points,
194
+ labels=labels
195
+ )
196
+ except Exception as e:
197
+ logger.error(f"Failed to add new points or box: {e}")
198
+ raise
199
+
200
  def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs):
201
  """Propagate through video with optional scaling"""
202
  if self.predictor is None: