| """ Plateau Scheduler |
| |
| Adapts PyTorch plateau scheduler and allows application of noise, warmup. |
| |
| Hacked together by / Copyright 2020 Ross Wightman |
| """ |
| import torch |
|
|
| from .scheduler import Scheduler |
|
|
|
|
| class PlateauLRScheduler(Scheduler): |
| """Decay the LR by a factor every time the validation loss plateaus.""" |
|
|
| def __init__( |
| self, |
| optimizer, |
| decay_rate=0.1, |
| patience_t=10, |
| verbose=True, |
| threshold=1e-4, |
| cooldown_t=0, |
| warmup_t=0, |
| warmup_lr_init=0, |
| lr_min=0, |
| mode="max", |
| noise_range_t=None, |
| noise_type="normal", |
| noise_pct=0.67, |
| noise_std=1.0, |
| noise_seed=None, |
| initialize=True, |
| ): |
| super().__init__(optimizer, "lr", initialize=initialize) |
|
|
| self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| self.optimizer, |
| patience=patience_t, |
| factor=decay_rate, |
| verbose=verbose, |
| threshold=threshold, |
| cooldown=cooldown_t, |
| mode=mode, |
| min_lr=lr_min, |
| ) |
|
|
| self.noise_range = noise_range_t |
| self.noise_pct = noise_pct |
| self.noise_type = noise_type |
| self.noise_std = noise_std |
| self.noise_seed = noise_seed if noise_seed is not None else 42 |
| self.warmup_t = warmup_t |
| self.warmup_lr_init = warmup_lr_init |
| if self.warmup_t: |
| self.warmup_steps = [ |
| (v - warmup_lr_init) / self.warmup_t for v in self.base_values |
| ] |
| super().update_groups(self.warmup_lr_init) |
| else: |
| self.warmup_steps = [1 for _ in self.base_values] |
| self.restore_lr = None |
|
|
| def state_dict(self): |
| return { |
| "best": self.lr_scheduler.best, |
| "last_epoch": self.lr_scheduler.last_epoch, |
| } |
|
|
| def load_state_dict(self, state_dict): |
| self.lr_scheduler.best = state_dict["best"] |
| if "last_epoch" in state_dict: |
| self.lr_scheduler.last_epoch = state_dict["last_epoch"] |
|
|
| |
| def step(self, epoch, metric=None): |
| if epoch <= self.warmup_t: |
| lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] |
| super().update_groups(lrs) |
| else: |
| if self.restore_lr is not None: |
| |
| for i, param_group in enumerate(self.optimizer.param_groups): |
| param_group["lr"] = self.restore_lr[i] |
| self.restore_lr = None |
|
|
| self.lr_scheduler.step(metric, epoch) |
|
|
| if self.noise_range is not None: |
| if isinstance(self.noise_range, (list, tuple)): |
| apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] |
| else: |
| apply_noise = epoch >= self.noise_range |
| if apply_noise: |
| self._apply_noise(epoch) |
|
|
| def _apply_noise(self, epoch): |
| g = torch.Generator() |
| g.manual_seed(self.noise_seed + epoch) |
| if self.noise_type == "normal": |
| while True: |
| |
| noise = torch.randn(1, generator=g).item() |
| if abs(noise) < self.noise_pct: |
| break |
| else: |
| noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct |
|
|
| |
| |
| restore_lr = [] |
| for i, param_group in enumerate(self.optimizer.param_groups): |
| old_lr = float(param_group["lr"]) |
| restore_lr.append(old_lr) |
| new_lr = old_lr + old_lr * noise |
| param_group["lr"] = new_lr |
| self.restore_lr = restore_lr |
|
|