from __future__ import annotations import json import random from contextlib import nullcontext from pathlib import Path from typing import Any import numpy as np import torch def seed_everything(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device(prefer_mps: bool = True, requested_device: str = "auto") -> torch.device: if requested_device != "auto": if requested_device == "cuda": if not torch.cuda.is_available(): raise RuntimeError("CUDA was requested but is not available.") return torch.device("cuda") if requested_device == "mps": if not torch.backends.mps.is_available(): raise RuntimeError("MPS was requested but is not available.") return torch.device("mps") if requested_device == "cpu": return torch.device("cpu") raise ValueError(f"Unsupported device: {requested_device}") if torch.cuda.is_available(): return torch.device("cuda") if prefer_mps and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def get_autocast(device: torch.device, enabled: bool): if not enabled: return nullcontext() if device.type == "cuda": return torch.autocast(device_type="cuda", dtype=torch.float16) return nullcontext() def dump_json(path: str | Path, payload: dict[str, Any]) -> None: path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2) def load_json(path: str | Path) -> dict[str, Any]: with Path(path).open("r", encoding="utf-8") as handle: return json.load(handle)