MogensR commited on
Commit
ce45f53
Β·
1 Parent(s): cca0593

final fix

Browse files
Files changed (1) hide show
  1. models/sam2_loader.py +92 -21
models/sam2_loader.py CHANGED
@@ -3,16 +3,14 @@
3
  SAM2 Loader β€” Robust loading and mask generation for SAM2
4
  ========================================================
5
  - Loads SAM2 model with Hydra config resolution
 
6
  - Generates seed masks for MatAnyone
7
  - Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
8
 
9
- Changes (2025-09-16):
10
- - Aligned with torch==2.3.1+cu121 and SAM2 commit
11
- - Added GPU memory logging for Tesla T4
12
- - Added SAM2 version logging via importlib.metadata
13
- - Simplified config resolution to match __init__.py
14
- - Fixed missing sys and inspect imports
15
- - Added SAM2Model wrapper class for app_hf.py compatibility
16
  """
17
 
18
  from __future__ import annotations
@@ -22,6 +20,8 @@
22
  import inspect
23
  import logging
24
  import importlib.metadata
 
 
25
  from pathlib import Path
26
  from typing import Optional, Tuple, Dict, Any
27
 
@@ -55,6 +55,60 @@ def _add_sys_path(p: Path) -> None:
55
 
56
  _add_sys_path(TP_SAM2)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # --------------------------------------------------------------------------------------
59
  # Safe Torch accessors
60
  # --------------------------------------------------------------------------------------
@@ -80,7 +134,7 @@ def _pick_device(env_key: str) -> str:
80
  requested = os.environ.get(env_key, "").strip().lower()
81
  has_cuda = _has_cuda()
82
 
83
- logger.info(f"CUDA environment variables: {{'SAM2_DEVICE': '{os.environ.get('SAM2_DEVICE', '')}'}}")
84
  logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
85
 
86
  if has_cuda and requested not in {"cpu"}:
@@ -100,12 +154,13 @@ def _pick_device(env_key: str) -> str:
100
  def _resolve_sam2_cfg(cfg_str: str) -> str:
101
  """Resolve SAM2 config path - return relative path for Hydra compatibility."""
102
  logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
 
103
 
104
- candidate = os.path.join(TP_SAM2, cfg_str)
105
  logger.info(f"Candidate path: {candidate}")
106
- logger.info(f"Candidate exists: {os.path.exists(candidate)}")
107
 
108
- if os.path.exists(candidate):
109
  if cfg_str.startswith("sam2/configs/"):
110
  relative_path = cfg_str.replace("sam2/configs/", "configs/")
111
  else:
@@ -114,15 +169,15 @@ def _resolve_sam2_cfg(cfg_str: str) -> str:
114
  return relative_path
115
 
116
  fallbacks = [
117
- os.path.join(TP_SAM2, "sam2", cfg_str),
118
- os.path.join(TP_SAM2, "configs", cfg_str),
119
  ]
120
 
121
  for fallback in fallbacks:
122
  logger.info(f"Trying fallback: {fallback}")
123
- if os.path.exists(fallback):
124
- if "configs/" in fallback:
125
- relative_path = "configs/" + fallback.split("configs/")[-1]
126
  logger.info(f"Returning fallback relative path: {relative_path}")
127
  return relative_path
128
 
@@ -156,7 +211,7 @@ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
156
  return None
157
 
158
  def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
159
- """Robust SAM2 loader with config resolution and error handling."""
160
  meta = {"sam2_import_ok": False, "sam2_init_ok": False}
161
  try:
162
  from sam2.build_sam import build_sam2
@@ -174,14 +229,26 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
174
  logger.info("[SAM2] SAM2 version unknown")
175
 
176
  # Check GPU memory before loading
177
- if torch and torch.cuda.is_available():
178
  mem_before = torch.cuda.memory_allocated() / 1024**3
179
- logger.info(f"GPU memory before SAM2 load: {mem_before:.2f}GB")
180
 
181
  device = _pick_device("SAM2_DEVICE")
182
  cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
183
  cfg = _resolve_sam2_cfg(cfg_env)
184
- ckpt = os.environ.get("SAM2_CHECKPOINT", "")
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def _try_build(cfg_path: str):
187
  logger.info(f"_try_build called with cfg_path: {cfg_path}")
@@ -238,12 +305,16 @@ def _try_build(cfg_path: str):
238
  predictor = SAM2ImagePredictor(sam)
239
  meta["sam2_init_ok"] = True
240
  meta["sam2_device"] = device
 
241
  return predictor, True, meta
242
  else:
 
243
  return None, False, meta
244
 
245
  except Exception as e:
246
- logger.error(f"SAM2 loading failed: {e}")
 
 
247
  return None, False, meta
248
 
249
  def run_sam2_mask(predictor: object,
 
3
  SAM2 Loader β€” Robust loading and mask generation for SAM2
4
  ========================================================
5
  - Loads SAM2 model with Hydra config resolution
6
+ - Auto-downloads missing checkpoint files
7
  - Generates seed masks for MatAnyone
8
  - Aligned with torch==2.3.1+cu121 and SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
9
 
10
+ Changes (2025-09-17):
11
+ - Added automatic checkpoint download functionality
12
+ - Enhanced error handling and logging
13
+ - Fixed missing checkpoint issue that was causing fallback mask generation
 
 
 
14
  """
15
 
16
  from __future__ import annotations
 
20
  import inspect
21
  import logging
22
  import importlib.metadata
23
+ import urllib.request
24
+ import urllib.error
25
  from pathlib import Path
26
  from typing import Optional, Tuple, Dict, Any
27
 
 
55
 
56
  _add_sys_path(TP_SAM2)
57
 
58
+ # --------------------------------------------------------------------------------------
59
+ # Checkpoint Download Functionality
60
+ # --------------------------------------------------------------------------------------
61
+ SAM2_CHECKPOINT_URLS = {
62
+ "sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
63
+ "sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
64
+ "sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
65
+ "sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
66
+ }
67
+
68
+ def _download_checkpoint(checkpoint_path: str, checkpoint_name: str) -> bool:
69
+ """Download SAM2 checkpoint if it doesn't exist."""
70
+ if os.path.exists(checkpoint_path):
71
+ logger.info(f"Checkpoint already exists: {checkpoint_path}")
72
+ return True
73
+
74
+ if checkpoint_name not in SAM2_CHECKPOINT_URLS:
75
+ logger.error(f"Unknown checkpoint: {checkpoint_name}. Available: {list(SAM2_CHECKPOINT_URLS.keys())}")
76
+ return False
77
+
78
+ url = SAM2_CHECKPOINT_URLS[checkpoint_name]
79
+ logger.info(f"Downloading SAM2 checkpoint: {checkpoint_name}")
80
+ logger.info(f"URL: {url}")
81
+ logger.info(f"Destination: {checkpoint_path}")
82
+
83
+ try:
84
+ # Create directory if it doesn't exist
85
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
86
+
87
+ # Download with progress
88
+ def _progress_hook(block_num, block_size, total_size):
89
+ if total_size > 0:
90
+ percent = min(100, block_num * block_size * 100 // total_size)
91
+ if percent % 10 == 0: # Log every 10%
92
+ logger.info(f"Download progress: {percent}%")
93
+
94
+ urllib.request.urlretrieve(url, checkpoint_path, reporthook=_progress_hook)
95
+
96
+ # Verify the file was downloaded
97
+ if os.path.exists(checkpoint_path) and os.path.getsize(checkpoint_path) > 0:
98
+ size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
99
+ logger.info(f"Successfully downloaded {checkpoint_name} ({size_mb:.1f} MB)")
100
+ return True
101
+ else:
102
+ logger.error(f"Download failed: {checkpoint_path} does not exist or is empty")
103
+ return False
104
+
105
+ except urllib.error.URLError as e:
106
+ logger.error(f"URL error downloading checkpoint: {e}")
107
+ return False
108
+ except Exception as e:
109
+ logger.error(f"Error downloading checkpoint: {e}")
110
+ return False
111
+
112
  # --------------------------------------------------------------------------------------
113
  # Safe Torch accessors
114
  # --------------------------------------------------------------------------------------
 
134
  requested = os.environ.get(env_key, "").strip().lower()
135
  has_cuda = _has_cuda()
136
 
137
+ logger.info(f"CUDA environment variables: {dict((k, v) for k, v in os.environ.items() if 'CUDA' in k or 'SAM2' in k)}")
138
  logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}")
139
 
140
  if has_cuda and requested not in {"cpu"}:
 
154
  def _resolve_sam2_cfg(cfg_str: str) -> str:
155
  """Resolve SAM2 config path - return relative path for Hydra compatibility."""
156
  logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}")
157
+ logger.info(f"TP_SAM2 = {TP_SAM2}")
158
 
159
+ candidate = TP_SAM2 / cfg_str
160
  logger.info(f"Candidate path: {candidate}")
161
+ logger.info(f"Candidate exists: {candidate.exists()}")
162
 
163
+ if candidate.exists():
164
  if cfg_str.startswith("sam2/configs/"):
165
  relative_path = cfg_str.replace("sam2/configs/", "configs/")
166
  else:
 
169
  return relative_path
170
 
171
  fallbacks = [
172
+ TP_SAM2 / "sam2" / cfg_str,
173
+ TP_SAM2 / "configs" / cfg_str,
174
  ]
175
 
176
  for fallback in fallbacks:
177
  logger.info(f"Trying fallback: {fallback}")
178
+ if fallback.exists():
179
+ if "configs" in str(fallback):
180
+ relative_path = "configs/" + str(fallback).split("configs/")[-1]
181
  logger.info(f"Returning fallback relative path: {relative_path}")
182
  return relative_path
183
 
 
211
  return None
212
 
213
  def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
214
+ """Robust SAM2 loader with config resolution and checkpoint auto-download."""
215
  meta = {"sam2_import_ok": False, "sam2_init_ok": False}
216
  try:
217
  from sam2.build_sam import build_sam2
 
229
  logger.info("[SAM2] SAM2 version unknown")
230
 
231
  # Check GPU memory before loading
232
+ if torch.cuda.is_available():
233
  mem_before = torch.cuda.memory_allocated() / 1024**3
234
+ logger.info(f"πŸ” GPU memory before SAM2 load: {mem_before:.2f}GB")
235
 
236
  device = _pick_device("SAM2_DEVICE")
237
  cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml")
238
  cfg = _resolve_sam2_cfg(cfg_env)
239
+
240
+ # Handle checkpoint with auto-download
241
+ ckpt = os.environ.get("SAM2_CHECKPOINT", "/home/user/app/checkpoints/sam2_hiera_large.pt")
242
+ checkpoint_name = os.path.basename(ckpt)
243
+
244
+ # Auto-download checkpoint if missing
245
+ if not os.path.exists(ckpt):
246
+ logger.info(f"SAM2 checkpoint not found: {ckpt}")
247
+ if not _download_checkpoint(ckpt, checkpoint_name):
248
+ logger.error(f"Failed to download SAM2 checkpoint: {checkpoint_name}")
249
+ return None, False, meta
250
+ else:
251
+ logger.info(f"Using existing SAM2 checkpoint: {ckpt}")
252
 
253
  def _try_build(cfg_path: str):
254
  logger.info(f"_try_build called with cfg_path: {cfg_path}")
 
305
  predictor = SAM2ImagePredictor(sam)
306
  meta["sam2_init_ok"] = True
307
  meta["sam2_device"] = device
308
+ logger.info("βœ… SAM2 loaded successfully with auto-downloaded checkpoint")
309
  return predictor, True, meta
310
  else:
311
+ logger.error("❌ SAM2 initialization returned None")
312
  return None, False, meta
313
 
314
  except Exception as e:
315
+ logger.error(f"❌ SAM2 loading failed: {e}")
316
+ import traceback
317
+ logger.error(f"SAM2 loading traceback: {traceback.format_exc()}")
318
  return None, False, meta
319
 
320
  def run_sam2_mask(predictor: object,