akryldigital commited on
Commit
32d6a0b
·
verified ·
1 Parent(s): 99c582d

add device detection utils

Browse files
Files changed (2) hide show
  1. src/utils/__init__.py +18 -0
  2. src/utils/device.py +109 -0
src/utils/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for the audit assistant."""
2
+
3
+ from .device import (
4
+ get_device,
5
+ get_device_for_sentence_transformers,
6
+ get_device_for_colpali,
7
+ log_device_info,
8
+ get_cached_device
9
+ )
10
+
11
+ __all__ = [
12
+ "get_device",
13
+ "get_device_for_sentence_transformers",
14
+ "get_device_for_colpali",
15
+ "log_device_info",
16
+ "get_cached_device"
17
+ ]
18
+
src/utils/device.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Device detection utility for automatic hardware acceleration.
3
+
4
+ Automatically detects and returns the best available device:
5
+ - MPS (Apple Silicon)
6
+ - CUDA (NVIDIA GPU)
7
+ - CPU (fallback)
8
+ """
9
+ import torch
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def get_device() -> str:
15
+ """
16
+ Automatically detect the best available device.
17
+
18
+ Priority:
19
+ 1. MPS (Apple Silicon) - if available
20
+ 2. CUDA (NVIDIA GPU) - if available
21
+ 3. CPU - fallback
22
+
23
+ Returns:
24
+ str: Device name ('mps', 'cuda', or 'cpu')
25
+ """
26
+ # Check for MPS (Apple Silicon)
27
+ mps_available = torch.backends.mps.is_available()
28
+ mps_built = torch.backends.mps.is_built()
29
+
30
+ if mps_available and mps_built:
31
+ try:
32
+ # Test if MPS is actually usable
33
+ torch.zeros(1, device='mps')
34
+ logger.info("🍎 Using MPS (Apple Silicon) device")
35
+ return 'mps'
36
+ except Exception as e:
37
+ logger.warning(f"⚠️ MPS available but not usable: {e}")
38
+
39
+ # Check for CUDA (NVIDIA GPU)
40
+ if torch.cuda.is_available():
41
+ logger.info(f"🎮 Using CUDA device (GPU: {torch.cuda.get_device_name(0)})")
42
+ return 'cuda'
43
+
44
+ # Fallback to CPU
45
+ logger.info(f"💻 Using CPU device (no GPU acceleration) [MPS available={mps_available}, MPS built={mps_built}]")
46
+ return 'cpu'
47
+
48
+ def get_device_for_sentence_transformers() -> str:
49
+ """
50
+ Get device string for sentence-transformers library.
51
+
52
+ Returns:
53
+ str: Device name compatible with sentence-transformers
54
+ """
55
+ device = get_device()
56
+ # sentence-transformers uses same naming
57
+ return device
58
+
59
+ def get_device_for_colpali() -> str:
60
+ """
61
+ Get device string for ColPali models.
62
+
63
+ Returns:
64
+ str: Device name compatible with ColPali/transformers
65
+ """
66
+ device = get_device()
67
+ # ColPali/transformers uses same naming
68
+ return device
69
+
70
+ def log_device_info():
71
+ """Log detailed device information"""
72
+ device = get_device()
73
+
74
+ logger.info("=" * 70)
75
+ logger.info("🖥️ DEVICE CONFIGURATION")
76
+ logger.info("=" * 70)
77
+ logger.info(f"Selected device: {device.upper()}")
78
+
79
+ if device == 'mps':
80
+ logger.info(" Type: Apple Silicon (M1/M2/M3)")
81
+ logger.info(" Acceleration: Metal Performance Shaders")
82
+ elif device == 'cuda':
83
+ logger.info(f" Type: NVIDIA GPU")
84
+ logger.info(f" GPU: {torch.cuda.get_device_name(0)}")
85
+ logger.info(f" CUDA Version: {torch.version.cuda}")
86
+ logger.info(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
87
+ else:
88
+ logger.info(" Type: CPU (no GPU acceleration)")
89
+ logger.info(" Note: Processing will be slower")
90
+
91
+ logger.info("=" * 70)
92
+
93
+ return device
94
+
95
+ # Convenience function to get device once at module level
96
+ _cached_device = None
97
+
98
+ def get_cached_device() -> str:
99
+ """
100
+ Get device with caching to avoid repeated checks.
101
+
102
+ Returns:
103
+ str: Device name ('mps', 'cuda', or 'cpu')
104
+ """
105
+ global _cached_device
106
+ if _cached_device is None:
107
+ _cached_device = get_device()
108
+ return _cached_device
109
+