| """shared utils: seeding, device selection, timing, json i/o.""" |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| import time |
| from contextlib import contextmanager |
|
|
| import numpy as np |
|
|
|
|
| def set_seed(seed: int = 0) -> None: |
| """seed python, numpy and torch (if available) for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| os.environ["PYTHONHASHSEED"] = str(seed) |
| try: |
| import torch |
|
|
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| except Exception: |
| pass |
|
|
|
|
| def pick_device(prefer_index: int | None = None): |
| """return a torch device, picking the gpu with the most free memory. |
| |
| on this shared node several gpus may be partly occupied, so we grab the |
| emptiest one unless ``prefer_index`` is given. |
| """ |
| import torch |
|
|
| if not torch.cuda.is_available(): |
| return torch.device("cpu") |
| if prefer_index is not None: |
| return torch.device(f"cuda:{prefer_index}") |
| best, best_free = 0, -1 |
| for i in range(torch.cuda.device_count()): |
| free, _ = torch.cuda.mem_get_info(i) |
| if free > best_free: |
| best, best_free = i, free |
| return torch.device(f"cuda:{best}") |
|
|
|
|
| @contextmanager |
| def timer(name: str = "block"): |
| """context manager yielding wall-clock seconds via ``.t``.""" |
|
|
| class _T: |
| t = 0.0 |
|
|
| obj = _T() |
| start = time.perf_counter() |
| try: |
| yield obj |
| finally: |
| obj.t = time.perf_counter() - start |
|
|
|
|
| def save_json(obj, path: str) -> None: |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| with open(path, "w") as f: |
| json.dump(obj, f, indent=2, default=_json_default) |
|
|
|
|
| def load_json(path: str): |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def _json_default(o): |
| if isinstance(o, (np.integer,)): |
| return int(o) |
| if isinstance(o, (np.floating,)): |
| return float(o) |
| if isinstance(o, (np.ndarray,)): |
| return o.tolist() |
| return str(o) |
|
|