"""WSD (Warmup-Stable-Decay) LR scheduler — manual implementation. Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature: Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr Phase 2 (stable, 60-80%): constant peak_lr Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1 Beats cosine for: - data-limited regimes (we can extend stable phase if loss still falls) - continue-pretrain (sharp decay enables clean fine-tune handoff) """ from __future__ import annotations import math import torch from torch.optim.lr_scheduler import LambdaLR def wsd_lr_schedule(step: int, total_steps: int, warmup_steps: int = 500, stable_frac: float = 0.80, decay_frac: float = 0.15, min_lr_ratio: float = 0.1, decay_type: str = "linear") -> float: """Return LR multiplier in [min_lr_ratio, 1.0] for a given step.""" if step < warmup_steps: return step / max(1, warmup_steps) # remainder of steps after warmup remaining = total_steps - warmup_steps if remaining <= 0: return 1.0 stable_steps = int(stable_frac * remaining) decay_steps = int(decay_frac * remaining) pos = step - warmup_steps if pos < stable_steps: return 1.0 decay_pos = pos - stable_steps if decay_pos >= decay_steps: return min_lr_ratio progress = decay_pos / max(1, decay_steps) if decay_type == "linear": return 1.0 - (1.0 - min_lr_ratio) * progress elif decay_type == "cosine": return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress)) elif decay_type == "inv_sqrt": return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10)) else: raise ValueError(f"unknown decay_type: {decay_type}") def get_wsd_scheduler(optimizer: torch.optim.Optimizer, total_steps: int, warmup_steps: int = 500, stable_frac: float = 0.80, decay_frac: float = 0.15, min_lr_ratio: float = 0.1, decay_type: str = "linear") -> LambdaLR: """Build a LambdaLR scheduler with WSD schedule.""" def fn(step): return wsd_lr_schedule(step, total_steps, warmup_steps, stable_frac, decay_frac, min_lr_ratio, decay_type) return LambdaLR(optimizer, lr_lambda=fn) if __name__ == "__main__": # Visualize the schedule total = 10000 warmup = 500 print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%") print(f" step lr_mult") for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]: m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear") print(f" {s:>5} {m:.4f}")