PIVOT / src /utils /common.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
2.14 kB
"""shared utils: seeding, device selection, timing, json i/o."""
from __future__ import annotations
import json
import os
import random
import time
from contextlib import contextmanager
import numpy as np
def set_seed(seed: int = 0) -> None:
"""seed python, numpy and torch (if available) for reproducibility."""
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
try:
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# determinism for reproducibility; benchmark off
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
except Exception:
pass
def pick_device(prefer_index: int | None = None):
"""return a torch device, picking the gpu with the most free memory.
on this shared node several gpus may be partly occupied, so we grab the
emptiest one unless ``prefer_index`` is given.
"""
import torch
if not torch.cuda.is_available():
return torch.device("cpu")
if prefer_index is not None:
return torch.device(f"cuda:{prefer_index}")
best, best_free = 0, -1
for i in range(torch.cuda.device_count()):
free, _ = torch.cuda.mem_get_info(i)
if free > best_free:
best, best_free = i, free
return torch.device(f"cuda:{best}")
@contextmanager
def timer(name: str = "block"):
"""context manager yielding wall-clock seconds via ``.t``."""
class _T:
t = 0.0
obj = _T()
start = time.perf_counter()
try:
yield obj
finally:
obj.t = time.perf_counter() - start
def save_json(obj, path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
json.dump(obj, f, indent=2, default=_json_default)
def load_json(path: str):
with open(path) as f:
return json.load(f)
def _json_default(o):
if isinstance(o, (np.integer,)):
return int(o)
if isinstance(o, (np.floating,)):
return float(o)
if isinstance(o, (np.ndarray,)):
return o.tolist()
return str(o)