akryldigital's picture
add device detection utils
32d6a0b verified
raw
history blame
3.17 kB
"""
Device detection utility for automatic hardware acceleration.
Automatically detects and returns the best available device:
- MPS (Apple Silicon)
- CUDA (NVIDIA GPU)
- CPU (fallback)
"""
import torch
import logging
logger = logging.getLogger(__name__)
def get_device() -> str:
"""
Automatically detect the best available device.
Priority:
1. MPS (Apple Silicon) - if available
2. CUDA (NVIDIA GPU) - if available
3. CPU - fallback
Returns:
str: Device name ('mps', 'cuda', or 'cpu')
"""
# Check for MPS (Apple Silicon)
mps_available = torch.backends.mps.is_available()
mps_built = torch.backends.mps.is_built()
if mps_available and mps_built:
try:
# Test if MPS is actually usable
torch.zeros(1, device='mps')
logger.info("๐ŸŽ Using MPS (Apple Silicon) device")
return 'mps'
except Exception as e:
logger.warning(f"โš ๏ธ MPS available but not usable: {e}")
# Check for CUDA (NVIDIA GPU)
if torch.cuda.is_available():
logger.info(f"๐ŸŽฎ Using CUDA device (GPU: {torch.cuda.get_device_name(0)})")
return 'cuda'
# Fallback to CPU
logger.info(f"๐Ÿ’ป Using CPU device (no GPU acceleration) [MPS available={mps_available}, MPS built={mps_built}]")
return 'cpu'
def get_device_for_sentence_transformers() -> str:
"""
Get device string for sentence-transformers library.
Returns:
str: Device name compatible with sentence-transformers
"""
device = get_device()
# sentence-transformers uses same naming
return device
def get_device_for_colpali() -> str:
"""
Get device string for ColPali models.
Returns:
str: Device name compatible with ColPali/transformers
"""
device = get_device()
# ColPali/transformers uses same naming
return device
def log_device_info():
"""Log detailed device information"""
device = get_device()
logger.info("=" * 70)
logger.info("๐Ÿ–ฅ๏ธ DEVICE CONFIGURATION")
logger.info("=" * 70)
logger.info(f"Selected device: {device.upper()}")
if device == 'mps':
logger.info(" Type: Apple Silicon (M1/M2/M3)")
logger.info(" Acceleration: Metal Performance Shaders")
elif device == 'cuda':
logger.info(f" Type: NVIDIA GPU")
logger.info(f" GPU: {torch.cuda.get_device_name(0)}")
logger.info(f" CUDA Version: {torch.version.cuda}")
logger.info(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
logger.info(" Type: CPU (no GPU acceleration)")
logger.info(" Note: Processing will be slower")
logger.info("=" * 70)
return device
# Convenience function to get device once at module level
_cached_device = None
def get_cached_device() -> str:
"""
Get device with caching to avoid repeated checks.
Returns:
str: Device name ('mps', 'cuda', or 'cpu')
"""
global _cached_device
if _cached_device is None:
_cached_device = get_device()
return _cached_device