File size: 2,887 Bytes
a0fa886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""WSD (Warmup-Stable-Decay) LR scheduler — manual implementation.

Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature:
  Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr
  Phase 2 (stable, 60-80%):              constant peak_lr
  Phase 3 (decay, 10-25%):               linear or 1/sqrt to peak_lr * 0.1

Beats cosine for:
  - data-limited regimes (we can extend stable phase if loss still falls)
  - continue-pretrain (sharp decay enables clean fine-tune handoff)
"""
from __future__ import annotations
import math
import torch
from torch.optim.lr_scheduler import LambdaLR


def wsd_lr_schedule(step: int, total_steps: int,
                    warmup_steps: int = 500,
                    stable_frac: float = 0.80,
                    decay_frac: float = 0.15,
                    min_lr_ratio: float = 0.1,
                    decay_type: str = "linear") -> float:
    """Return LR multiplier in [min_lr_ratio, 1.0] for a given step."""
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    # remainder of steps after warmup
    remaining = total_steps - warmup_steps
    if remaining <= 0:
        return 1.0
    stable_steps = int(stable_frac * remaining)
    decay_steps = int(decay_frac * remaining)
    pos = step - warmup_steps
    if pos < stable_steps:
        return 1.0
    decay_pos = pos - stable_steps
    if decay_pos >= decay_steps:
        return min_lr_ratio
    progress = decay_pos / max(1, decay_steps)
    if decay_type == "linear":
        return 1.0 - (1.0 - min_lr_ratio) * progress
    elif decay_type == "cosine":
        return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
    elif decay_type == "inv_sqrt":
        return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10))
    else:
        raise ValueError(f"unknown decay_type: {decay_type}")


def get_wsd_scheduler(optimizer: torch.optim.Optimizer,
                      total_steps: int,
                      warmup_steps: int = 500,
                      stable_frac: float = 0.80,
                      decay_frac: float = 0.15,
                      min_lr_ratio: float = 0.1,
                      decay_type: str = "linear") -> LambdaLR:
    """Build a LambdaLR scheduler with WSD schedule."""
    def fn(step):
        return wsd_lr_schedule(step, total_steps, warmup_steps,
                              stable_frac, decay_frac, min_lr_ratio, decay_type)
    return LambdaLR(optimizer, lr_lambda=fn)


if __name__ == "__main__":
    # Visualize the schedule
    total = 10000
    warmup = 500
    print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%")
    print(f"  step    lr_mult")
    for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]:
        m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear")
        print(f"  {s:>5}   {m:.4f}")