File size: 1,289 Bytes
54c5666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from typing import Iterable


def build_optimizer(model: torch.nn.Module, args) -> torch.optim.Optimizer:
    """Create an AdamW optimizer with common defaults.

    If a fused optimizer is available and requested later, hook it up here.

    """
    kwargs = dict(
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2),
        eps=1e-8,
    )
    # Use fused AdamW if available and requested
    use_fused = bool(getattr(args, "fused_adam", False)) and torch.cuda.is_available()
    try:
        if use_fused:
            # PyTorch AdamW may accept fused=True on recent versions
            return torch.optim.AdamW(model.parameters(), fused=True, **kwargs)  # type: ignore[arg-type]
    except TypeError:
        # Fallback to non-fused if unsupported
        pass
    return torch.optim.AdamW(model.parameters(), **kwargs)


def clip_grads(parameters: Iterable[torch.nn.Parameter], scaler, max_norm: float) -> None:
    """Gradient clipping that is AMP-aware (unscales first if using a scaler)."""
    if max_norm is None or max_norm <= 0:
        return
    if scaler is not None:
        scaler.unscale_(parameters)
    torch.nn.utils.clip_grad_norm_(parameters, max_norm)