MogensR commited on
Commit
ce6bb5b
·
1 Parent(s): 61d0a06

Create device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +415 -0
device_manager.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Device Management Module
3
+ Handles hardware detection, optimization, and device switching
4
+ """
5
+
6
+ import torch
7
+ import logging
8
+ import platform
9
+ import subprocess
10
+ from typing import Optional, Dict, Any, List
11
+ from exceptions import DeviceError
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class DeviceManager:
16
+ """
17
+ Manages device detection, validation, and optimization for video processing
18
+ """
19
+
20
+ def __init__(self):
21
+ self._optimal_device = None
22
+ self._device_info = {}
23
+ self._cuda_tested = False
24
+ self._mps_tested = False
25
+ self._initialize_device_info()
26
+
27
+ def _initialize_device_info(self):
28
+ """Initialize comprehensive device information"""
29
+ self._device_info = {
30
+ 'platform': platform.system(),
31
+ 'python_version': platform.python_version(),
32
+ 'pytorch_version': torch.__version__,
33
+ 'cuda_available': torch.cuda.is_available(),
34
+ 'cuda_version': torch.version.cuda if torch.cuda.is_available() else None,
35
+ 'mps_available': self._check_mps_availability(),
36
+ 'cpu_count': torch.get_num_threads(),
37
+ }
38
+
39
+ if self._device_info['cuda_available']:
40
+ self._device_info.update(self._get_cuda_info())
41
+
42
+ if self._device_info['mps_available']:
43
+ self._device_info.update(self._get_mps_info())
44
+
45
+ logger.debug(f"Device info initialized: {self._device_info}")
46
+
47
+ def _check_mps_availability(self) -> bool:
48
+ """Check if Metal Performance Shaders (MPS) is available on macOS"""
49
+ try:
50
+ if platform.system() == 'Darwin': # macOS
51
+ return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
52
+ except Exception:
53
+ pass
54
+ return False
55
+
56
+ def _get_cuda_info(self) -> Dict[str, Any]:
57
+ """Get detailed CUDA information"""
58
+ cuda_info = {}
59
+ try:
60
+ if torch.cuda.is_available():
61
+ cuda_info.update({
62
+ 'cuda_device_count': torch.cuda.device_count(),
63
+ 'cuda_current_device': torch.cuda.current_device(),
64
+ 'cuda_devices': []
65
+ })
66
+
67
+ for i in range(torch.cuda.device_count()):
68
+ device_props = torch.cuda.get_device_properties(i)
69
+ device_info = {
70
+ 'index': i,
71
+ 'name': device_props.name,
72
+ 'memory_total_gb': device_props.total_memory / (1024**3),
73
+ 'memory_total_mb': device_props.total_memory / (1024**2),
74
+ 'multiprocessor_count': device_props.multiprocessor_count,
75
+ 'compute_capability': f"{device_props.major}.{device_props.minor}"
76
+ }
77
+
78
+ # Get current memory usage
79
+ try:
80
+ memory_allocated = torch.cuda.memory_allocated(i) / (1024**3)
81
+ memory_reserved = torch.cuda.memory_reserved(i) / (1024**3)
82
+ device_info.update({
83
+ 'memory_allocated_gb': memory_allocated,
84
+ 'memory_reserved_gb': memory_reserved,
85
+ 'memory_free_gb': device_info['memory_total_gb'] - memory_reserved
86
+ })
87
+ except Exception as e:
88
+ logger.warning(f"Could not get memory info for CUDA device {i}: {e}")
89
+
90
+ cuda_info['cuda_devices'].append(device_info)
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error getting CUDA info: {e}")
94
+
95
+ return cuda_info
96
+
97
+ def _get_mps_info(self) -> Dict[str, Any]:
98
+ """Get Metal Performance Shaders information"""
99
+ mps_info = {}
100
+ try:
101
+ if self._device_info['mps_available']:
102
+ # Get system memory as MPS uses unified memory
103
+ try:
104
+ result = subprocess.run(['sysctl', 'hw.memsize'],
105
+ capture_output=True, text=True, timeout=5)
106
+ if result.returncode == 0:
107
+ memory_bytes = int(result.stdout.split(':')[1].strip())
108
+ mps_info['mps_system_memory_gb'] = memory_bytes / (1024**3)
109
+ except Exception as e:
110
+ logger.warning(f"Could not get system memory info: {e}")
111
+
112
+ mps_info['mps_device'] = 'Apple Silicon GPU'
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error getting MPS info: {e}")
116
+
117
+ return mps_info
118
+
119
+ def get_optimal_device(self) -> torch.device:
120
+ """
121
+ Get the optimal device for video processing with comprehensive testing
122
+ """
123
+ if self._optimal_device is not None:
124
+ return self._optimal_device
125
+
126
+ logger.info("Determining optimal device for video processing...")
127
+
128
+ # Try CUDA first (most common for AI workloads)
129
+ if self._device_info['cuda_available'] and not self._cuda_tested:
130
+ cuda_device = self._test_cuda_device()
131
+ if cuda_device is not None:
132
+ self._optimal_device = cuda_device
133
+ logger.info(f"Selected CUDA device: {self._get_device_name(cuda_device)}")
134
+ return self._optimal_device
135
+
136
+ # Try MPS on Apple Silicon
137
+ if self._device_info['mps_available'] and not self._mps_tested:
138
+ mps_device = self._test_mps_device()
139
+ if mps_device is not None:
140
+ self._optimal_device = mps_device
141
+ logger.info(f"Selected MPS device: {self._get_device_name(mps_device)}")
142
+ return self._optimal_device
143
+
144
+ # Fallback to CPU
145
+ self._optimal_device = torch.device("cpu")
146
+ logger.info("Using CPU device (no suitable GPU found or GPU tests failed)")
147
+ return self._optimal_device
148
+
149
+ def _test_cuda_device(self) -> Optional[torch.device]:
150
+ """Test CUDA device functionality"""
151
+ self._cuda_tested = True
152
+
153
+ try:
154
+ # Find best CUDA device (highest memory)
155
+ best_device_idx = 0
156
+ best_memory = 0
157
+
158
+ for device_info in self._device_info.get('cuda_devices', []):
159
+ if device_info['memory_free_gb'] > best_memory:
160
+ best_memory = device_info['memory_free_gb']
161
+ best_device_idx = device_info['index']
162
+
163
+ device = torch.device(f"cuda:{best_device_idx}")
164
+
165
+ # Test basic functionality
166
+ test_tensor = torch.tensor([1.0], device=device)
167
+ result = test_tensor * 2
168
+
169
+ # Test memory operations
170
+ large_tensor = torch.randn(1000, 1000, device=device)
171
+ del large_tensor, test_tensor, result
172
+ torch.cuda.empty_cache()
173
+ torch.cuda.synchronize()
174
+
175
+ logger.info(f"CUDA device {best_device_idx} passed functionality tests")
176
+ return device
177
+
178
+ except Exception as e:
179
+ logger.warning(f"CUDA device test failed: {e}")
180
+ return None
181
+
182
+ def _test_mps_device(self) -> Optional[torch.device]:
183
+ """Test MPS device functionality"""
184
+ self._mps_tested = True
185
+
186
+ try:
187
+ device = torch.device("mps")
188
+
189
+ # Test basic functionality
190
+ test_tensor = torch.tensor([1.0], device=device)
191
+ result = test_tensor * 2
192
+
193
+ # Test memory operations
194
+ large_tensor = torch.randn(1000, 1000, device=device)
195
+ del large_tensor, test_tensor, result
196
+
197
+ # MPS doesn't have explicit cache clearing like CUDA
198
+ logger.info("MPS device passed functionality tests")
199
+ return device
200
+
201
+ except Exception as e:
202
+ logger.warning(f"MPS device test failed: {e}")
203
+ return None
204
+
205
+ def _get_device_name(self, device: torch.device) -> str:
206
+ """Get human-readable device name"""
207
+ if device.type == 'cuda':
208
+ if self._device_info.get('cuda_devices'):
209
+ device_idx = device.index or 0
210
+ for cuda_device in self._device_info['cuda_devices']:
211
+ if cuda_device['index'] == device_idx:
212
+ return cuda_device['name']
213
+ return f"CUDA Device {device.index or 0}"
214
+ elif device.type == 'mps':
215
+ return "Apple Silicon GPU (MPS)"
216
+ else:
217
+ return "CPU"
218
+
219
+ def get_device_capabilities(self, device: Optional[torch.device] = None) -> Dict[str, Any]:
220
+ """Get capabilities of the specified device"""
221
+ if device is None:
222
+ device = self.get_optimal_device()
223
+
224
+ capabilities = {
225
+ 'device_type': device.type,
226
+ 'device_name': self._get_device_name(device),
227
+ 'supports_mixed_precision': False,
228
+ 'recommended_batch_size': 1,
229
+ 'memory_efficiency': 'medium'
230
+ }
231
+
232
+ if device.type == 'cuda':
233
+ device_idx = device.index or 0
234
+ for cuda_device in self._device_info.get('cuda_devices', []):
235
+ if cuda_device['index'] == device_idx:
236
+ # Check compute capability for mixed precision
237
+ compute_version = float(cuda_device.get('compute_capability', '0.0'))
238
+ capabilities['supports_mixed_precision'] = compute_version >= 7.0
239
+
240
+ # Estimate batch size based on memory
241
+ memory_gb = cuda_device.get('memory_free_gb', 0)
242
+ if memory_gb >= 24:
243
+ capabilities['recommended_batch_size'] = 4
244
+ capabilities['memory_efficiency'] = 'high'
245
+ elif memory_gb >= 12:
246
+ capabilities['recommended_batch_size'] = 2
247
+ capabilities['memory_efficiency'] = 'high'
248
+ elif memory_gb >= 6:
249
+ capabilities['recommended_batch_size'] = 1
250
+ capabilities['memory_efficiency'] = 'medium'
251
+ else:
252
+ capabilities['memory_efficiency'] = 'low'
253
+
254
+ capabilities['memory_available_gb'] = memory_gb
255
+ break
256
+
257
+ elif device.type == 'mps':
258
+ capabilities['supports_mixed_precision'] = True # MPS supports fp16
259
+ capabilities['memory_efficiency'] = 'high' # Unified memory
260
+ system_memory = self._device_info.get('mps_system_memory_gb', 8)
261
+ if system_memory >= 16:
262
+ capabilities['recommended_batch_size'] = 2
263
+ capabilities['memory_available_gb'] = system_memory * 0.7 # Rough estimate
264
+
265
+ else: # CPU
266
+ capabilities['memory_efficiency'] = 'low'
267
+ capabilities['supports_mixed_precision'] = False
268
+
269
+ return capabilities
270
+
271
+ def switch_device(self, device_type: str) -> torch.device:
272
+ """
273
+ Switch to a specific device type
274
+
275
+ Args:
276
+ device_type: 'cuda', 'mps', or 'cpu'
277
+ """
278
+ try:
279
+ if device_type.lower() == 'cuda':
280
+ if not self._device_info['cuda_available']:
281
+ raise DeviceError('cuda', 'CUDA not available on this system')
282
+
283
+ device = self._test_cuda_device()
284
+ if device is None:
285
+ raise DeviceError('cuda', 'CUDA device failed functionality tests')
286
+
287
+ elif device_type.lower() == 'mps':
288
+ if not self._device_info['mps_available']:
289
+ raise DeviceError('mps', 'MPS not available on this system')
290
+
291
+ device = self._test_mps_device()
292
+ if device is None:
293
+ raise DeviceError('mps', 'MPS device failed functionality tests')
294
+
295
+ elif device_type.lower() == 'cpu':
296
+ device = torch.device('cpu')
297
+
298
+ else:
299
+ raise DeviceError('unknown', f'Unknown device type: {device_type}')
300
+
301
+ self._optimal_device = device
302
+ logger.info(f"Switched to device: {self._get_device_name(device)}")
303
+ return device
304
+
305
+ except DeviceError:
306
+ raise
307
+ except Exception as e:
308
+ raise DeviceError(device_type, f"Failed to switch to {device_type}: {str(e)}")
309
+
310
+ def get_available_devices(self) -> List[str]:
311
+ """Get list of available device types"""
312
+ devices = ['cpu'] # CPU always available
313
+
314
+ if self._device_info['cuda_available']:
315
+ devices.append('cuda')
316
+
317
+ if self._device_info['mps_available']:
318
+ devices.append('mps')
319
+
320
+ return devices
321
+
322
+ def get_device_status(self) -> Dict[str, Any]:
323
+ """Get comprehensive device status"""
324
+ current_device = self.get_optimal_device()
325
+
326
+ status = {
327
+ 'current_device': str(current_device),
328
+ 'current_device_name': self._get_device_name(current_device),
329
+ 'available_devices': self.get_available_devices(),
330
+ 'device_info': self._device_info.copy(),
331
+ 'capabilities': self.get_device_capabilities(current_device)
332
+ }
333
+
334
+ # Add current memory usage if on GPU
335
+ if current_device.type == 'cuda':
336
+ try:
337
+ device_idx = current_device.index or 0
338
+ status['current_memory_usage'] = {
339
+ 'allocated_gb': torch.cuda.memory_allocated(device_idx) / (1024**3),
340
+ 'reserved_gb': torch.cuda.memory_reserved(device_idx) / (1024**3),
341
+ 'max_allocated_gb': torch.cuda.max_memory_allocated(device_idx) / (1024**3),
342
+ 'max_reserved_gb': torch.cuda.max_memory_reserved(device_idx) / (1024**3)
343
+ }
344
+ except Exception as e:
345
+ logger.warning(f"Could not get current memory usage: {e}")
346
+
347
+ return status
348
+
349
+ def optimize_for_processing(self) -> Dict[str, Any]:
350
+ """Optimize device settings for video processing"""
351
+ device = self.get_optimal_device()
352
+ optimizations = {
353
+ 'device': str(device),
354
+ 'optimizations_applied': []
355
+ }
356
+
357
+ try:
358
+ if device.type == 'cuda':
359
+ # Enable cuDNN benchmarking for consistent input sizes
360
+ torch.backends.cudnn.benchmark = True
361
+ optimizations['optimizations_applied'].append('cudnn_benchmark')
362
+
363
+ # Enable cuDNN deterministic mode if needed for reproducibility
364
+ # torch.backends.cudnn.deterministic = True
365
+
366
+ # Set memory allocation strategy
367
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
368
+ optimizations['optimizations_applied'].append('cuda_memory_strategy')
369
+
370
+ elif device.type == 'mps':
371
+ # MPS-specific optimizations would go here
372
+ optimizations['optimizations_applied'].append('mps_optimized')
373
+
374
+ else: # CPU
375
+ # Set optimal number of threads for CPU processing
376
+ torch.set_num_threads(min(torch.get_num_threads(), 8))
377
+ optimizations['optimizations_applied'].append('cpu_thread_optimization')
378
+
379
+ logger.info(f"Applied optimizations for {device}: {optimizations['optimizations_applied']}")
380
+
381
+ except Exception as e:
382
+ logger.warning(f"Some optimizations failed: {e}")
383
+ optimizations['optimization_errors'] = str(e)
384
+
385
+ return optimizations
386
+
387
+ def cleanup_device_memory(self):
388
+ """Clean up device memory"""
389
+ device = self.get_optimal_device()
390
+
391
+ if device.type == 'cuda':
392
+ try:
393
+ torch.cuda.empty_cache()
394
+ torch.cuda.synchronize()
395
+ logger.debug("CUDA memory cache cleared")
396
+ except Exception as e:
397
+ logger.warning(f"CUDA memory cleanup failed: {e}")
398
+
399
+ elif device.type == 'mps':
400
+ try:
401
+ # MPS uses unified memory, less explicit cleanup needed
402
+ # But we can still run garbage collection
403
+ import gc
404
+ gc.collect()
405
+ logger.debug("MPS memory cleanup completed")
406
+ except Exception as e:
407
+ logger.warning(f"MPS memory cleanup failed: {e}")
408
+
409
+ else: # CPU
410
+ try:
411
+ import gc
412
+ gc.collect()
413
+ logger.debug("CPU memory cleanup completed")
414
+ except Exception as e:
415
+ logger.warning(f"CPU memory cleanup failed: {e}")