MultiModalRag / utils /device.py
irajkoohi's picture
chore: update app [space deploy]
6c21523
Raw
History Blame Contribute Delete
3.48 kB
"""
Device detection and selection utility.
Priority: CUDA β†’ MPS (Apple Silicon) β†’ CPU
Usage:
from utils.device import get_device, device_info
device = get_device() # e.g. "mps", "cuda", "cpu"
print(device_info()) # human-readable summary
"""
import logging
import os
from functools import lru_cache
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_device() -> str:
"""
Detect the best available torch device.
Respects the TORCH_DEVICE env var to force a specific device.
Returns one of: "cuda", "mps", "cpu"
"""
# Allow explicit override
forced = os.environ.get("TORCH_DEVICE", "").strip().lower()
if forced in ("cuda", "mps", "cpu"):
logger.info(f"Device forced via TORCH_DEVICE env var: {forced}")
return forced
try:
import torch
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = "mps"
else:
device = "cpu"
logger.info(f"Auto-detected device: {device}")
return device
except ImportError:
logger.warning("torch not importable β€” falling back to CPU")
return "cpu"
def device_info() -> dict:
"""
Return a dict with device name and label.
Priority: DEVICE_LABEL env var β†’ Groq API key β†’ CUDA β†’ MPS β†’ CPU
No hardcoded strings β€” everything derived from environment or torch.
"""
# Explicit override (e.g. set in HF Space secrets)
label_override = os.environ.get("DEVICE_LABEL", "").strip()
if label_override:
return {"device": "override", "label": label_override}
# HuggingFace Inference API (free, no daily limit)
if os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN"):
hf_model = os.environ.get("HF_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
label = f"HuggingFace ({hf_model.split('/')[-1]})"
logger.info(f"Device info: {label}")
return {"device": "hf", "label": label}
# Groq: LLM runs on Groq cloud, not local hardware
if os.environ.get("GROQ_API_KEY"):
groq_model = os.environ.get("GROQ_MODEL", "llama-3.3-70b-versatile")
label = f"Groq LPU ({groq_model})"
logger.info(f"Device info: {label}")
return {"device": "groq", "label": label}
# Local torch device detection
device = get_device()
info = {"device": device, "label": device.upper()}
try:
import torch
if device == "cuda":
idx = torch.cuda.current_device()
info["gpu_name"] = torch.cuda.get_device_name(idx)
total = torch.cuda.get_device_properties(idx).total_memory
info["vram_gb"] = round(total / 1024 ** 3, 1)
info["label"] = f"CUDA β€” {info['gpu_name']} ({info['vram_gb']} GB)"
elif device == "mps":
try:
import subprocess
chip = subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"],
stderr=subprocess.DEVNULL
).decode().strip()
except Exception:
chip = "Apple Silicon"
info["label"] = f"MPS β€” {chip}"
else:
info["label"] = "CPU"
except Exception as e:
logger.debug(f"device_info detail error: {e}")
logger.info(f"Device info: {info['label']}")
return info