| """AdamW optimizer creation with Weight Decay separation.""" |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from llm_lab.config import TrainConfig |
|
|
|
|
| def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW: |
| """Creates an AdamW optimizer. |
| |
| Weight Decay separation rules: |
| - Apply decay: Linear weights (attention proj, FFN, etc.) |
| - No decay: Embeddings, LayerNorm/RMSNorm, Bias |
| |
| Why separate? |
| - Weight Decay penalizes large weights to prevent overfitting |
| - However, applying it to Norm scale parameters interferes with normalization |
| - Applying it to Embeddings causes rare token representations to shrink toward 0 |
| - It is convention to exclude 1D parameters (bias, norm weight) from decay |
| """ |
| |
| decay_params = [] |
| no_decay_params = [] |
|
|
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
|
|
| |
| if param.dim() <= 1 or "embedding" in name: |
| no_decay_params.append(param) |
| else: |
| decay_params.append(param) |
|
|
| param_groups = [ |
| {"params": decay_params, "weight_decay": config.weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
|
|
| n_decay = sum(p.numel() for p in decay_params) |
| n_no_decay = sum(p.numel() for p in no_decay_params) |
| print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)") |
| print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)") |
|
|
| optimizer = torch.optim.AdamW( |
| param_groups, |
| lr=config.learning_rate, |
| betas=(config.beta1, config.beta2), |
| eps=config.adam_eps, |
| fused=torch.cuda.is_available(), |
| ) |
|
|
| return optimizer |
|
|