Spaces:
Runtime error
Runtime error
| """ | |
| File copied from | |
| https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py | |
| """ | |
| import torch | |
| import torch.optim as optim | |
| from collections import defaultdict | |
| from torch import Tensor | |
| from torch.optim.optimizer import Optimizer | |
| from typing import Iterable, Optional, Union | |
| _params_type = Union[Iterable[Tensor], Iterable[dict]] | |
| class Lookahead(Optimizer): | |
| """Lookahead optimizer: https://arxiv.org/abs/1907.08610""" | |
| # noinspection PyMissingConstructor | |
| def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6): | |
| if not 0.0 <= alpha <= 1.0: | |
| raise ValueError(f"Invalid slow update rate: {alpha}") | |
| if not 1 <= k: | |
| raise ValueError(f"Invalid lookahead steps: {k}") | |
| defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) | |
| self.base_optimizer = base_optimizer | |
| self.param_groups = self.base_optimizer.param_groups | |
| self.defaults = base_optimizer.defaults | |
| self.defaults.update(defaults) | |
| self.state = defaultdict(dict) | |
| # manually add our defaults to the param groups | |
| for name, default in defaults.items(): | |
| for group in self.param_groups: | |
| group.setdefault(name, default) | |
| def update_slow(self, group: dict): | |
| for fast_p in group["params"]: | |
| if fast_p.grad is None: | |
| continue | |
| param_state = self.state[fast_p] | |
| if "slow_buffer" not in param_state: | |
| param_state["slow_buffer"] = torch.empty_like(fast_p.data) | |
| param_state["slow_buffer"].copy_(fast_p.data) | |
| slow = param_state["slow_buffer"] | |
| slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"]) | |
| fast_p.data.copy_(slow) | |
| def sync_lookahead(self): | |
| for group in self.param_groups: | |
| self.update_slow(group) | |
| def step(self, closure: Optional[callable] = None) -> Optional[float]: | |
| # print(self.k) | |
| # assert id(self.param_groups) == id(self.base_optimizer.param_groups) | |
| loss = self.base_optimizer.step(closure) | |
| for group in self.param_groups: | |
| group["lookahead_step"] += 1 | |
| if group["lookahead_step"] % group["lookahead_k"] == 0: | |
| self.update_slow(group) | |
| return loss | |
| def state_dict(self) -> dict: | |
| fast_state_dict = self.base_optimizer.state_dict() | |
| slow_state = { | |
| (id(k) if isinstance(k, torch.Tensor) else k): v | |
| for k, v in self.state.items() | |
| } | |
| fast_state = fast_state_dict["state"] | |
| param_groups = fast_state_dict["param_groups"] | |
| return { | |
| "state": fast_state, | |
| "slow_state": slow_state, | |
| "param_groups": param_groups, | |
| } | |
| def load_state_dict(self, state_dict: dict): | |
| fast_state_dict = { | |
| "state": state_dict["state"], | |
| "param_groups": state_dict["param_groups"], | |
| } | |
| self.base_optimizer.load_state_dict(fast_state_dict) | |
| # We want to restore the slow state, but share param_groups reference | |
| # with base_optimizer. This is a bit redundant but least code | |
| slow_state_new = False | |
| if "slow_state" not in state_dict: | |
| print("Loading state_dict from optimizer without Lookahead applied.") | |
| state_dict["slow_state"] = defaultdict(dict) | |
| slow_state_new = True | |
| slow_state_dict = { | |
| "state": state_dict["slow_state"], | |
| "param_groups": state_dict[ | |
| "param_groups" | |
| ], # this is pointless but saves code | |
| } | |
| super(Lookahead, self).load_state_dict(slow_state_dict) | |
| self.param_groups = ( | |
| self.base_optimizer.param_groups | |
| ) # make both ref same container | |
| if slow_state_new: | |
| # reapply defaults to catch missing lookahead specific ones | |
| for name, default in self.defaults.items(): | |
| for group in self.param_groups: | |
| group.setdefault(name, default) | |
| def LookaheadAdam( | |
| params: _params_type, | |
| lr: float = 1e-3, | |
| betas: tuple[float, float] = (0.9, 0.999), | |
| eps: float = 1e-08, | |
| weight_decay: float = 0, | |
| amsgrad: bool = False, | |
| lalpha: float = 0.5, | |
| k: int = 6, | |
| ): | |
| return Lookahead( | |
| torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k | |
| ) | |
| def LookaheadRAdam( | |
| params: _params_type, | |
| lr: float = 1e-3, | |
| betas: tuple[float, float] = (0.9, 0.999), | |
| eps: float = 1e-8, | |
| weight_decay: float = 0, | |
| lalpha: float = 0.5, | |
| k: int = 6, | |
| ): | |
| return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k) | |
| def LookaheadRMSprop( | |
| params: _params_type, | |
| lr: float = 1e-2, | |
| alpha: float = 0.99, | |
| eps: float = 1e-08, | |
| weight_decay: float = 0, | |
| momentum: float = 0, | |
| centered: bool = False, | |
| lalpha: float = 0.5, | |
| k: int = 6, | |
| ): | |
| return Lookahead( | |
| torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered), | |
| lalpha, | |
| k, | |
| ) | |