File size: 1,513 Bytes
5850885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Shared seeding helper for training and eval entrypoints."""

from __future__ import annotations

import random
from typing import Any


def set_seed(seed: int) -> None:
    """Seed every RNG surface the training/eval stack relies on."""
    random.seed(seed)

    np_mod: Any | None = None
    try:
        import numpy as _np_mod
    except ImportError:
        pass
    else:
        np_mod = _np_mod
    if np_mod is not None:
        np_mod.random.seed(seed)

    torch_mod: Any | None = None
    try:
        import torch as _torch_mod
    except ImportError:
        pass
    else:
        torch_mod = _torch_mod
    if torch_mod is not None:
        manual_seed = getattr(torch_mod, "manual_seed", None)
        if callable(manual_seed):
            manual_seed(seed)
        cuda = getattr(torch_mod, "cuda", None)
        is_available = getattr(cuda, "is_available", None)
        if cuda is not None and callable(is_available) and is_available():
            cuda_manual_seed = getattr(cuda, "manual_seed", None)
            if callable(cuda_manual_seed):
                cuda_manual_seed(seed)
            cuda_manual_seed_all = getattr(cuda, "manual_seed_all", None)
            if callable(cuda_manual_seed_all):
                cuda_manual_seed_all(seed)

    try:
        from transformers import set_seed as transformers_set_seed
    except ImportError:
        transformers_set_seed = None
    if transformers_set_seed is not None:
        transformers_set_seed(seed)


__all__ = ["set_seed"]