| """ | |
| @Date: 2021/09/14 | |
| @description: | |
| """ | |
| class WarmupScheduler: | |
| def __init__(self, optimizer, lr_pow, init_lr, warmup_lr, warmup_step, max_step, **kwargs): | |
| self.lr_pow = lr_pow | |
| self.init_lr = init_lr | |
| self.running_lr = init_lr | |
| self.warmup_lr = warmup_lr | |
| self.warmup_step = warmup_step | |
| self.max_step = max_step | |
| self.optimizer = optimizer | |
| def step_update(self, cur_step): | |
| if cur_step < self.warmup_step: | |
| frac = cur_step / self.warmup_step | |
| step = self.warmup_lr - self.init_lr | |
| self.running_lr = self.init_lr + step * frac | |
| else: | |
| frac = (float(cur_step) - self.warmup_step) / (self.max_step - self.warmup_step) | |
| scale_running_lr = max((1. - frac), 0.) ** self.lr_pow | |
| self.running_lr = self.warmup_lr * scale_running_lr | |
| if self.optimizer is not None: | |
| for param_group in self.optimizer.param_groups: | |
| param_group['lr'] = self.running_lr | |
| if __name__ == '__main__': | |
| import matplotlib.pyplot as plt | |
| scheduler = WarmupScheduler(optimizer=None, | |
| lr_pow=4, | |
| init_lr=0.0000003, | |
| warmup_lr=0.00003, | |
| warmup_step=10000, | |
| max_step=100000) | |
| x = [] | |
| y = [] | |
| for i in range(100000): | |
| if i == 10000-1: | |
| print() | |
| scheduler.step_update(i) | |
| x.append(i) | |
| y.append(scheduler.running_lr) | |
| plt.plot(x, y, linewidth=1) | |
| plt.show() | |