Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from typing import Dict, Any, List, Optional | |
| import torch.optim | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer | |
| from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler | |
| from omegaconf import II, open_dict | |
| logger = logging.getLogger(__name__) | |
| class OptimizerAndSchedulerConfig(FairseqDataclass): | |
| optimizer: Any = None | |
| lr_scheduler: Optional[Any] = None | |
| lr: List[float] = II("optimization.lr") | |
| class CompositeOptimizerConfig(FairseqDataclass): | |
| groups: Dict[str, OptimizerAndSchedulerConfig] = field( | |
| default_factory=lambda: {}, | |
| metadata={ | |
| "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " | |
| "Configures a different optimizer and (optionally) lr scheduler for each parameter group" | |
| }, | |
| ) | |
| class FairseqCompositeOptimizer(FairseqOptimizer): | |
| optimizers: Dict[str, FairseqOptimizer] = {} | |
| lr_schedulers: Dict[str, FairseqLRScheduler] = {} | |
| lr_scheduler: FairseqLRScheduler = None | |
| _optimizer: torch.optim.Optimizer | |
| def __init__(self, cfg: CompositeOptimizerConfig, params): | |
| super().__init__(cfg) | |
| assert ( | |
| len(params) > 1 | |
| ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" | |
| groupped_params = defaultdict(list) | |
| for p in params: | |
| group = getattr(p, "param_group", "default") | |
| groupped_params[group].append(p) | |
| assert groupped_params.keys() == cfg.groups.keys(), ( | |
| f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! " | |
| "Try setting 'param_group' on your parameters in the model." | |
| ) | |
| for group, group_params in groupped_params.items(): | |
| group_cfg = cfg.groups[group] | |
| with open_dict(group_cfg): | |
| group_cfg.optimizer.lr = group_cfg.lr | |
| group_cfg.lr_scheduler.lr = group_cfg.lr | |
| self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) | |
| if group_cfg.lr_scheduler is not None: | |
| self.lr_schedulers[group] = build_lr_scheduler( | |
| group_cfg.lr_scheduler, self.optimizers[group] | |
| ) | |
| if len(self.lr_schedulers) > 0: | |
| assert len(self.lr_schedulers) == len(self.optimizers), ( | |
| f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " | |
| f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}" | |
| ) | |
| self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers) | |
| self._optimizer = CompositeOptimizer(self.optimizers) | |
| def supports_groups(self): | |
| return True | |
| def param_groups(self): | |
| for opt in self.optimizers.values(): | |
| for group in opt.param_groups: | |
| yield group | |
| def get_lr(self): | |
| """Return the current learning rate.""" | |
| k = ( | |
| "default" | |
| if "default" in self.optimizers | |
| else next(iter(self.optimizers.keys())) | |
| ) | |
| return self.optimizers[k].param_groups[0]["lr"] | |
| def state_dict(self): | |
| """Return the LR scheduler state dict.""" | |
| return {k: s.state_dict() for k, s in self.optimizers.items()} | |
| def load_state_dict(self, state_dict, optimizer_overrides=None): | |
| """Load an LR scheduler state dict.""" | |
| for k, state in state_dict.items(): | |
| if k not in self.optimizers: | |
| # skip extra keys like "loss_scale" added by fp16 optimizer | |
| continue | |
| overrides = ( | |
| optimizer_overrides[k] | |
| if isinstance(optimizer_overrides, dict) and k in optimizer_overrides | |
| else None | |
| ) | |
| self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides) | |
| class CompositeOptimizer(torch.optim.Optimizer): | |
| def __init__(self, optimizers: Dict[str, FairseqOptimizer]): | |
| self.optimizers = optimizers | |
| def supports_memory_efficient_fp16(self): | |
| return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values()) | |
| def supports_flat_params(self): | |
| return all(o.supports_flat_params for o in self.optimizers.values()) | |
| def step(self, closure=None, groups=None): | |
| """Performs a single optimization step. | |
| Args: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| for k, opt in self.optimizers.items(): | |
| if groups is None or k in groups: | |
| opt.step() | |
| return loss | |
| def zero_grad(self): | |
| for opt in self.optimizers.values(): | |
| opt.zero_grad() | |
| class CompositeLRScheduler(FairseqLRScheduler): | |
| def __init__(self, lr_schedulers): | |
| super().__init__(None, None) | |
| self.lr_schedulers = lr_schedulers | |
| def state_dict(self): | |
| """Return the LR scheduler state dict.""" | |
| return {k: s.state_dict() for k, s in self.lr_schedulers.items()} | |
| def load_state_dict(self, state_dict): | |
| """Load an LR scheduler state dict.""" | |
| for k, state in state_dict.items(): | |
| self.lr_schedulers[k].load_state_dict(state) | |
| def step_begin_epoch(self, epoch): | |
| """Update the learning rate at the beginning of the given epoch.""" | |
| for s in self.lr_schedulers.values(): | |
| s.step_begin_epoch(epoch) | |
| def step(self, epoch, val_loss=None): | |
| """Update the learning rate at the end of the given epoch.""" | |
| for s in self.lr_schedulers.values(): | |
| s.step(epoch) | |
| def step_update(self, num_updates): | |
| """Update the learning rate after each update.""" | |
| return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()} | |