| """WSD (Warmup-Stable-Decay) LR scheduler — manual implementation. |
| |
| Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature: |
| Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr |
| Phase 2 (stable, 60-80%): constant peak_lr |
| Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1 |
| |
| Beats cosine for: |
| - data-limited regimes (we can extend stable phase if loss still falls) |
| - continue-pretrain (sharp decay enables clean fine-tune handoff) |
| """ |
| from __future__ import annotations |
| import math |
| import torch |
| from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
| def wsd_lr_schedule(step: int, total_steps: int, |
| warmup_steps: int = 500, |
| stable_frac: float = 0.80, |
| decay_frac: float = 0.15, |
| min_lr_ratio: float = 0.1, |
| decay_type: str = "linear") -> float: |
| """Return LR multiplier in [min_lr_ratio, 1.0] for a given step.""" |
| if step < warmup_steps: |
| return step / max(1, warmup_steps) |
| |
| remaining = total_steps - warmup_steps |
| if remaining <= 0: |
| return 1.0 |
| stable_steps = int(stable_frac * remaining) |
| decay_steps = int(decay_frac * remaining) |
| pos = step - warmup_steps |
| if pos < stable_steps: |
| return 1.0 |
| decay_pos = pos - stable_steps |
| if decay_pos >= decay_steps: |
| return min_lr_ratio |
| progress = decay_pos / max(1, decay_steps) |
| if decay_type == "linear": |
| return 1.0 - (1.0 - min_lr_ratio) * progress |
| elif decay_type == "cosine": |
| return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress)) |
| elif decay_type == "inv_sqrt": |
| return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10)) |
| else: |
| raise ValueError(f"unknown decay_type: {decay_type}") |
|
|
|
|
| def get_wsd_scheduler(optimizer: torch.optim.Optimizer, |
| total_steps: int, |
| warmup_steps: int = 500, |
| stable_frac: float = 0.80, |
| decay_frac: float = 0.15, |
| min_lr_ratio: float = 0.1, |
| decay_type: str = "linear") -> LambdaLR: |
| """Build a LambdaLR scheduler with WSD schedule.""" |
| def fn(step): |
| return wsd_lr_schedule(step, total_steps, warmup_steps, |
| stable_frac, decay_frac, min_lr_ratio, decay_type) |
| return LambdaLR(optimizer, lr_lambda=fn) |
|
|
|
|
| if __name__ == "__main__": |
| |
| total = 10000 |
| warmup = 500 |
| print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%") |
| print(f" step lr_mult") |
| for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]: |
| m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear") |
| print(f" {s:>5} {m:.4f}") |
|
|