| """Optimizer and scheduler factories.""" |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
|
|
|
|
| @dataclass(frozen=True) |
| class ScheduleConfig: |
| """Training schedule settings.""" |
|
|
| peak_learning_rate: float = 3.0e-4 |
| min_learning_rate: float = 3.0e-5 |
| warmup_steps: int = 2000 |
| weight_decay: float = 0.1 |
| betas: tuple[float, float] = (0.9, 0.95) |
| adam_eps: float = 1.0e-8 |
| total_steps: int = 25_000 |
|
|
|
|
| def create_optimizer(model: torch.nn.Module, config: ScheduleConfig) -> torch.optim.Optimizer: |
| """Create an AdamW optimizer with correct weight-decay exclusions.""" |
| decay: list[torch.nn.Parameter] = [] |
| no_decay: list[torch.nn.Parameter] = [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if param.ndim == 1 or "norm" in name: |
| no_decay.append(param) |
| else: |
| decay.append(param) |
| return torch.optim.AdamW( |
| [ |
| {"params": decay, "weight_decay": config.weight_decay}, |
| {"params": no_decay, "weight_decay": 0.0}, |
| ], |
| lr=config.peak_learning_rate, |
| betas=config.betas, |
| eps=config.adam_eps, |
| ) |
|
|
|
|
| def lr_lambda(current_step: int, config: ScheduleConfig) -> float: |
| """Warm up linearly and then decay with cosine.""" |
| if current_step < config.warmup_steps: |
| return float(current_step + 1) / float(max(1, config.warmup_steps)) |
| progress = (current_step - config.warmup_steps) / float(max(1, config.total_steps - config.warmup_steps)) |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| floor = config.min_learning_rate / config.peak_learning_rate |
| return floor + (1.0 - floor) * cosine |
|
|
|
|
| def create_scheduler(optimizer: torch.optim.Optimizer, config: ScheduleConfig) -> torch.optim.lr_scheduler.LambdaLR: |
| """Create the training LR scheduler.""" |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step, config)) |
|
|