lingbot-vla / lingbotvla /optim /optimizer.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from ..utils.import_utils import is_torch_npu_available
# https://github.com/meta-llama/llama-recipes/blob/v0.0.4/src/llama_recipes/policies/anyprecision_optimizer.py
class AnyPrecisionAdamW(Optimizer):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.0,
use_kahan_summation=True,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
compensation_buffer_dtype=torch.bfloat16,
):
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"use_kahan_summation": use_kahan_summation,
"momentum_dtype": momentum_dtype,
"variance_dtype": variance_dtype,
"compensation_buffer_dtype": compensation_buffer_dtype,
}
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
if closure is not None:
with torch.enable_grad():
closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
weight_decay = group["weight_decay"]
eps = group["eps"]
use_kahan_summation = group["use_kahan_summation"]
momentum_dtype = group["momentum_dtype"]
variance_dtype = group["variance_dtype"]
compensation_buffer_dtype = group["compensation_buffer_dtype"]
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("AnyPrecisionAdamW does not support sparse gradients.")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
# momentum - EMA of gradient values
state["exp_avg"] = torch.zeros_like(p, dtype=momentum_dtype)
# variance uncentered - EMA of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, dtype=variance_dtype)
# optional Kahan summation - accumulated error tracker
if use_kahan_summation:
state["compensation"] = torch.zeros_like(p, dtype=compensation_buffer_dtype)
# Main processing
# update the steps for each param group update
state["step"] += 1
step = state["step"]
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
grad = p.grad
if weight_decay: # weight decay, AdamW style
p.data.mul_(1 - lr * weight_decay)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # update momentum
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # update uncentered variance
bias_correction1 = 1 - beta1**step # adjust using bias1
step_size = lr / bias_correction1
denom_correction = (1 - beta2**step) ** 0.5 # adjust using bias2 and avoids math import
centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(eps, alpha=1)
if use_kahan_summation: # lr update to compensation
compensation = state["compensation"]
compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
# update weights with compensation (Kahan summation)
# save error back to compensation for next iteration
temp_buffer = p.detach().clone()
p.data.add_(compensation)
compensation.add_(temp_buffer.sub_(p.data))
else: # usual AdamW updates
p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
def build_optimizer(
model: "nn.Module",
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
eps: float = 1e-8,
weight_decay: float = 1e-2,
fused: bool = False,
optimizer_type: str = "adamw",
param_groups: Optional[Sequence[Dict[str, Any]]] = None,
post_training=False,
) -> "torch.optim.Optimizer":
if param_groups is None:
align_parameters = [
name for name, _ in model.named_parameters() if "depth" in name
]
if len(align_parameters) > 0:
lr_gain = 10.0 if not post_training else 1.0
param_groups = [
{
"params": [
p
for n, p in model.named_parameters()
if (p.requires_grad and n not in align_parameters)
],
"lr": lr,
},
{
"params": [
p
for n, p in model.named_parameters()
if (p.requires_grad and n in align_parameters)
],
"lr": lr * lr_gain,
}
]
else:
param_groups = filter(lambda p: p.requires_grad, model.parameters())
if optimizer_type == "adamw":
foreach = False if is_torch_npu_available() else (not fused)
fused = False if is_torch_npu_available() else fused
optim = AdamW(param_groups, lr, betas, eps, weight_decay, fused=fused, foreach=foreach)
elif optimizer_type == "anyprecision_adamw":
optim = AnyPrecisionAdamW(param_groups, lr, betas, eps, weight_decay)
else:
raise ValueError("Only adamw and anyprecision_adamw are supported as optimizers.")
return optim