File size: 2,142 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""shared utils: seeding, device selection, timing, json i/o."""
from __future__ import annotations

import json
import os
import random
import time
from contextlib import contextmanager

import numpy as np


def set_seed(seed: int = 0) -> None:
    """seed python, numpy and torch (if available) for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    try:
        import torch

        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # determinism for reproducibility; benchmark off
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass


def pick_device(prefer_index: int | None = None):
    """return a torch device, picking the gpu with the most free memory.

    on this shared node several gpus may be partly occupied, so we grab the
    emptiest one unless ``prefer_index`` is given.
    """
    import torch

    if not torch.cuda.is_available():
        return torch.device("cpu")
    if prefer_index is not None:
        return torch.device(f"cuda:{prefer_index}")
    best, best_free = 0, -1
    for i in range(torch.cuda.device_count()):
        free, _ = torch.cuda.mem_get_info(i)
        if free > best_free:
            best, best_free = i, free
    return torch.device(f"cuda:{best}")


@contextmanager
def timer(name: str = "block"):
    """context manager yielding wall-clock seconds via ``.t``."""

    class _T:
        t = 0.0

    obj = _T()
    start = time.perf_counter()
    try:
        yield obj
    finally:
        obj.t = time.perf_counter() - start


def save_json(obj, path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        json.dump(obj, f, indent=2, default=_json_default)


def load_json(path: str):
    with open(path) as f:
        return json.load(f)


def _json_default(o):
    if isinstance(o, (np.integer,)):
        return int(o)
    if isinstance(o, (np.floating,)):
        return float(o)
    if isinstance(o, (np.ndarray,)):
        return o.tolist()
    return str(o)