MogensR commited on
Commit
c9f07e1
·
verified ·
1 Parent(s): 1be662e

Update models/model_loaders.py

Browse files
Files changed (1) hide show
  1. models/model_loaders.py +42 -112
models/model_loaders.py CHANGED
@@ -2,6 +2,7 @@
2
  """
3
  Model Loading and Memory Management
4
  Handles lazy loading of SAM2 and MatAnyone models with caching
 
5
  """
6
 
7
  import os
@@ -11,172 +12,133 @@
11
  import torch
12
  import psutil
13
  import mediapipe as mp
 
14
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
- # Context manager for CUDA memory cleanup
18
- from contextlib import contextmanager
19
-
20
  @contextmanager
21
  def torch_memory_manager():
22
- """Context manager for CUDA memory cleanup."""
23
  try:
 
24
  yield
25
  finally:
26
  if torch.cuda.is_available():
27
  torch.cuda.empty_cache()
28
  gc.collect()
 
29
 
30
  def get_memory_usage():
31
- """Get current memory usage statistics."""
32
  memory_info = {}
33
-
34
- # GPU memory if available
35
  if torch.cuda.is_available():
36
  memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
37
  memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
38
  memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
39
  torch.cuda.memory_allocated()) / 1e9
40
-
41
- # RAM memory
42
  memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
43
  memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
44
-
45
  return memory_info
46
 
47
  def clear_model_cache():
48
- """Clear all cached models and free memory."""
49
  if hasattr(st, 'cache_resource'):
50
  st.cache_resource.clear()
51
  if torch.cuda.is_available():
52
  torch.cuda.empty_cache()
53
  gc.collect()
54
- logger.info("Model cache cleared")
55
-
56
- # ============================================================================
57
- # SAM2 Model Loading
58
- # ============================================================================
59
 
60
  @st.cache_resource(show_spinner=False)
61
  def load_sam2_predictor():
62
- """
63
- Lazy load SAM2 image predictor with fallback strategies.
64
- Returns (predictor, device) tuple. Returns (None, None) if loading fails.
65
- """
66
  try:
67
- print("Loading SAM2 image predictor...", flush=True)
68
  from sam2.build_sam import build_sam2
69
  from sam2.sam2_image_predictor import SAM2ImagePredictor
70
-
71
- # Determine device
72
  device = "cuda" if torch.cuda.is_available() else "cpu"
73
- print(f"Using device for SAM2: {device}", flush=True)
74
-
75
- # Try local checkpoints first
76
  checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
77
  model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
78
-
79
  if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
80
- print("Local checkpoints not found, using Hugging Face...", flush=True)
81
  predictor = SAM2ImagePredictor.from_pretrained(
82
  "facebook/sam2-hiera-large",
83
  device=device
84
  )
85
  else:
86
- # Check available GPU memory
87
  memory_info = get_memory_usage()
88
  gpu_free = memory_info.get('gpu_free', 0)
89
-
90
  if device == "cuda" and gpu_free < 4.0:
91
- print(f"Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model...", flush=True)
92
  try:
93
  predictor = SAM2ImagePredictor.from_pretrained(
94
  "facebook/sam2-hiera-tiny",
95
  device=device
96
  )
97
- except Exception:
 
98
  predictor = SAM2ImagePredictor.from_pretrained(
99
  "facebook/sam2-hiera-small",
100
  device=device
101
  )
102
  else:
103
- # Use local large model
104
  sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
105
  predictor = SAM2ImagePredictor(sam2_model)
106
-
107
- # CRITICAL: Verify and force model to correct device
108
  if hasattr(predictor, 'model'):
109
  predictor.model.to(device)
110
  predictor.model.eval()
111
- print(f"SAM2 model moved to {device} and set to eval mode", flush=True)
112
-
113
- print(f"✅ SAM2 loaded successfully on {device}!", flush=True)
114
  return predictor, device
115
-
116
  except Exception as e:
117
- print(f"❌ Failed to load SAM2 predictor: {e}", flush=True)
118
  import traceback
119
  traceback.print_exc()
120
  return None, None
121
 
122
- # Alias for backward compatibility
123
  def load_sam2():
124
- """Alias for load_sam2_predictor() - returns just predictor for compatibility"""
125
  predictor, device = load_sam2_predictor()
126
  return predictor
127
 
128
- # ============================================================================
129
- # MatAnyone Model Loading
130
- # ============================================================================
131
-
132
  @st.cache_resource(show_spinner=False)
133
  def load_matanyone_processor():
134
- """
135
- Lazy load MatAnyone processor with explicit GPU placement.
136
- Returns (processor, device) tuple. Returns (None, None) if loading fails.
137
- """
138
  try:
139
- print("Loading MatAnyone processor...", flush=True)
140
  from matanyone import InferenceCore
141
-
142
- # Determine device
143
  device = "cuda" if torch.cuda.is_available() else "cpu"
144
- print(f"MatAnyone using device: {device}", flush=True)
145
-
146
- # Load processor with explicit device
147
  processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
148
-
149
- # CRITICAL: Verify the processor's model is actually on GPU
150
  if hasattr(processor, 'model'):
151
  processor.model.to(device)
152
  processor.model.eval()
153
- print(f"MatAnyone model explicitly moved to {device}", flush=True)
154
-
155
- # Check if processor has device attribute and set it
156
  if not hasattr(processor, 'device'):
157
  processor.device = device
158
- print(f"Set processor.device to {device}", flush=True)
159
-
160
- print(f"✅ MatAnyone loaded successfully on {device}!", flush=True)
161
  return processor, device
162
-
163
  except Exception as e:
164
- print(f"❌ Failed to load MatAnyone: {e}", flush=True)
165
  import traceback
166
  traceback.print_exc()
167
  return None, None
168
 
169
- # Alias for backward compatibility
170
  def load_matanyone():
171
- """Alias for load_matanyone_processor() - returns just processor for compatibility"""
172
  processor, device = load_matanyone_processor()
173
  return processor
174
 
175
- # ============================================================================
176
- # MediaPipe Pose
177
- # ============================================================================
178
-
179
- # Initialize MediaPipe Pose as a module-level variable
180
  mp_pose = mp.solutions.pose
181
  pose = mp_pose.Pose(
182
  static_image_mode=False,
@@ -184,23 +146,13 @@ def load_matanyone():
184
  enable_segmentation=True,
185
  min_detection_confidence=0.5
186
  )
187
- print("✅ MediaPipe Pose initialized", flush=True)
188
-
189
- # ============================================================================
190
- # Model Health Check
191
- # ============================================================================
192
 
193
  def test_models():
194
- """
195
- Test if both models can load successfully.
196
- Returns dict with test results.
197
- """
198
  results = {
199
  'sam2': {'loaded': False, 'error': None, 'device': None},
200
  'matanyone': {'loaded': False, 'error': None, 'device': None}
201
  }
202
-
203
- # Test SAM2
204
  try:
205
  sam2_predictor, sam2_device = load_sam2_predictor()
206
  if sam2_predictor is not None:
@@ -210,8 +162,7 @@ def test_models():
210
  results['sam2']['error'] = "Predictor returned None"
211
  except Exception as e:
212
  results['sam2']['error'] = str(e)
213
-
214
- # Test MatAnyone
215
  try:
216
  matanyone_processor, matanyone_device = load_matanyone_processor()
217
  if matanyone_processor is not None:
@@ -221,53 +172,35 @@ def test_models():
221
  results['matanyone']['error'] = "Processor returned None"
222
  except Exception as e:
223
  results['matanyone']['error'] = str(e)
224
-
 
225
  return results
226
 
227
- # ============================================================================
228
- # Memory Monitoring
229
- # ============================================================================
230
-
231
  def log_memory_usage(stage=""):
232
- """Log current memory usage with optional stage label."""
233
  memory_info = get_memory_usage()
234
-
235
  log_msg = f"Memory usage"
236
  if stage:
237
  log_msg += f" ({stage})"
238
  log_msg += ":"
239
-
240
  if 'gpu_allocated' in memory_info:
241
  log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free"
242
-
243
  log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used"
244
-
245
  print(log_msg, flush=True)
246
  logger.info(log_msg)
247
  return memory_info
248
 
249
  def check_memory_available(required_gb=2.0):
250
- """
251
- Check if enough GPU memory is available.
252
- Returns (bool, float) - (is_available, free_gb)
253
- """
254
  if not torch.cuda.is_available():
255
  return False, 0.0
256
-
257
  memory_info = get_memory_usage()
258
  free_gb = memory_info.get('gpu_free', 0)
259
-
260
  return free_gb >= required_gb, free_gb
261
 
262
  def free_memory_aggressive():
263
- """Aggressively free GPU and system memory."""
264
  print("Performing aggressive memory cleanup...", flush=True)
265
- logger.info("Performing aggressive memory cleanup...")
266
-
267
- # Clear model cache
268
  clear_model_cache()
269
-
270
- # CUDA cleanup
271
  if torch.cuda.is_available():
272
  torch.cuda.empty_cache()
273
  torch.cuda.synchronize()
@@ -275,10 +208,7 @@ def free_memory_aggressive():
275
  torch.cuda.ipc_collect()
276
  except Exception:
277
  pass
278
-
279
- # System cleanup
280
  gc.collect()
281
-
282
  print("Memory cleanup complete", flush=True)
283
  logger.info("Memory cleanup complete")
284
- log_memory_usage("after cleanup")
 
2
  """
3
  Model Loading and Memory Management
4
  Handles lazy loading of SAM2 and MatAnyone models with caching
5
+ (Enhanced logging, error handling, and memory safety)
6
  """
7
 
8
  import os
 
12
  import torch
13
  import psutil
14
  import mediapipe as mp
15
+ from contextlib import contextmanager
16
 
17
+ logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
20
  @contextmanager
21
  def torch_memory_manager():
 
22
  try:
23
+ logger.info("[torch_memory_manager] Enter") # [LOG+SAFETY PATCH]
24
  yield
25
  finally:
26
  if torch.cuda.is_available():
27
  torch.cuda.empty_cache()
28
  gc.collect()
29
+ logger.info("[torch_memory_manager] Exit, cleaned up") # [LOG+SAFETY PATCH]
30
 
31
  def get_memory_usage():
 
32
  memory_info = {}
 
 
33
  if torch.cuda.is_available():
34
  memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
35
  memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
36
  memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
37
  torch.cuda.memory_allocated()) / 1e9
 
 
38
  memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
39
  memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
40
+ logger.info(f"[get_memory_usage] {memory_info}") # [LOG+SAFETY PATCH]
41
  return memory_info
42
 
43
  def clear_model_cache():
44
+ logger.info("[clear_model_cache] Clearing all model caches...") # [LOG+SAFETY PATCH]
45
  if hasattr(st, 'cache_resource'):
46
  st.cache_resource.clear()
47
  if torch.cuda.is_available():
48
  torch.cuda.empty_cache()
49
  gc.collect()
50
+ logger.info("[clear_model_cache] Model cache cleared") # [LOG+SAFETY PATCH]
 
 
 
 
51
 
52
  @st.cache_resource(show_spinner=False)
53
  def load_sam2_predictor():
 
 
 
 
54
  try:
55
+ logger.info("[load_sam2_predictor] Loading SAM2 image predictor...") # [LOG+SAFETY PATCH]
56
  from sam2.build_sam import build_sam2
57
  from sam2.sam2_image_predictor import SAM2ImagePredictor
58
+
 
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ logger.info(f"[load_sam2_predictor] Using device: {device}")
61
+
 
62
  checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
63
  model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
64
+
65
  if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
66
+ logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.")
67
  predictor = SAM2ImagePredictor.from_pretrained(
68
  "facebook/sam2-hiera-large",
69
  device=device
70
  )
71
  else:
 
72
  memory_info = get_memory_usage()
73
  gpu_free = memory_info.get('gpu_free', 0)
 
74
  if device == "cuda" and gpu_free < 4.0:
75
+ logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.")
76
  try:
77
  predictor = SAM2ImagePredictor.from_pretrained(
78
  "facebook/sam2-hiera-tiny",
79
  device=device
80
  )
81
+ except Exception as e:
82
+ logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}")
83
  predictor = SAM2ImagePredictor.from_pretrained(
84
  "facebook/sam2-hiera-small",
85
  device=device
86
  )
87
  else:
88
+ logger.info("[load_sam2_predictor] Using local large model")
89
  sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
90
  predictor = SAM2ImagePredictor(sam2_model)
91
+
 
92
  if hasattr(predictor, 'model'):
93
  predictor.model.to(device)
94
  predictor.model.eval()
95
+ logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode")
96
+
97
+ logger.info(f"✅ SAM2 loaded successfully on {device}!")
98
  return predictor, device
99
+
100
  except Exception as e:
101
+ logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True)
102
  import traceback
103
  traceback.print_exc()
104
  return None, None
105
 
 
106
  def load_sam2():
 
107
  predictor, device = load_sam2_predictor()
108
  return predictor
109
 
 
 
 
 
110
  @st.cache_resource(show_spinner=False)
111
  def load_matanyone_processor():
 
 
 
 
112
  try:
113
+ logger.info("[load_matanyone_processor] Loading MatAnyone processor...") # [LOG+SAFETY PATCH]
114
  from matanyone import InferenceCore
115
+
 
116
  device = "cuda" if torch.cuda.is_available() else "cpu"
117
+ logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}")
118
+
 
119
  processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
 
 
120
  if hasattr(processor, 'model'):
121
  processor.model.to(device)
122
  processor.model.eval()
123
+ logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}")
124
+
 
125
  if not hasattr(processor, 'device'):
126
  processor.device = device
127
+ logger.info(f"[load_matanyone_processor] Set processor.device to {device}")
128
+
129
+ logger.info(f"✅ MatAnyone loaded successfully on {device}!")
130
  return processor, device
131
+
132
  except Exception as e:
133
+ logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True)
134
  import traceback
135
  traceback.print_exc()
136
  return None, None
137
 
 
138
  def load_matanyone():
 
139
  processor, device = load_matanyone_processor()
140
  return processor
141
 
 
 
 
 
 
142
  mp_pose = mp.solutions.pose
143
  pose = mp_pose.Pose(
144
  static_image_mode=False,
 
146
  enable_segmentation=True,
147
  min_detection_confidence=0.5
148
  )
149
+ logger.info("✅ MediaPipe Pose initialized",) # [LOG+SAFETY PATCH]
 
 
 
 
150
 
151
  def test_models():
 
 
 
 
152
  results = {
153
  'sam2': {'loaded': False, 'error': None, 'device': None},
154
  'matanyone': {'loaded': False, 'error': None, 'device': None}
155
  }
 
 
156
  try:
157
  sam2_predictor, sam2_device = load_sam2_predictor()
158
  if sam2_predictor is not None:
 
162
  results['sam2']['error'] = "Predictor returned None"
163
  except Exception as e:
164
  results['sam2']['error'] = str(e)
165
+ logger.error(f"[test_models] SAM2 error: {e}", exc_info=True)
 
166
  try:
167
  matanyone_processor, matanyone_device = load_matanyone_processor()
168
  if matanyone_processor is not None:
 
172
  results['matanyone']['error'] = "Processor returned None"
173
  except Exception as e:
174
  results['matanyone']['error'] = str(e)
175
+ logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
176
+ logger.info(f"[test_models] Results: {results}") # [LOG+SAFETY PATCH]
177
  return results
178
 
 
 
 
 
179
  def log_memory_usage(stage=""):
 
180
  memory_info = get_memory_usage()
 
181
  log_msg = f"Memory usage"
182
  if stage:
183
  log_msg += f" ({stage})"
184
  log_msg += ":"
 
185
  if 'gpu_allocated' in memory_info:
186
  log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free"
 
187
  log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used"
 
188
  print(log_msg, flush=True)
189
  logger.info(log_msg)
190
  return memory_info
191
 
192
  def check_memory_available(required_gb=2.0):
 
 
 
 
193
  if not torch.cuda.is_available():
194
  return False, 0.0
 
195
  memory_info = get_memory_usage()
196
  free_gb = memory_info.get('gpu_free', 0)
197
+ logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}") # [LOG+SAFETY PATCH]
198
  return free_gb >= required_gb, free_gb
199
 
200
  def free_memory_aggressive():
201
+ logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...") # [LOG+SAFETY PATCH]
202
  print("Performing aggressive memory cleanup...", flush=True)
 
 
 
203
  clear_model_cache()
 
 
204
  if torch.cuda.is_available():
205
  torch.cuda.empty_cache()
206
  torch.cuda.synchronize()
 
208
  torch.cuda.ipc_collect()
209
  except Exception:
210
  pass
 
 
211
  gc.collect()
 
212
  print("Memory cleanup complete", flush=True)
213
  logger.info("Memory cleanup complete")
214
+ log_memory_usage("after cleanup")