# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. # Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Sequence, Tuple import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.optimizer import Optimizer from ..utils.import_utils import is_torch_npu_available # https://github.com/meta-llama/llama-recipes/blob/v0.0.4/src/llama_recipes/policies/anyprecision_optimizer.py class AnyPrecisionAdamW(Optimizer): def __init__( self, params, lr=1e-3, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.0, use_kahan_summation=True, momentum_dtype=torch.bfloat16, variance_dtype=torch.bfloat16, compensation_buffer_dtype=torch.bfloat16, ): defaults = { "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "use_kahan_summation": use_kahan_summation, "momentum_dtype": momentum_dtype, "variance_dtype": variance_dtype, "compensation_buffer_dtype": compensation_buffer_dtype, } super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): """ Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ if closure is not None: with torch.enable_grad(): closure() for group in self.param_groups: beta1, beta2 = group["betas"] lr = group["lr"] weight_decay = group["weight_decay"] eps = group["eps"] use_kahan_summation = group["use_kahan_summation"] momentum_dtype = group["momentum_dtype"] variance_dtype = group["variance_dtype"] compensation_buffer_dtype = group["compensation_buffer_dtype"] for p in group["params"]: if p.grad is None: continue if p.grad.is_sparse: raise RuntimeError("AnyPrecisionAdamW does not support sparse gradients.") state = self.state[p] # State initialization if len(state) == 0: state["step"] = torch.tensor(0.0) # momentum - EMA of gradient values state["exp_avg"] = torch.zeros_like(p, dtype=momentum_dtype) # variance uncentered - EMA of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p, dtype=variance_dtype) # optional Kahan summation - accumulated error tracker if use_kahan_summation: state["compensation"] = torch.zeros_like(p, dtype=compensation_buffer_dtype) # Main processing # update the steps for each param group update state["step"] += 1 step = state["step"] exp_avg = state["exp_avg"] exp_avg_sq = state["exp_avg_sq"] grad = p.grad if weight_decay: # weight decay, AdamW style p.data.mul_(1 - lr * weight_decay) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # update momentum exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # update uncentered variance bias_correction1 = 1 - beta1**step # adjust using bias1 step_size = lr / bias_correction1 denom_correction = (1 - beta2**step) ** 0.5 # adjust using bias2 and avoids math import centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(eps, alpha=1) if use_kahan_summation: # lr update to compensation compensation = state["compensation"] compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) # update weights with compensation (Kahan summation) # save error back to compensation for next iteration temp_buffer = p.detach().clone() p.data.add_(compensation) compensation.add_(temp_buffer.sub_(p.data)) else: # usual AdamW updates p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) def build_optimizer( model: "nn.Module", lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-8, weight_decay: float = 1e-2, fused: bool = False, optimizer_type: str = "adamw", param_groups: Optional[Sequence[Dict[str, Any]]] = None, post_training=False, ) -> "torch.optim.Optimizer": if param_groups is None: align_parameters = [ name for name, _ in model.named_parameters() if "depth" in name ] if len(align_parameters) > 0: lr_gain = 10.0 if not post_training else 1.0 param_groups = [ { "params": [ p for n, p in model.named_parameters() if (p.requires_grad and n not in align_parameters) ], "lr": lr, }, { "params": [ p for n, p in model.named_parameters() if (p.requires_grad and n in align_parameters) ], "lr": lr * lr_gain, } ] else: param_groups = filter(lambda p: p.requires_grad, model.parameters()) if optimizer_type == "adamw": foreach = False if is_torch_npu_available() else (not fused) fused = False if is_torch_npu_available() else fused optim = AdamW(param_groups, lr, betas, eps, weight_decay, fused=fused, foreach=foreach) elif optimizer_type == "anyprecision_adamw": optim = AnyPrecisionAdamW(param_groups, lr, betas, eps, weight_decay) else: raise ValueError("Only adamw and anyprecision_adamw are supported as optimizers.") return optim