File size: 2,012 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Optimizer and scheduler factories."""

from __future__ import annotations

import math
from dataclasses import dataclass

import torch


@dataclass(frozen=True)
class ScheduleConfig:
    """Training schedule settings."""

    peak_learning_rate: float = 3.0e-4
    min_learning_rate: float = 3.0e-5
    warmup_steps: int = 2000
    weight_decay: float = 0.1
    betas: tuple[float, float] = (0.9, 0.95)
    adam_eps: float = 1.0e-8
    total_steps: int = 25_000


def create_optimizer(model: torch.nn.Module, config: ScheduleConfig) -> torch.optim.Optimizer:
    """Create an AdamW optimizer with correct weight-decay exclusions."""
    decay: list[torch.nn.Parameter] = []
    no_decay: list[torch.nn.Parameter] = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if param.ndim == 1 or "norm" in name:
            no_decay.append(param)
        else:
            decay.append(param)
    return torch.optim.AdamW(
        [
            {"params": decay, "weight_decay": config.weight_decay},
            {"params": no_decay, "weight_decay": 0.0},
        ],
        lr=config.peak_learning_rate,
        betas=config.betas,
        eps=config.adam_eps,
    )


def lr_lambda(current_step: int, config: ScheduleConfig) -> float:
    """Warm up linearly and then decay with cosine."""
    if current_step < config.warmup_steps:
        return float(current_step + 1) / float(max(1, config.warmup_steps))
    progress = (current_step - config.warmup_steps) / float(max(1, config.total_steps - config.warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    floor = config.min_learning_rate / config.peak_learning_rate
    return floor + (1.0 - floor) * cosine


def create_scheduler(optimizer: torch.optim.Optimizer, config: ScheduleConfig) -> torch.optim.lr_scheduler.LambdaLR:
    """Create the training LR scheduler."""
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step, config))