Spaces:
Sleeping
Sleeping
| """ | |
| Device detection utility for automatic hardware acceleration. | |
| Automatically detects and returns the best available device: | |
| - MPS (Apple Silicon) | |
| - CUDA (NVIDIA GPU) | |
| - CPU (fallback) | |
| """ | |
| import torch | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def get_device() -> str: | |
| """ | |
| Automatically detect the best available device. | |
| Priority: | |
| 1. MPS (Apple Silicon) - if available | |
| 2. CUDA (NVIDIA GPU) - if available | |
| 3. CPU - fallback | |
| Returns: | |
| str: Device name ('mps', 'cuda', or 'cpu') | |
| """ | |
| # Check for MPS (Apple Silicon) | |
| mps_available = torch.backends.mps.is_available() | |
| mps_built = torch.backends.mps.is_built() | |
| if mps_available and mps_built: | |
| try: | |
| # Test if MPS is actually usable | |
| torch.zeros(1, device='mps') | |
| logger.info("๐ Using MPS (Apple Silicon) device") | |
| return 'mps' | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ MPS available but not usable: {e}") | |
| # Check for CUDA (NVIDIA GPU) | |
| if torch.cuda.is_available(): | |
| logger.info(f"๐ฎ Using CUDA device (GPU: {torch.cuda.get_device_name(0)})") | |
| return 'cuda' | |
| # Fallback to CPU | |
| logger.info(f"๐ป Using CPU device (no GPU acceleration) [MPS available={mps_available}, MPS built={mps_built}]") | |
| return 'cpu' | |
| def get_device_for_sentence_transformers() -> str: | |
| """ | |
| Get device string for sentence-transformers library. | |
| Returns: | |
| str: Device name compatible with sentence-transformers | |
| """ | |
| device = get_device() | |
| # sentence-transformers uses same naming | |
| return device | |
| def get_device_for_colpali() -> str: | |
| """ | |
| Get device string for ColPali models. | |
| Returns: | |
| str: Device name compatible with ColPali/transformers | |
| """ | |
| device = get_device() | |
| # ColPali/transformers uses same naming | |
| return device | |
| def log_device_info(): | |
| """Log detailed device information""" | |
| device = get_device() | |
| logger.info("=" * 70) | |
| logger.info("๐ฅ๏ธ DEVICE CONFIGURATION") | |
| logger.info("=" * 70) | |
| logger.info(f"Selected device: {device.upper()}") | |
| if device == 'mps': | |
| logger.info(" Type: Apple Silicon (M1/M2/M3)") | |
| logger.info(" Acceleration: Metal Performance Shaders") | |
| elif device == 'cuda': | |
| logger.info(f" Type: NVIDIA GPU") | |
| logger.info(f" GPU: {torch.cuda.get_device_name(0)}") | |
| logger.info(f" CUDA Version: {torch.version.cuda}") | |
| logger.info(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |
| else: | |
| logger.info(" Type: CPU (no GPU acceleration)") | |
| logger.info(" Note: Processing will be slower") | |
| logger.info("=" * 70) | |
| return device | |
| # Convenience function to get device once at module level | |
| _cached_device = None | |
| def get_cached_device() -> str: | |
| """ | |
| Get device with caching to avoid repeated checks. | |
| Returns: | |
| str: Device name ('mps', 'cuda', or 'cpu') | |
| """ | |
| global _cached_device | |
| if _cached_device is None: | |
| _cached_device = get_device() | |
| return _cached_device | |