File size: 1,872 Bytes
858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | """AdamW optimizer creation with Weight Decay separation."""
import torch
import torch.nn as nn
from llm_lab.config import TrainConfig
def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
"""Creates an AdamW optimizer.
Weight Decay separation rules:
- Apply decay: Linear weights (attention proj, FFN, etc.)
- No decay: Embeddings, LayerNorm/RMSNorm, Bias
Why separate?
- Weight Decay penalizes large weights to prevent overfitting
- However, applying it to Norm scale parameters interferes with normalization
- Applying it to Embeddings causes rare token representations to shrink toward 0
- It is convention to exclude 1D parameters (bias, norm weight) from decay
"""
# Separate parameters into decay / no-decay groups
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# 1D tensors (bias, norm weight) or embedding → no decay
if param.dim() <= 1 or "embedding" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
param_groups = [
{"params": decay_params, "weight_decay": config.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
n_decay = sum(p.numel() for p in decay_params)
n_no_decay = sum(p.numel() for p in no_decay_params)
print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)")
print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
optimizer = torch.optim.AdamW(
param_groups,
lr=config.learning_rate,
betas=(config.beta1, config.beta2),
eps=config.adam_eps,
fused=torch.cuda.is_available(), # CUDA fused AdamW (faster)
)
return optimizer
|