File size: 4,011 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context

from contextlib import contextmanager

import torch
import torch.cuda
from torch import Tensor

from .process_group_initializer import ParallelMode


class SeedManager:
    """This class is a manager of all random seeds involved in the system."""

    def __init__(self):
        self._current_mode = None
        self._seeds = {}
        self._seed_states = {}

    @property
    def current_mode(self):
        return self._current_mode

    @property
    def seeds(self):
        return self._seeds

    @property
    def seed_states(self):
        return self._seed_states

    def set_state(self, parallel_mode: ParallelMode, state: Tensor):
        """Sets the state of the seed manager for `parallel_mode`."""
        assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
        self._seed_states[parallel_mode] = state

    def set_mode(self, parallel_mode: ParallelMode):
        """Sets the current mode of the seed manager."""
        if self.current_mode:
            # save state for current mode
            self._seed_states[self._current_mode] = torch.cuda.get_rng_state()

        # set new state for new mode
        self._current_mode = parallel_mode
        torch.cuda.set_rng_state(self._seed_states[parallel_mode])

    def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
        """Adds a seed to the seed manager for `parallel_mode`."""
        assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode"
        if not overwrite:
            assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists"
        elif parallel_mode in self._seed_states:
            print(f"Warning: {parallel_mode} seed overwritten.", flush=True)

        current_state = torch.cuda.get_rng_state()
        torch.cuda.manual_seed(seed)
        self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
        self._seeds[parallel_mode] = seed
        torch.cuda.set_rng_state(current_state)

    def reset(self):
        self._current_mode = None
        self._seeds = {}
        self._seed_states = {}


_SEED_MANAGER = SeedManager()


def get_seeds():
    """Returns the seeds of the seed manager.
    Returns:
        dict: The seeds of the seed manager.
    """
    return _SEED_MANAGER.seeds


def get_states(copy=False):
    """Returns the seed states of the seed manager.
    Returns:
        dict: The seed states of the seed manager.
    """
    states = _SEED_MANAGER.seed_states
    if copy:
        new_states = dict()
        for parallel_mode, state in states.items():
            new_states[parallel_mode] = state.clone()
        return new_states
    else:
        return _SEED_MANAGER.seed_states


def get_current_mode():
    """Returns the current mode of the seed manager.
    Returns:
        :class:`torch.ByteTensor`: The current mode of the seed manager.
    """
    return _SEED_MANAGER.current_mode


def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
    """Adds a seed to the seed manager for `parallel_mode`."""
    _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)


def set_mode(parallel_mode: ParallelMode):
    """Sets the current mode of the seed manager."""
    _SEED_MANAGER.set_mode(parallel_mode)


def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
    """Sets the state of the seed manager for `parallel_mode`."""
    _SEED_MANAGER.set_state(parallel_mode, state)


def sync_states():
    current_mode = get_current_mode()
    current_states = torch.cuda.get_rng_state()
    set_seed_states(current_mode, current_states)


@contextmanager
def seed(parallel_mode: ParallelMode):
    """A context for seed switch"""
    current_mode = _SEED_MANAGER.current_mode
    try:
        yield _SEED_MANAGER.set_mode(parallel_mode)
    finally:
        _SEED_MANAGER.set_mode(current_mode)