| | import torch |
| | from torch.optim.optimizer import Optimizer |
| | import pytorch_lightning as pl |
| | from torch.optim.lr_scheduler import _LRScheduler |
| |
|
| |
|
| | class BaseScheduler(object): |
| | """Base class for the step-wise scheduler logic. |
| | |
| | Args: |
| | optimizer (Optimize): Optimizer instance to apply lr schedule on. |
| | |
| | Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler. |
| | """ |
| |
|
| | def __init__(self, optimizer): |
| | self.optimizer = optimizer |
| | self.step_num = 0 |
| |
|
| | def zero_grad(self): |
| | self.optimizer.zero_grad() |
| |
|
| | def _get_lr(self): |
| | raise NotImplementedError |
| |
|
| | def _set_lr(self, lr): |
| | for param_group in self.optimizer.param_groups: |
| | param_group["lr"] = lr |
| |
|
| | def step(self, metrics=None, epoch=None): |
| | """Update step-wise learning rate before optimizer.step.""" |
| | self.step_num += 1 |
| | lr = self._get_lr() |
| | self._set_lr(lr) |
| |
|
| | def load_state_dict(self, state_dict): |
| | self.__dict__.update(state_dict) |
| |
|
| | def state_dict(self): |
| | return {key: value for key, value in self.__dict__.items() if key != "optimizer"} |
| |
|
| | def as_tensor(self, start=0, stop=100_000): |
| | """Returns the scheduler values from start to stop.""" |
| | lr_list = [] |
| | for _ in range(start, stop): |
| | self.step_num += 1 |
| | lr_list.append(self._get_lr()) |
| | self.step_num = 0 |
| | return torch.tensor(lr_list) |
| |
|
| | def plot(self, start=0, stop=100_000): |
| | """Plot the scheduler values from start to stop.""" |
| | import matplotlib.pyplot as plt |
| |
|
| | all_lr = self.as_tensor(start=start, stop=stop) |
| | plt.plot(all_lr.numpy()) |
| | plt.show() |
| |
|
| | class DPTNetScheduler(BaseScheduler): |
| | """Dual Path Transformer Scheduler used in [1] |
| | |
| | Args: |
| | optimizer (Optimizer): Optimizer instance to apply lr schedule on. |
| | steps_per_epoch (int): Number of steps per epoch. |
| | d_model(int): The number of units in the layer output. |
| | warmup_steps (int): The number of steps in the warmup stage of training. |
| | noam_scale (float): Linear increase rate in first phase. |
| | exp_max (float): Max learning rate in second phase. |
| | exp_base (float): Exp learning rate base in second phase. |
| | |
| | Schedule: |
| | This scheduler increases the learning rate linearly for the first |
| | ``warmup_steps``, and then decay it by 0.98 for every two epochs. |
| | |
| | References |
| | [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context- |
| | Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | optimizer, |
| | steps_per_epoch, |
| | d_model, |
| | warmup_steps=4000, |
| | noam_scale=1.0, |
| | exp_max=0.0004, |
| | exp_base=0.98, |
| | ): |
| | super().__init__(optimizer) |
| | self.noam_scale = noam_scale |
| | self.d_model = d_model |
| | self.warmup_steps = warmup_steps |
| | self.exp_max = exp_max |
| | self.exp_base = exp_base |
| | self.steps_per_epoch = steps_per_epoch |
| | self.epoch = 0 |
| |
|
| | def _get_lr(self): |
| | if self.step_num % self.steps_per_epoch == 0: |
| | self.epoch += 1 |
| |
|
| | if self.step_num > self.warmup_steps: |
| | |
| | lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2)) |
| | else: |
| | |
| | lr = ( |
| | self.noam_scale |
| | * self.d_model ** (-0.5) |
| | * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) |
| | ) |
| | return lr |
| |
|
| | class CustomExponentialLR(_LRScheduler): |
| | def __init__(self, optimizer, gamma, step_size, last_epoch=-1): |
| | self.gamma = gamma |
| | self.step_size = step_size |
| | self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) |
| | super(CustomExponentialLR, self).__init__(optimizer, last_epoch) |
| |
|
| | def get_lr(self): |
| | if self.last_epoch == 0 or (self.last_epoch + 1) % self.step_size != 0: |
| | return [group['lr'] for group in self.optimizer.param_groups] |
| | return [lr * self.gamma for lr in self.base_lrs] |
| |
|
| |
|
| | |
| | _BaseScheduler = BaseScheduler |