| |
| |
| |
| |
|
|
| import importlib |
| from collections.abc import Collection |
| from dataclasses import dataclass, field |
| from typing import List |
|
|
| import torch |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.optim import FairseqOptimizer, register_optimizer |
| from omegaconf import II, DictConfig |
|
|
|
|
| try: |
| import deepspeed |
|
|
| has_deepspeed = True |
| except ImportError as e: |
| has_deepspeed = False |
|
|
|
|
| def _get_cpu_adam(): |
| try: |
| from deepspeed.ops.op_builder import CPUAdamBuilder |
|
|
| return CPUAdamBuilder().load() |
| except ImportError: |
| |
| from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam |
|
|
| return ds_opt_adam |
|
|
|
|
| @dataclass |
| class FairseqCPUAdamConfig(FairseqDataclass): |
| adam_betas: str = field( |
| default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} |
| ) |
| adam_eps: float = field( |
| default=1e-8, metadata={"help": "epsilon for Adam optimizer"} |
| ) |
| weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) |
| fp16_adam_stats: bool = field( |
| default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} |
| ) |
| |
| lr: List[float] = II("optimization.lr") |
|
|
|
|
| @register_optimizer("cpu_adam", dataclass=FairseqCPUAdamConfig) |
| class FairseqCPUAdam(FairseqOptimizer): |
| """Adam optimizer for fairseq, optimized for CPU tensors. |
| |
| Important note: this optimizer corresponds to the "AdamW" variant of |
| Adam in its weight decay behavior. As such, it is most closely |
| analogous to torch.optim.AdamW from PyTorch. |
| """ |
|
|
| def __init__(self, cfg: DictConfig, params): |
| super().__init__(cfg) |
| self._optimizer = CPUAdam(params, **self.optimizer_config) |
|
|
| @property |
| def optimizer_config(self): |
| """ |
| Return a kwarg dictionary that will be used to override optimizer |
| args stored in checkpoints. This allows us to load a checkpoint and |
| resume training using a different set of optimizer args, e.g., with a |
| different learning rate. |
| """ |
| return { |
| "lr": self.cfg.lr[0] |
| if isinstance(self.cfg.lr, Collection) |
| else self.cfg.lr, |
| "betas": eval(self.cfg.adam_betas), |
| "eps": self.cfg.adam_eps, |
| "weight_decay": self.cfg.weight_decay, |
| "use_fp16_stats": self.cfg.fp16_adam_stats, |
| } |
|
|
|
|
| class CPUAdam(torch.optim.Optimizer): |
|
|
| optimizer_id = 0 |
|
|
| def __init__( |
| self, |
| params, |
| lr=1e-3, |
| bias_correction=True, |
| betas=(0.9, 0.999), |
| eps=1e-8, |
| weight_decay=0, |
| use_fp16_stats=False, |
| ): |
| defaults = { |
| "lr": lr, |
| "bias_correction": bias_correction, |
| "betas": betas, |
| "eps": eps, |
| "weight_decay": weight_decay, |
| } |
| super().__init__(params, defaults) |
|
|
| self.use_fp16_stats = use_fp16_stats |
| self.FLOAT16_MAX = 65504.0 |
|
|
| if not has_deepspeed: |
| raise ImportError("Please install DeepSpeed: pip install deepspeed") |
|
|
| self.opt_id = CPUAdam.optimizer_id |
| CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 |
|
|
| self.ds_opt_adam = _get_cpu_adam() |
| adamw_mode = True |
| self.ds_opt_adam.create_adam( |
| self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode |
| ) |
|
|
| @property |
| def supports_memory_efficient_fp16(self): |
| return True |
|
|
| @property |
| def supports_flat_params(self): |
| return True |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
|
|
| torch.cuda.synchronize() |
|
|
| for group_id, group in enumerate(self.param_groups): |
| for param_id, p in enumerate(group["params"]): |
| if p.grad is None: |
| continue |
|
|
| state = self.state[p] |
| if len(state) == 0: |
| state["step"] = 0 |
| dtype = torch.float16 if self.use_fp16_stats else p.data.dtype |
| |
| state["exp_avg"] = torch.zeros_like( |
| p.data, dtype=dtype, device="cpu" |
| ) |
| |
| state["exp_avg_sq"] = torch.zeros_like( |
| p.data, dtype=dtype, device="cpu" |
| ) |
| if self.use_fp16_stats: |
| assert torch.is_floating_point(p.data) |
| state["exp_avg_scale"] = 1.0 |
| state["exp_avg_sq_scale"] = 1.0 |
|
|
| exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
|
|
| p_data_bak = p.data |
|
|
| p.data = p.data.to(dtype=torch.float32, device="cpu") |
| p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu") |
|
|
| if self.use_fp16_stats: |
| exp_avg = exp_avg.float() * state["exp_avg_scale"] |
| exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] |
|
|
| state["step"] += 1 |
| beta1, beta2 = group["betas"] |
|
|
| self.ds_opt_adam.adam_update( |
| self.opt_id, |
| state["step"], |
| group["lr"], |
| beta1, |
| beta2, |
| group["eps"], |
| group["weight_decay"], |
| group["bias_correction"], |
| p.data, |
| p.grad.data, |
| exp_avg, |
| exp_avg_sq, |
| ) |
|
|
| if p_data_bak.data_ptr() != p.data.data_ptr(): |
| p_data_bak.copy_(p.data) |
| p.data = p_data_bak |
|
|
| if self.use_fp16_stats: |
|
|
| def inf_norm(t): |
| return torch.norm(t, float("inf")) |
|
|
| |
| state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( |
| 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, |
| 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, |
| ) |
| state["exp_avg"], state["exp_avg_sq"] = ( |
| (exp_avg / state["exp_avg_scale"]).half(), |
| (exp_avg_sq / state["exp_avg_sq_scale"]).half(), |
| ) |
|
|
| return loss |
|
|