| """学习率调度:线性 warmup + 余弦退火。""" |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
|
|
|
|
| def build_scheduler( |
| optimizer: torch.optim.Optimizer, |
| warmup_steps: int, |
| total_steps: int, |
| base_lr: float, |
| min_lr: float, |
| ) -> torch.optim.lr_scheduler.LambdaLR: |
| """返回 ``LambdaLR``,其中 lr_factor = current_lr / base_lr。""" |
|
|
| def lr_lambda(step: int) -> float: |
| if step < warmup_steps: |
| return float(step + 1) / max(1, warmup_steps) |
| |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| progress = min(max(progress, 0.0), 1.0) |
| cos = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| ratio = (min_lr + (base_lr - min_lr) * cos) / max(base_lr, 1e-12) |
| return float(ratio) |
|
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|