gemeo-sus / src /wsd_scheduler.py
timmers's picture
GEMEO/SUS v6 recurrence-aware (RAVEN) — new-onset Top-1 60.1% vs baseline 38.2%, defeats autocorrelation trap. GEMEO Arch v2.0 Principle 7 proven.
908ea05 verified
"""WSD (Warmup-Stable-Decay) LR scheduler — manual implementation.
Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature:
Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr
Phase 2 (stable, 60-80%): constant peak_lr
Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1
Beats cosine for:
- data-limited regimes (we can extend stable phase if loss still falls)
- continue-pretrain (sharp decay enables clean fine-tune handoff)
"""
from __future__ import annotations
import math
import torch
from torch.optim.lr_scheduler import LambdaLR
def wsd_lr_schedule(step: int, total_steps: int,
warmup_steps: int = 500,
stable_frac: float = 0.80,
decay_frac: float = 0.15,
min_lr_ratio: float = 0.1,
decay_type: str = "linear") -> float:
"""Return LR multiplier in [min_lr_ratio, 1.0] for a given step."""
if step < warmup_steps:
return step / max(1, warmup_steps)
# remainder of steps after warmup
remaining = total_steps - warmup_steps
if remaining <= 0:
return 1.0
stable_steps = int(stable_frac * remaining)
decay_steps = int(decay_frac * remaining)
pos = step - warmup_steps
if pos < stable_steps:
return 1.0
decay_pos = pos - stable_steps
if decay_pos >= decay_steps:
return min_lr_ratio
progress = decay_pos / max(1, decay_steps)
if decay_type == "linear":
return 1.0 - (1.0 - min_lr_ratio) * progress
elif decay_type == "cosine":
return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
elif decay_type == "inv_sqrt":
return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10))
else:
raise ValueError(f"unknown decay_type: {decay_type}")
def get_wsd_scheduler(optimizer: torch.optim.Optimizer,
total_steps: int,
warmup_steps: int = 500,
stable_frac: float = 0.80,
decay_frac: float = 0.15,
min_lr_ratio: float = 0.1,
decay_type: str = "linear") -> LambdaLR:
"""Build a LambdaLR scheduler with WSD schedule."""
def fn(step):
return wsd_lr_schedule(step, total_steps, warmup_steps,
stable_frac, decay_frac, min_lr_ratio, decay_type)
return LambdaLR(optimizer, lr_lambda=fn)
if __name__ == "__main__":
# Visualize the schedule
total = 10000
warmup = 500
print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%")
print(f" step lr_mult")
for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]:
m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear")
print(f" {s:>5} {m:.4f}")