Spaces:
Paused
Paused
| """ | |
| Device detection and selection utility. | |
| Priority: CUDA β MPS (Apple Silicon) β CPU | |
| Usage: | |
| from utils.device import get_device, device_info | |
| device = get_device() # e.g. "mps", "cuda", "cpu" | |
| print(device_info()) # human-readable summary | |
| """ | |
| import logging | |
| import os | |
| from functools import lru_cache | |
| logger = logging.getLogger(__name__) | |
| def get_device() -> str: | |
| """ | |
| Detect the best available torch device. | |
| Respects the TORCH_DEVICE env var to force a specific device. | |
| Returns one of: "cuda", "mps", "cpu" | |
| """ | |
| # Allow explicit override | |
| forced = os.environ.get("TORCH_DEVICE", "").strip().lower() | |
| if forced in ("cuda", "mps", "cpu"): | |
| logger.info(f"Device forced via TORCH_DEVICE env var: {forced}") | |
| return forced | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| logger.info(f"Auto-detected device: {device}") | |
| return device | |
| except ImportError: | |
| logger.warning("torch not importable β falling back to CPU") | |
| return "cpu" | |
| def device_info() -> dict: | |
| """ | |
| Return a dict with device name and label. | |
| Priority: DEVICE_LABEL env var β Groq API key β CUDA β MPS β CPU | |
| No hardcoded strings β everything derived from environment or torch. | |
| """ | |
| # Explicit override (e.g. set in HF Space secrets) | |
| label_override = os.environ.get("DEVICE_LABEL", "").strip() | |
| if label_override: | |
| return {"device": "override", "label": label_override} | |
| # HuggingFace Inference API (free, no daily limit) | |
| if os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN"): | |
| hf_model = os.environ.get("HF_MODEL", "meta-llama/Llama-3.1-8B-Instruct") | |
| label = f"HuggingFace ({hf_model.split('/')[-1]})" | |
| logger.info(f"Device info: {label}") | |
| return {"device": "hf", "label": label} | |
| # Groq: LLM runs on Groq cloud, not local hardware | |
| if os.environ.get("GROQ_API_KEY"): | |
| groq_model = os.environ.get("GROQ_MODEL", "llama-3.3-70b-versatile") | |
| label = f"Groq LPU ({groq_model})" | |
| logger.info(f"Device info: {label}") | |
| return {"device": "groq", "label": label} | |
| # Local torch device detection | |
| device = get_device() | |
| info = {"device": device, "label": device.upper()} | |
| try: | |
| import torch | |
| if device == "cuda": | |
| idx = torch.cuda.current_device() | |
| info["gpu_name"] = torch.cuda.get_device_name(idx) | |
| total = torch.cuda.get_device_properties(idx).total_memory | |
| info["vram_gb"] = round(total / 1024 ** 3, 1) | |
| info["label"] = f"CUDA β {info['gpu_name']} ({info['vram_gb']} GB)" | |
| elif device == "mps": | |
| try: | |
| import subprocess | |
| chip = subprocess.check_output( | |
| ["sysctl", "-n", "machdep.cpu.brand_string"], | |
| stderr=subprocess.DEVNULL | |
| ).decode().strip() | |
| except Exception: | |
| chip = "Apple Silicon" | |
| info["label"] = f"MPS β {chip}" | |
| else: | |
| info["label"] = "CPU" | |
| except Exception as e: | |
| logger.debug(f"device_info detail error: {e}") | |
| logger.info(f"Device info: {info['label']}") | |
| return info |