MogensR commited on
Commit
84a78ca
·
1 Parent(s): d01fd14

Create memory_manager.py

Browse files
Files changed (1) hide show
  1. memory_manager.py +471 -0
memory_manager.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory Management Module
3
+ Handles memory cleanup, monitoring, and GPU resource management
4
+ """
5
+
6
+ import gc
7
+ import os
8
+ import psutil
9
+ import torch
10
+ import time
11
+ import logging
12
+ import threading
13
+ from typing import Dict, Any, Optional, Callable
14
+ from exceptions import MemoryError, ResourceExhaustionError
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class MemoryManager:
19
+ """
20
+ Comprehensive memory management for video processing applications
21
+ """
22
+
23
+ def __init__(self, device: torch.device, memory_limit_gb: Optional[float] = None):
24
+ self.device = device
25
+ self.gpu_available = device.type in ['cuda', 'mps']
26
+ self.memory_limit_gb = memory_limit_gb
27
+ self.cleanup_callbacks = []
28
+ self.monitoring_active = False
29
+ self.monitoring_thread = None
30
+ self.stats = {
31
+ 'cleanup_count': 0,
32
+ 'peak_memory_usage': 0.0,
33
+ 'total_allocated': 0.0,
34
+ 'total_freed': 0.0
35
+ }
36
+
37
+ # Initialize memory monitoring
38
+ self._initialize_memory_limits()
39
+ logger.info(f"MemoryManager initialized for device: {device}")
40
+
41
+ def _initialize_memory_limits(self):
42
+ """Initialize memory limits based on device and system"""
43
+ if self.device.type == 'cuda':
44
+ try:
45
+ device_idx = self.device.index or 0
46
+ device_props = torch.cuda.get_device_properties(device_idx)
47
+ total_memory_gb = device_props.total_memory / (1024**3)
48
+
49
+ # Use 80% of GPU memory as default limit if not specified
50
+ if self.memory_limit_gb is None:
51
+ self.memory_limit_gb = total_memory_gb * 0.8
52
+
53
+ logger.info(f"CUDA memory limit set to {self.memory_limit_gb:.1f}GB "
54
+ f"(total: {total_memory_gb:.1f}GB)")
55
+
56
+ except Exception as e:
57
+ logger.warning(f"Could not get CUDA memory info: {e}")
58
+ self.memory_limit_gb = 4.0 # Conservative fallback
59
+
60
+ elif self.device.type == 'mps':
61
+ # MPS uses unified memory, so check system memory
62
+ system_memory_gb = psutil.virtual_memory().total / (1024**3)
63
+ if self.memory_limit_gb is None:
64
+ # Use 50% of system memory for MPS as it shares with system
65
+ self.memory_limit_gb = system_memory_gb * 0.5
66
+
67
+ logger.info(f"MPS memory limit set to {self.memory_limit_gb:.1f}GB "
68
+ f"(system: {system_memory_gb:.1f}GB)")
69
+
70
+ else: # CPU
71
+ system_memory_gb = psutil.virtual_memory().total / (1024**3)
72
+ if self.memory_limit_gb is None:
73
+ # Use 60% of system memory for CPU processing
74
+ self.memory_limit_gb = system_memory_gb * 0.6
75
+
76
+ logger.info(f"CPU memory limit set to {self.memory_limit_gb:.1f}GB "
77
+ f"(system: {system_memory_gb:.1f}GB)")
78
+
79
+ def get_memory_usage(self) -> Dict[str, Any]:
80
+ """Get comprehensive memory usage statistics"""
81
+ usage = {
82
+ 'device_type': self.device.type,
83
+ 'memory_limit_gb': self.memory_limit_gb,
84
+ 'timestamp': time.time()
85
+ }
86
+
87
+ try:
88
+ if self.device.type == 'cuda':
89
+ device_idx = self.device.index or 0
90
+
91
+ # GPU memory
92
+ allocated = torch.cuda.memory_allocated(device_idx)
93
+ reserved = torch.cuda.memory_reserved(device_idx)
94
+ total = torch.cuda.get_device_properties(device_idx).total_memory
95
+
96
+ usage.update({
97
+ 'gpu_allocated_gb': allocated / (1024**3),
98
+ 'gpu_reserved_gb': reserved / (1024**3),
99
+ 'gpu_total_gb': total / (1024**3),
100
+ 'gpu_utilization_percent': (allocated / total) * 100,
101
+ 'gpu_reserved_percent': (reserved / total) * 100,
102
+ 'gpu_free_gb': (total - reserved) / (1024**3)
103
+ })
104
+
105
+ # Peak memory tracking
106
+ max_allocated = torch.cuda.max_memory_allocated(device_idx)
107
+ max_reserved = torch.cuda.max_memory_reserved(device_idx)
108
+ usage.update({
109
+ 'gpu_max_allocated_gb': max_allocated / (1024**3),
110
+ 'gpu_max_reserved_gb': max_reserved / (1024**3)
111
+ })
112
+
113
+ elif self.device.type == 'mps':
114
+ # MPS doesn't have explicit memory tracking like CUDA
115
+ # Fall back to system memory monitoring
116
+ vm = psutil.virtual_memory()
117
+ usage.update({
118
+ 'system_memory_gb': vm.total / (1024**3),
119
+ 'system_available_gb': vm.available / (1024**3),
120
+ 'system_used_gb': vm.used / (1024**3),
121
+ 'system_utilization_percent': vm.percent
122
+ })
123
+
124
+ except Exception as e:
125
+ logger.warning(f"Error getting GPU memory usage: {e}")
126
+
127
+ # Always include system memory info
128
+ try:
129
+ vm = psutil.virtual_memory()
130
+ swap = psutil.swap_memory()
131
+
132
+ usage.update({
133
+ 'system_total_gb': vm.total / (1024**3),
134
+ 'system_available_gb': vm.available / (1024**3),
135
+ 'system_used_gb': vm.used / (1024**3),
136
+ 'system_percent': vm.percent,
137
+ 'swap_total_gb': swap.total / (1024**3),
138
+ 'swap_used_gb': swap.used / (1024**3),
139
+ 'swap_percent': swap.percent
140
+ })
141
+
142
+ except Exception as e:
143
+ logger.warning(f"Error getting system memory usage: {e}")
144
+
145
+ # Process-specific memory
146
+ try:
147
+ process = psutil.Process()
148
+ memory_info = process.memory_info()
149
+ usage.update({
150
+ 'process_rss_gb': memory_info.rss / (1024**3), # Physical memory
151
+ 'process_vms_gb': memory_info.vms / (1024**3), # Virtual memory
152
+ })
153
+
154
+ except Exception as e:
155
+ logger.warning(f"Error getting process memory usage: {e}")
156
+
157
+ # Update peak tracking
158
+ current_usage = usage.get('gpu_allocated_gb', usage.get('system_used_gb', 0))
159
+ if current_usage > self.stats['peak_memory_usage']:
160
+ self.stats['peak_memory_usage'] = current_usage
161
+
162
+ return usage
163
+
164
+ def cleanup_basic(self):
165
+ """Basic memory cleanup - lightweight operation"""
166
+ try:
167
+ gc.collect()
168
+
169
+ if self.device.type == 'cuda':
170
+ torch.cuda.empty_cache()
171
+
172
+ self.stats['cleanup_count'] += 1
173
+ logger.debug("Basic memory cleanup completed")
174
+
175
+ except Exception as e:
176
+ logger.warning(f"Basic memory cleanup failed: {e}")
177
+
178
+ def cleanup_aggressive(self):
179
+ """Aggressive memory cleanup - more thorough but slower"""
180
+ try:
181
+ start_time = time.time()
182
+
183
+ # Run all registered cleanup callbacks first
184
+ for callback in self.cleanup_callbacks:
185
+ try:
186
+ callback()
187
+ except Exception as e:
188
+ logger.warning(f"Cleanup callback failed: {e}")
189
+
190
+ # Multiple garbage collection passes
191
+ for _ in range(3):
192
+ gc.collect()
193
+
194
+ if self.device.type == 'cuda':
195
+ # CUDA-specific aggressive cleanup
196
+ torch.cuda.empty_cache()
197
+ torch.cuda.synchronize()
198
+
199
+ # Reset peak memory statistics
200
+ device_idx = self.device.index or 0
201
+ torch.cuda.reset_peak_memory_stats(device_idx)
202
+
203
+ elif self.device.type == 'mps':
204
+ # MPS cleanup - mainly garbage collection
205
+ # Could add MPS-specific operations if available
206
+ pass
207
+
208
+ cleanup_time = time.time() - start_time
209
+ self.stats['cleanup_count'] += 1
210
+
211
+ logger.debug(f"Aggressive memory cleanup completed in {cleanup_time:.2f}s")
212
+
213
+ except Exception as e:
214
+ logger.error(f"Aggressive memory cleanup failed: {e}")
215
+ raise MemoryError("aggressive_cleanup", str(e))
216
+
217
+ def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]:
218
+ """Check if system is under memory pressure"""
219
+ usage = self.get_memory_usage()
220
+
221
+ pressure_info = {
222
+ 'under_pressure': False,
223
+ 'pressure_level': 'normal', # normal, warning, critical
224
+ 'recommendations': [],
225
+ 'usage_percent': 0.0
226
+ }
227
+
228
+ # Determine usage percentage based on device type
229
+ if self.device.type == 'cuda':
230
+ usage_percent = usage.get('gpu_utilization_percent', 0)
231
+ pressure_info['usage_percent'] = usage_percent
232
+
233
+ if usage_percent >= threshold_percent:
234
+ pressure_info['under_pressure'] = True
235
+
236
+ if usage_percent >= 95:
237
+ pressure_info['pressure_level'] = 'critical'
238
+ pressure_info['recommendations'].extend([
239
+ 'Reduce batch size immediately',
240
+ 'Enable gradient checkpointing',
241
+ 'Consider switching to CPU processing'
242
+ ])
243
+ elif usage_percent >= threshold_percent:
244
+ pressure_info['pressure_level'] = 'warning'
245
+ pressure_info['recommendations'].extend([
246
+ 'Run aggressive memory cleanup',
247
+ 'Reduce keyframe interval',
248
+ 'Monitor memory usage closely'
249
+ ])
250
+
251
+ else: # CPU or MPS - use system memory
252
+ usage_percent = usage.get('system_percent', 0)
253
+ pressure_info['usage_percent'] = usage_percent
254
+
255
+ if usage_percent >= threshold_percent:
256
+ pressure_info['under_pressure'] = True
257
+
258
+ if usage_percent >= 95:
259
+ pressure_info['pressure_level'] = 'critical'
260
+ pressure_info['recommendations'].extend([
261
+ 'Free system memory immediately',
262
+ 'Close unnecessary applications',
263
+ 'Reduce video processing quality'
264
+ ])
265
+ elif usage_percent >= threshold_percent:
266
+ pressure_info['pressure_level'] = 'warning'
267
+ pressure_info['recommendations'].extend([
268
+ 'Run memory cleanup',
269
+ 'Monitor system memory',
270
+ 'Consider processing in smaller chunks'
271
+ ])
272
+
273
+ return pressure_info
274
+
275
+ def auto_cleanup_if_needed(self, pressure_threshold: float = 80.0) -> bool:
276
+ """Automatically run cleanup if memory pressure is detected"""
277
+ pressure = self.check_memory_pressure(pressure_threshold)
278
+
279
+ if pressure['under_pressure']:
280
+ cleanup_method = (
281
+ self.cleanup_aggressive
282
+ if pressure['pressure_level'] == 'critical'
283
+ else self.cleanup_basic
284
+ )
285
+
286
+ logger.info(f"Auto-cleanup triggered due to {pressure['pressure_level']} "
287
+ f"memory pressure ({pressure['usage_percent']:.1f}%)")
288
+
289
+ cleanup_method()
290
+ return True
291
+
292
+ return False
293
+
294
+ def register_cleanup_callback(self, callback: Callable):
295
+ """Register a callback to run during cleanup operations"""
296
+ self.cleanup_callbacks.append(callback)
297
+ logger.debug("Cleanup callback registered")
298
+
299
+ def start_monitoring(self, interval_seconds: float = 30.0,
300
+ pressure_callback: Optional[Callable] = None):
301
+ """Start background memory monitoring"""
302
+ if self.monitoring_active:
303
+ logger.warning("Memory monitoring already active")
304
+ return
305
+
306
+ self.monitoring_active = True
307
+
308
+ def monitor_loop():
309
+ while self.monitoring_active:
310
+ try:
311
+ pressure = self.check_memory_pressure()
312
+
313
+ if pressure['under_pressure']:
314
+ logger.warning(f"Memory pressure detected: {pressure['pressure_level']} "
315
+ f"({pressure['usage_percent']:.1f}%)")
316
+
317
+ if pressure_callback:
318
+ try:
319
+ pressure_callback(pressure)
320
+ except Exception as e:
321
+ logger.error(f"Pressure callback failed: {e}")
322
+
323
+ # Auto-cleanup on critical pressure
324
+ if pressure['pressure_level'] == 'critical':
325
+ self.cleanup_aggressive()
326
+
327
+ time.sleep(interval_seconds)
328
+
329
+ except Exception as e:
330
+ logger.error(f"Memory monitoring error: {e}")
331
+ time.sleep(interval_seconds)
332
+
333
+ self.monitoring_thread = threading.Thread(target=monitor_loop, daemon=True)
334
+ self.monitoring_thread.start()
335
+
336
+ logger.info(f"Memory monitoring started (interval: {interval_seconds}s)")
337
+
338
+ def stop_monitoring(self):
339
+ """Stop background memory monitoring"""
340
+ if self.monitoring_active:
341
+ self.monitoring_active = False
342
+ if self.monitoring_thread and self.monitoring_thread.is_alive():
343
+ self.monitoring_thread.join(timeout=5.0)
344
+ logger.info("Memory monitoring stopped")
345
+
346
+ def estimate_memory_requirement(self, video_width: int, video_height: int,
347
+ frames_in_memory: int = 5) -> Dict[str, float]:
348
+ """Estimate memory requirements for video processing"""
349
+
350
+ # Base memory per frame (RGB image)
351
+ bytes_per_frame = video_width * video_height * 3
352
+
353
+ # Additional overhead for processing
354
+ overhead_multiplier = 3.0 # For masks, intermediate results, etc.
355
+
356
+ estimated_memory = {
357
+ 'frames_memory_gb': (bytes_per_frame * frames_in_memory * overhead_multiplier) / (1024**3),
358
+ 'model_memory_gb': 4.0, # Rough estimate for SAM2 + MatAnyone
359
+ 'system_overhead_gb': 2.0,
360
+ 'total_estimated_gb': 0.0
361
+ }
362
+
363
+ estimated_memory['total_estimated_gb'] = sum([
364
+ estimated_memory['frames_memory_gb'],
365
+ estimated_memory['model_memory_gb'],
366
+ estimated_memory['system_overhead_gb']
367
+ ])
368
+
369
+ return estimated_memory
370
+
371
+ def can_process_video(self, video_width: int, video_height: int,
372
+ frames_in_memory: int = 5) -> Dict[str, Any]:
373
+ """Check if video can be processed with current memory"""
374
+
375
+ estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory)
376
+ current_usage = self.get_memory_usage()
377
+
378
+ # Available memory calculation
379
+ if self.device.type == 'cuda':
380
+ available_memory = current_usage.get('gpu_free_gb', 0)
381
+ else:
382
+ available_memory = current_usage.get('system_available_gb', 0)
383
+
384
+ can_process = estimate['total_estimated_gb'] <= available_memory
385
+
386
+ result = {
387
+ 'can_process': can_process,
388
+ 'estimated_memory_gb': estimate['total_estimated_gb'],
389
+ 'available_memory_gb': available_memory,
390
+ 'memory_margin_gb': available_memory - estimate['total_estimated_gb'],
391
+ 'recommendations': []
392
+ }
393
+
394
+ if not can_process:
395
+ deficit = estimate['total_estimated_gb'] - available_memory
396
+ result['recommendations'] = [
397
+ f"Free {deficit:.1f}GB of memory",
398
+ "Reduce video resolution",
399
+ "Process in smaller chunks",
400
+ "Use lower quality settings"
401
+ ]
402
+ elif result['memory_margin_gb'] < 1.0:
403
+ result['recommendations'] = [
404
+ "Memory margin is low",
405
+ "Monitor memory usage during processing",
406
+ "Consider reducing batch size"
407
+ ]
408
+
409
+ return result
410
+
411
+ def get_optimization_suggestions(self) -> Dict[str, Any]:
412
+ """Get memory optimization suggestions based on current state"""
413
+ usage = self.get_memory_usage()
414
+
415
+ suggestions = {
416
+ 'current_usage_percent': usage.get('gpu_utilization_percent', usage.get('system_percent', 0)),
417
+ 'suggestions': [],
418
+ 'priority': 'low' # low, medium, high
419
+ }
420
+
421
+ usage_percent = suggestions['current_usage_percent']
422
+
423
+ if usage_percent >= 90:
424
+ suggestions['priority'] = 'high'
425
+ suggestions['suggestions'].extend([
426
+ 'Run aggressive memory cleanup immediately',
427
+ 'Reduce batch size to 1',
428
+ 'Enable gradient checkpointing if available',
429
+ 'Consider switching to CPU processing'
430
+ ])
431
+ elif usage_percent >= 75:
432
+ suggestions['priority'] = 'medium'
433
+ suggestions['suggestions'].extend([
434
+ 'Run memory cleanup regularly',
435
+ 'Monitor memory usage closely',
436
+ 'Reduce keyframe interval',
437
+ 'Use mixed precision if supported'
438
+ ])
439
+ elif usage_percent >= 50:
440
+ suggestions['priority'] = 'low'
441
+ suggestions['suggestions'].extend([
442
+ 'Current usage is acceptable',
443
+ 'Regular cleanup should be sufficient',
444
+ 'Monitor for memory leaks during long operations'
445
+ ])
446
+ else:
447
+ suggestions['suggestions'] = [
448
+ 'Memory usage is optimal',
449
+ 'No immediate action required'
450
+ ]
451
+
452
+ return suggestions
453
+
454
+ def get_stats(self) -> Dict[str, Any]:
455
+ """Get memory management statistics"""
456
+ return {
457
+ 'cleanup_count': self.stats['cleanup_count'],
458
+ 'peak_memory_usage_gb': self.stats['peak_memory_usage'],
459
+ 'monitoring_active': self.monitoring_active,
460
+ 'device_type': self.device.type,
461
+ 'memory_limit_gb': self.memory_limit_gb,
462
+ 'registered_callbacks': len(self.cleanup_callbacks)
463
+ }
464
+
465
+ def __del__(self):
466
+ """Cleanup when MemoryManager is destroyed"""
467
+ try:
468
+ self.stop_monitoring()
469
+ self.cleanup_aggressive()
470
+ except Exception:
471
+ pass # Ignore errors during cleanup