LLM-1B-Lab / llm_lab /training /optimizer.py
Vjeong's picture
docs: translate all Korean comments and docstrings to English
858e8b2
"""AdamW optimizer creation with Weight Decay separation."""
import torch
import torch.nn as nn
from llm_lab.config import TrainConfig
def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
"""Creates an AdamW optimizer.
Weight Decay separation rules:
- Apply decay: Linear weights (attention proj, FFN, etc.)
- No decay: Embeddings, LayerNorm/RMSNorm, Bias
Why separate?
- Weight Decay penalizes large weights to prevent overfitting
- However, applying it to Norm scale parameters interferes with normalization
- Applying it to Embeddings causes rare token representations to shrink toward 0
- It is convention to exclude 1D parameters (bias, norm weight) from decay
"""
# Separate parameters into decay / no-decay groups
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# 1D tensors (bias, norm weight) or embedding → no decay
if param.dim() <= 1 or "embedding" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
param_groups = [
{"params": decay_params, "weight_decay": config.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
n_decay = sum(p.numel() for p in decay_params)
n_no_decay = sum(p.numel() for p in no_decay_params)
print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)")
print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
optimizer = torch.optim.AdamW(
param_groups,
lr=config.learning_rate,
betas=(config.beta1, config.beta2),
eps=config.adam_eps,
fused=torch.cuda.is_available(), # CUDA fused AdamW (faster)
)
return optimizer