File size: 1,872 Bytes
858e8b2
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
 
 
8a58ffe
858e8b2
 
 
 
 
8a58ffe
858e8b2
8a58ffe
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
858e8b2
 
8a58ffe
 
 
 
 
 
858e8b2
8a58ffe
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""AdamW optimizer creation with Weight Decay separation."""

import torch
import torch.nn as nn

from llm_lab.config import TrainConfig


def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
    """Creates an AdamW optimizer.

    Weight Decay separation rules:
      - Apply decay: Linear weights (attention proj, FFN, etc.)
      - No decay: Embeddings, LayerNorm/RMSNorm, Bias

    Why separate?
      - Weight Decay penalizes large weights to prevent overfitting
      - However, applying it to Norm scale parameters interferes with normalization
      - Applying it to Embeddings causes rare token representations to shrink toward 0
      - It is convention to exclude 1D parameters (bias, norm weight) from decay
    """
    # Separate parameters into decay / no-decay groups
    decay_params = []
    no_decay_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # 1D tensors (bias, norm weight) or embedding → no decay
        if param.dim() <= 1 or "embedding" in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    param_groups = [
        {"params": decay_params, "weight_decay": config.weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ]

    n_decay = sum(p.numel() for p in decay_params)
    n_no_decay = sum(p.numel() for p in no_decay_params)
    print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)")
    print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")

    optimizer = torch.optim.AdamW(
        param_groups,
        lr=config.learning_rate,
        betas=(config.beta1, config.beta2),
        eps=config.adam_eps,
        fused=torch.cuda.is_available(),  # CUDA fused AdamW (faster)
    )

    return optimizer