Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
import torch
from typing import Iterable
def build_optimizer(model: torch.nn.Module, args) -> torch.optim.Optimizer:
"""Create an AdamW optimizer with common defaults.
If a fused optimizer is available and requested later, hook it up here.
"""
kwargs = dict(
lr=args.learning_rate,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=1e-8,
)
# Use fused AdamW if available and requested
use_fused = bool(getattr(args, "fused_adam", False)) and torch.cuda.is_available()
try:
if use_fused:
# PyTorch AdamW may accept fused=True on recent versions
return torch.optim.AdamW(model.parameters(), fused=True, **kwargs) # type: ignore[arg-type]
except TypeError:
# Fallback to non-fused if unsupported
pass
return torch.optim.AdamW(model.parameters(), **kwargs)
def clip_grads(parameters: Iterable[torch.nn.Parameter], scaler, max_norm: float) -> None:
"""Gradient clipping that is AMP-aware (unscales first if using a scaler)."""
if max_norm is None or max_norm <= 0:
return
if scaler is not None:
scaler.unscale_(parameters)
torch.nn.utils.clip_grad_norm_(parameters, max_norm)