WJAD / src /wjad /train /schedule.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
raw
history blame contribute delete
885 Bytes
"""学习率调度:线性 warmup + 余弦退火。"""
from __future__ import annotations
import math
import torch
def build_scheduler(
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
base_lr: float,
min_lr: float,
) -> torch.optim.lr_scheduler.LambdaLR:
"""返回 ``LambdaLR``,其中 lr_factor = current_lr / base_lr。"""
def lr_lambda(step: int) -> float:
if step < warmup_steps:
return float(step + 1) / max(1, warmup_steps)
# 余弦退火
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
progress = min(max(progress, 0.0), 1.0)
cos = 0.5 * (1.0 + math.cos(math.pi * progress))
ratio = (min_lr + (base_lr - min_lr) * cos) / max(base_lr, 1e-12)
return float(ratio)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)