File size: 2,142 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """shared utils: seeding, device selection, timing, json i/o."""
from __future__ import annotations
import json
import os
import random
import time
from contextlib import contextmanager
import numpy as np
def set_seed(seed: int = 0) -> None:
"""seed python, numpy and torch (if available) for reproducibility."""
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
try:
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# determinism for reproducibility; benchmark off
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
except Exception:
pass
def pick_device(prefer_index: int | None = None):
"""return a torch device, picking the gpu with the most free memory.
on this shared node several gpus may be partly occupied, so we grab the
emptiest one unless ``prefer_index`` is given.
"""
import torch
if not torch.cuda.is_available():
return torch.device("cpu")
if prefer_index is not None:
return torch.device(f"cuda:{prefer_index}")
best, best_free = 0, -1
for i in range(torch.cuda.device_count()):
free, _ = torch.cuda.mem_get_info(i)
if free > best_free:
best, best_free = i, free
return torch.device(f"cuda:{best}")
@contextmanager
def timer(name: str = "block"):
"""context manager yielding wall-clock seconds via ``.t``."""
class _T:
t = 0.0
obj = _T()
start = time.perf_counter()
try:
yield obj
finally:
obj.t = time.perf_counter() - start
def save_json(obj, path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
json.dump(obj, f, indent=2, default=_json_default)
def load_json(path: str):
with open(path) as f:
return json.load(f)
def _json_default(o):
if isinstance(o, (np.integer,)):
return int(o)
if isinstance(o, (np.floating,)):
return float(o)
if isinstance(o, (np.ndarray,)):
return o.tolist()
return str(o)
|