| | from torch.optim import Optimizer |
| | import math |
| |
|
| | class _LRScheduler(object): |
| | def __init__(self, optimizer, last_epoch=-1): |
| | if not isinstance(optimizer, Optimizer): |
| | raise TypeError('{} is not an Optimizer'.format( |
| | type(optimizer).__name__)) |
| | self.optimizer = optimizer |
| | if last_epoch == -1: |
| | for group in optimizer.param_groups: |
| | group.setdefault('initial_lr', group['lr']) |
| | else: |
| | for i, group in enumerate(optimizer.param_groups): |
| | if 'initial_lr' not in group: |
| | raise KeyError("param 'initial_lr' is not specified " |
| | "in param_groups[{}] when resuming an optimizer".format(i)) |
| | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) |
| | self.step(epoch = last_epoch + 1) |
| | self.last_epoch = last_epoch |
| |
|
| | def state_dict(self): |
| | """Returns the state of the scheduler as a :class:`dict`. |
| | It contains an entry for every variable in self.__dict__ which |
| | is not the optimizer. |
| | """ |
| | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} |
| |
|
| | def load_state_dict(self, state_dict): |
| | """Loads the schedulers state. |
| | Arguments: |
| | state_dict (dict): scheduler state. Should be an object returned |
| | from a call to :meth:`state_dict`. |
| | """ |
| | self.__dict__.update(state_dict) |
| |
|
| | def get_lr(self): |
| | raise NotImplementedError |
| |
|
| | def step(self, epoch=None): |
| | if epoch is None: |
| | epoch = self.last_epoch + 1 |
| | self.last_epoch = epoch |
| | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): |
| | param_group['lr'] = lr |
| |
|
| | class CosineSchedule(_LRScheduler): |
| |
|
| | def __init__(self, optimizer, K): |
| | self.K = K |
| | super().__init__(optimizer, -1) |
| |
|
| | def cosine(self, base_lr): |
| | if self.K == 1: |
| | return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (2-1))) |
| | return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K-1))) |
| |
|
| | def get_lr(self): |
| | return [self.cosine(base_lr) for base_lr in self.base_lrs] |
| | |
| | def get_last_lr(self): |
| | return self.get_lr() |
| |
|
| | class CosineAnnealingWarmUp(_LRScheduler): |
| |
|
| | def __init__(self, optimizer, warmup_length, T_max = 0, last_epoch = -1): |
| | self.warmup_length = warmup_length |
| | self.T_max = T_max |
| | self.last_epoch = last_epoch |
| |
|
| | super().__init__(optimizer, last_epoch) |
| |
|
| | def cosine_lr(self, base_lr): |
| |
|
| | return base_lr * 0.5 * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) |
| |
|
| | def warmup_lr(self, base_lr): |
| |
|
| | return base_lr * (self.last_epoch + 1) / self.warmup_length |
| |
|
| | def get_lr(self): |
| | if self.last_epoch < self.warmup_length: |
| | return [self.warmup_lr(base_lr) for base_lr in self.base_lrs] |
| | else: |
| | return [self.cosine_lr(base_lr) for base_lr in self.base_lrs] |
| | |
| | def get_last_lr(self): |
| | assert self.T_max > 0, 'CosineAnnealingWarmUp is called with T_max <= 0, Check your code' |
| | return self.get_lr() |
| |
|
| | class PatienceSchedule(_LRScheduler): |
| |
|
| | def __init__(self, optimizer, patience, factor): |
| | self.factor = factor |
| | self.patience = patience |
| | self.best_loss = float('inf') |
| | self.counter = 0 |
| |
|
| | super().__init__(optimizer, -1) |
| |
|
| | def step(self, current_loss = None, **kwargs): |
| | |
| | |
| |
|
| | if current_loss is None: |
| | return 0 |
| | |
| | |
| | if current_loss < self.best_loss: |
| | self.best_loss = current_loss |
| | self.counter = 0 |
| | else: |
| | |
| | self.counter += 1 |
| | |
| | |
| | if self.counter >= self.patience: |
| | for param_group in self.optimizer.param_groups: |
| | param_group['lr'] /= self.factor |
| | print(f"Reducing learning rate to {self.optimizer.param_groups[0]['lr']:.5f}") |
| | self.counter = 0 |
| |
|
| | def get_last_lr(self): |
| | return self.optimizer.param_groups[0]['lr'] |