""" Device detection utility for automatic hardware acceleration. Automatically detects and returns the best available device: - MPS (Apple Silicon) - CUDA (NVIDIA GPU) - CPU (fallback) """ import torch import logging logger = logging.getLogger(__name__) def get_device() -> str: """ Automatically detect the best available device. Priority: 1. MPS (Apple Silicon) - if available 2. CUDA (NVIDIA GPU) - if available 3. CPU - fallback Returns: str: Device name ('mps', 'cuda', or 'cpu') """ # Check for MPS (Apple Silicon) mps_available = torch.backends.mps.is_available() mps_built = torch.backends.mps.is_built() if mps_available and mps_built: try: # Test if MPS is actually usable torch.zeros(1, device='mps') logger.info("🍎 Using MPS (Apple Silicon) device") return 'mps' except Exception as e: logger.warning(f"⚠️ MPS available but not usable: {e}") # Check for CUDA (NVIDIA GPU) if torch.cuda.is_available(): logger.info(f"🎮 Using CUDA device (GPU: {torch.cuda.get_device_name(0)})") return 'cuda' # Fallback to CPU logger.info(f"💻 Using CPU device (no GPU acceleration) [MPS available={mps_available}, MPS built={mps_built}]") return 'cpu' def get_device_for_sentence_transformers() -> str: """ Get device string for sentence-transformers library. Returns: str: Device name compatible with sentence-transformers """ device = get_device() # sentence-transformers uses same naming return device def get_device_for_colpali() -> str: """ Get device string for ColPali models. Returns: str: Device name compatible with ColPali/transformers """ device = get_device() # ColPali/transformers uses same naming return device def log_device_info(): """Log detailed device information""" device = get_device() logger.info("=" * 70) logger.info("🖥️ DEVICE CONFIGURATION") logger.info("=" * 70) logger.info(f"Selected device: {device.upper()}") if device == 'mps': logger.info(" Type: Apple Silicon (M1/M2/M3)") logger.info(" Acceleration: Metal Performance Shaders") elif device == 'cuda': logger.info(f" Type: NVIDIA GPU") logger.info(f" GPU: {torch.cuda.get_device_name(0)}") logger.info(f" CUDA Version: {torch.version.cuda}") logger.info(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") else: logger.info(" Type: CPU (no GPU acceleration)") logger.info(" Note: Processing will be slower") logger.info("=" * 70) return device # Convenience function to get device once at module level _cached_device = None def get_cached_device() -> str: """ Get device with caching to avoid repeated checks. Returns: str: Device name ('mps', 'cuda', or 'cpu') """ global _cached_device if _cached_device is None: _cached_device = get_device() return _cached_device