| import math |
| from collections.abc import Iterable |
| from math import cos, floor, log, pi |
|
|
| import skorch |
| from torch.optim.lr_scheduler import _LRScheduler |
|
|
| _LRScheduler |
|
|
|
|
| class CyclicCosineDecayLR(skorch.callbacks.Callback): |
| def __init__( |
| self, |
| optimizer, |
| init_interval, |
| min_lr, |
| len_param_groups, |
| base_lrs, |
| restart_multiplier=None, |
| restart_interval=None, |
| restart_lr=None, |
| last_epoch=-1, |
| ): |
| """ |
| Initialize new CyclicCosineDecayLR object |
| :param optimizer: (Optimizer) - Wrapped optimizer. |
| :param init_interval: (int) - Initial decay cycle interval. |
| :param min_lr: (float or iterable of floats) - Minimal learning rate. |
| :param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals, |
| if this parameter is set, restart_interval must be None. |
| :param restart_interval: (int) - Restart interval for fixed cycle intervals, |
| if this parameter is set, restart_multiplier must be None. |
| :param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts, |
| if not provided, initial learning rate will be used. |
| :param last_epoch: (int) - Last epoch. |
| """ |
| self.len_param_groups = len_param_groups |
| if restart_interval is not None and restart_multiplier is not None: |
| raise ValueError( |
| "You can either set restart_interval or restart_multiplier but not both" |
| ) |
|
|
| if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups: |
| raise ValueError( |
| "Expected len(min_lr) to be equal to len(optimizer.param_groups), " |
| "got {} and {} instead".format(len(min_lr), self.len_param_groups) |
| ) |
|
|
| if isinstance(restart_lr, Iterable) and len(restart_lr) != len( |
| self.len_param_groups |
| ): |
| raise ValueError( |
| "Expected len(restart_lr) to be equal to len(optimizer.param_groups), " |
| "got {} and {} instead".format(len(restart_lr), self.len_param_groups) |
| ) |
|
|
| if init_interval <= 0: |
| raise ValueError( |
| "init_interval must be a positive number, got {} instead".format( |
| init_interval |
| ) |
| ) |
|
|
| group_num = self.len_param_groups |
| self._init_interval = init_interval |
| self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr |
| self._restart_lr = ( |
| [restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr |
| ) |
| self._restart_interval = restart_interval |
| self._restart_multiplier = restart_multiplier |
| self.last_epoch = last_epoch |
| self.base_lrs = base_lrs |
| super().__init__() |
|
|
| def on_batch_end(self, net, training, **kwargs): |
| if self.last_epoch < self._init_interval: |
| return self._calc(self.last_epoch, self._init_interval, self.base_lrs) |
|
|
| elif self._restart_interval is not None: |
| cycle_epoch = ( |
| self.last_epoch - self._init_interval |
| ) % self._restart_interval |
| lrs = self.base_lrs if self._restart_lr is None else self._restart_lr |
| return self._calc(cycle_epoch, self._restart_interval, lrs) |
|
|
| elif self._restart_multiplier is not None: |
| n = self._get_n(self.last_epoch) |
| sn_prev = self._partial_sum(n) |
| cycle_epoch = self.last_epoch - sn_prev |
| interval = self._init_interval * self._restart_multiplier ** n |
| lrs = self.base_lrs if self._restart_lr is None else self._restart_lr |
| return self._calc(cycle_epoch, interval, lrs) |
| else: |
| return self._min_lr |
|
|
| def _calc(self, t, T, lrs): |
| return [ |
| min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2 |
| for lr, min_lr in zip(lrs, self._min_lr) |
| ] |
|
|
| def _get_n(self, epoch): |
| a = self._init_interval |
| r = self._restart_multiplier |
| _t = 1 - (1 - r) * epoch / a |
| return floor(log(_t, r)) |
|
|
| def _partial_sum(self, n): |
| a = self._init_interval |
| r = self._restart_multiplier |
| return a * (1 - r ** n) / (1 - r) |
|
|
|
|
| class LearningRateDecayCallback(skorch.callbacks.Callback): |
| def __init__( |
| self, |
| config, |
| ): |
| super().__init__() |
| self.lr_warmup_end = config.lr_warmup_end |
| self.lr_warmup_start = config.lr_warmup_start |
| self.learning_rate = config.learning_rate |
| self.warmup_batch = config.warmup_epoch * config.batch_per_epoch |
| self.final_batch = config.final_epoch * config.batch_per_epoch |
|
|
| self.batch_idx = 0 |
|
|
| def on_batch_end(self, net, training, **kwargs): |
| """ |
| |
| :param trainer: |
| :type trainer: |
| :param pl_module: |
| :type pl_module: |
| :param batch: |
| :type batch: |
| :param batch_idx: |
| :type batch_idx: |
| :param dataloader_idx: |
| :type dataloader_idx: |
| """ |
| |
| if training: |
|
|
| if self.batch_idx < self.warmup_batch: |
| |
| lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch)) |
| lr = self.lr_warmup_start + lr_mult * ( |
| self.lr_warmup_end - self.lr_warmup_start |
| ) |
| else: |
| |
| progress = float(self.batch_idx - self.warmup_batch) / float( |
| max(1, self.final_batch - self.warmup_batch) |
| ) |
| lr = max( |
| self.learning_rate |
| + 0.5 |
| * (1.0 + math.cos(math.pi * progress)) |
| * (self.lr_warmup_end - self.learning_rate), |
| self.learning_rate, |
| ) |
| net.lr = lr |
| |
| |
|
|
| self.batch_idx += 1 |
|
|
|
|
| class LRAnnealing(skorch.callbacks.Callback): |
| def on_epoch_end(self, net, **kwargs): |
| if not net.history[-1]["valid_loss_best"]: |
| net.lr /= 4.0 |
|
|