English
File size: 8,071 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from torch.optim.lr_scheduler import _LRScheduler, StepLR, MultiStepLR, \
    ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
import math
import warnings


__all__ = [
    'CosinePowerAnnealingLR', 'StepLRWithWarmup', 'MultiStepLRWithWarmup',
    'ExponentialLRWithWarmup', 'CosineAnnealingLRWithWarmup',
    'CosinePowerAnnealingLRWithWarmup', 'ReduceLROnPlateauWithWarmup']


class _WarmupLR(_LRScheduler):
    """Wrapper adding a warmup phase to a Pytorch Scheduler.

    This class is not intended to be directly instantiated. One should
    instead create child classes with the desired `_SCHEDULER_CLASS`.

    Credit: https://github.com/lehduong/torch-warmup-lr

    :param init_lr: float
        Learning rate value to start the warmup from. All your
        optimizer's parameter groups will be warmed up from
        `init_lr` to their initial value as set in the optimizer
    :param num_warmup: int
        Number of scheduler steps (i.e. epochs, most of the time)
        dedicated to warming up
    :param warmup_strategy: str
        Warmup strategy, among ['linear', 'cos', 'constant']
    """
    _SCHEDULER_CLASS = None

    def __init__(
            self, *args, warmup_init_lr=1e-6, num_warmup=1,
            warmup_strategy='cos', **kwargs):

        assert warmup_strategy in ['linear', 'cos', 'constant'], \
            f"Expect warmup_strategy to be one of ['linear', 'cos', " \
            f"'constant'] but got {warmup_strategy}"

        self._scheduler = self._SCHEDULER_CLASS(*args, **kwargs)
        self._init_lr = warmup_init_lr
        self._num_warmup = num_warmup
        self._step_count = 0

        # Define the strategy to warm up learning rate
        self._warmup_strategy = warmup_strategy
        if warmup_strategy == 'cos':
            self._warmup_func = self._warmup_cos
        elif warmup_strategy == 'linear':
            self._warmup_func = self._warmup_linear
        else:
            self._warmup_func = self._warmup_const

        # Dave initial learning rate of each param group. only useful
        # when each param groups having different learning rate
        self._format_param()

        # A first step is needed to initialize the LR
        self.step()

    def __getattr__(self, name):
        if name == '_scheduler':
            if name in self.__dict__.keys():
                return self._scheduler
            else:
                return
        return getattr(self._scheduler, name)

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        wrapper_state_dict = {
            key: value
            for key, value in self.__dict__.items()
            if (key != 'optimizer' and key != '_scheduler')}

        wrapped_state_dict = {
            key: value
            for key, value in self._scheduler.__dict__.items()
            if key != 'optimizer'}

        return {'wrapped': wrapped_state_dict, 'wrapper': wrapper_state_dict}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        :param state_dict: dict
            Scheduler state. Should be an object returned from a call
            to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict['wrapper'])
        self._scheduler.__dict__.update(state_dict['wrapped'])

    def _format_param(self):
        """Set the first and last learning rates for the warmup phase,
        for each parameter group. All parameter groups will start the
        warmup at the same value `self._init_lr`.
        """
        for group in self._scheduler.optimizer.param_groups:
            group['warmup_max_lr'] = group['lr']
            group['warmup_initial_lr'] = min(self._init_lr, group['lr'])

    def _warmup_cos(self, start, end, pct):
        """Cosine warmup scheme.
        """
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    def _warmup_const(self, start, end, pct):
        """Constant warmup scheme.
        """
        return start if pct < 0.9999 else end

    def _warmup_linear(self, start, end, pct):
        """Linear warmup scheme.
        """
        return (end - start) * pct + start

    def get_lr(self):
        lrs = []
        step_num = self._step_count

        # warm up learning rate
        if step_num <= self._num_warmup:
            for group in self._scheduler.optimizer.param_groups:
                computed_lr = self._warmup_func(
                    group['warmup_initial_lr'], group['warmup_max_lr'],
                    step_num / self._num_warmup)
                lrs.append(computed_lr)
        else:
            lrs = self._scheduler.get_lr()
        return lrs

    def step(self, *args, **kwargs):
        if self._step_count <= self._num_warmup:
            values = self.get_lr()
            for param_group, lr in zip(
                    self._scheduler.optimizer.param_groups, values):
                param_group['lr'] = lr
            self._step_count += 1
        else:
            self._scheduler.step(*args, **kwargs)


class CosinePowerAnnealingLR(CosineAnnealingLR):
    """Same as CosineAnnealingLR, but with an additional `power`
    parameter, to mitigate the annealing time spent on large learning
    rates (i.e. `power < 1`) or small learning rates (i.e. `power > 1`).
    """

    def __init__(
            self, optimizer, T_max, eta_min=0, power=2, last_epoch=-1,
            verbose=False):
        super().__init__(
            optimizer, T_max, eta_min=eta_min, last_epoch=last_epoch,
            verbose=verbose)
        self.power = power

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        elif self._step_count == 1 and self.last_epoch > 0:
            return [
                self.eta_min + (base_lr - self.eta_min) *
                ((1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2) ** self.power
                for base_lr, group in
                zip(self.base_lrs, self.optimizer.param_groups)]
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return [
                group['lr'] + (base_lr - self.eta_min) *
                ((1 - math.cos(math.pi / self.T_max)) / 2) ** self.power
                for base_lr, group in
                zip(self.base_lrs, self.optimizer.param_groups)]
        return [
            ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
             (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max))) ** self.power *
            (group['lr'] - self.eta_min) + self.eta_min
            for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [
            self.eta_min + (base_lr - self.eta_min) *
            ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2) ** self.power
            for base_lr in self.base_lrs]


class StepLRWithWarmup(_WarmupLR):
    """StepLRWithWarmup with warmup.
    """
    _SCHEDULER_CLASS = StepLR


class MultiStepLRWithWarmup(_WarmupLR):
    """MultiStepLR with warmup.
    """
    _SCHEDULER_CLASS = MultiStepLR


class ExponentialLRWithWarmup(_WarmupLR):
    """ExponentialLR with warmup.
    """
    _SCHEDULER_CLASS = ExponentialLR


class CosineAnnealingLRWithWarmup(_WarmupLR):
    """CosineAnnealingLR with warmup.
    """
    _SCHEDULER_CLASS = CosineAnnealingLR


class CosinePowerAnnealingLRWithWarmup(_WarmupLR):
    """CosinePowerAnnealingLR with warmup.
    """
    _SCHEDULER_CLASS = CosinePowerAnnealingLR


class ReduceLROnPlateauWithWarmup(_WarmupLR):
    """ReduceLROnPlateau with warmup.
    """
    _SCHEDULER_CLASS = ReduceLROnPlateau


ON_PLATEAU_SCHEDULERS = (ReduceLROnPlateau, ReduceLROnPlateauWithWarmup)