"""学习率调度:线性 warmup + 余弦退火。""" from __future__ import annotations import math import torch def build_scheduler( optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float, ) -> torch.optim.lr_scheduler.LambdaLR: """返回 ``LambdaLR``,其中 lr_factor = current_lr / base_lr。""" def lr_lambda(step: int) -> float: if step < warmup_steps: return float(step + 1) / max(1, warmup_steps) # 余弦退火 progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) progress = min(max(progress, 0.0), 1.0) cos = 0.5 * (1.0 + math.cos(math.pi * progress)) ratio = (min_lr + (base_lr - min_lr) * cos) / max(base_lr, 1e-12) return float(ratio) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)