"""AdamW optimizer creation with Weight Decay separation.""" import torch import torch.nn as nn from llm_lab.config import TrainConfig def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW: """Creates an AdamW optimizer. Weight Decay separation rules: - Apply decay: Linear weights (attention proj, FFN, etc.) - No decay: Embeddings, LayerNorm/RMSNorm, Bias Why separate? - Weight Decay penalizes large weights to prevent overfitting - However, applying it to Norm scale parameters interferes with normalization - Applying it to Embeddings causes rare token representations to shrink toward 0 - It is convention to exclude 1D parameters (bias, norm weight) from decay """ # Separate parameters into decay / no-decay groups decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # 1D tensors (bias, norm weight) or embedding → no decay if param.dim() <= 1 or "embedding" in name: no_decay_params.append(param) else: decay_params.append(param) param_groups = [ {"params": decay_params, "weight_decay": config.weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] n_decay = sum(p.numel() for p in decay_params) n_no_decay = sum(p.numel() for p in no_decay_params) print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)") print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)") optimizer = torch.optim.AdamW( param_groups, lr=config.learning_rate, betas=(config.beta1, config.beta2), eps=config.adam_eps, fused=torch.cuda.is_available(), # CUDA fused AdamW (faster) ) return optimizer