| from typing import Dict, Any |
|
|
| import torch |
|
|
|
|
| class Scheduler: |
| """ Parameter Scheduler Base Class |
| A scheduler base class that can be used to schedule any optimizer parameter groups. |
| |
| Unlike the builtin PyTorch schedulers, this is intended to be consistently called |
| * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value |
| * At the END of each optimizer update, after incrementing the update count, to calculate next update's value |
| |
| The schedulers built on this should try to remain as stateless as possible (for simplicity). |
| |
| This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' |
| and -1 values for special behaviour. All epoch and update counts must be tracked in the training |
| code and explicitly passed in to the schedulers on the corresponding step or step_update call. |
| |
| Based on ideas from: |
| * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler |
| * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers |
| """ |
|
|
| def __init__( |
| self, |
| optimizer: torch.optim.Optimizer, |
| param_group_field: str, |
| noise_range_t=None, |
| noise_type="normal", |
| noise_pct=0.67, |
| noise_std=1.0, |
| noise_seed=None, |
| initialize: bool = True, |
| ) -> None: |
| self.optimizer = optimizer |
| self.param_group_field = param_group_field |
| self._initial_param_group_field = f"initial_{param_group_field}" |
| if initialize: |
| for i, group in enumerate(self.optimizer.param_groups): |
| if param_group_field not in group: |
| raise KeyError( |
| f"{param_group_field} missing from param_groups[{i}]" |
| ) |
| group.setdefault( |
| self._initial_param_group_field, group[param_group_field] |
| ) |
| else: |
| for i, group in enumerate(self.optimizer.param_groups): |
| if self._initial_param_group_field not in group: |
| raise KeyError( |
| f"{self._initial_param_group_field} missing from param_groups[{i}]" |
| ) |
| self.base_values = [ |
| group[self._initial_param_group_field] |
| for group in self.optimizer.param_groups |
| ] |
| self.metric = None |
| self.noise_range_t = noise_range_t |
| self.noise_pct = noise_pct |
| self.noise_type = noise_type |
| self.noise_std = noise_std |
| self.noise_seed = noise_seed if noise_seed is not None else 42 |
| self.update_groups(self.base_values) |
|
|
| def state_dict(self) -> Dict[str, Any]: |
| return { |
| key: value for key, value in self.__dict__.items() if key != "optimizer" |
| } |
|
|
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| self.__dict__.update(state_dict) |
|
|
| def get_epoch_values(self, epoch: int): |
| return None |
|
|
| def get_update_values(self, num_updates: int): |
| return None |
|
|
| def step(self, epoch: int, metric: float = None) -> None: |
| self.metric = metric |
| values = self.get_epoch_values(epoch) |
| if values is not None: |
| values = self._add_noise(values, epoch) |
| self.update_groups(values) |
|
|
| def step_update(self, num_updates: int, metric: float = None): |
| self.metric = metric |
| values = self.get_update_values(num_updates) |
| if values is not None: |
| values = self._add_noise(values, num_updates) |
| self.update_groups(values) |
|
|
| def update_groups(self, values): |
| if not isinstance(values, (list, tuple)): |
| values = [values] * len(self.optimizer.param_groups) |
| for param_group, value in zip(self.optimizer.param_groups, values): |
| param_group[self.param_group_field] = value |
|
|
| def _add_noise(self, lrs, t): |
| if self.noise_range_t is not None: |
| if isinstance(self.noise_range_t, (list, tuple)): |
| apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] |
| else: |
| apply_noise = t >= self.noise_range_t |
| if apply_noise: |
| g = torch.Generator() |
| g.manual_seed(self.noise_seed + t) |
| if self.noise_type == "normal": |
| while True: |
| |
| noise = torch.randn(1, generator=g).item() |
| if abs(noise) < self.noise_pct: |
| break |
| else: |
| noise = ( |
| 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct |
| ) |
| lrs = [v + v * noise for v in lrs] |
| return lrs |
|
|