import torch from typing import Iterable def build_optimizer(model: torch.nn.Module, args) -> torch.optim.Optimizer: """Create an AdamW optimizer with common defaults. If a fused optimizer is available and requested later, hook it up here. """ kwargs = dict( lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=1e-8, ) # Use fused AdamW if available and requested use_fused = bool(getattr(args, "fused_adam", False)) and torch.cuda.is_available() try: if use_fused: # PyTorch AdamW may accept fused=True on recent versions return torch.optim.AdamW(model.parameters(), fused=True, **kwargs) # type: ignore[arg-type] except TypeError: # Fallback to non-fused if unsupported pass return torch.optim.AdamW(model.parameters(), **kwargs) def clip_grads(parameters: Iterable[torch.nn.Parameter], scaler, max_norm: float) -> None: """Gradient clipping that is AMP-aware (unscales first if using a scaler).""" if max_norm is None or max_norm <= 0: return if scaler is not None: scaler.unscale_(parameters) torch.nn.utils.clip_grad_norm_(parameters, max_norm)