File size: 5,562 Bytes
27871e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Optimizer and learning rate scheduler for SLM training.

Uses AdamW with cosine annealing and warmup.
"""

import math
from typing import Optional, Tuple, List

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR


def create_optimizer(
    model: torch.nn.Module,
    learning_rate: float = 3e-4,
    weight_decay: float = 0.1,
    betas: Tuple[float, float] = (0.9, 0.95),
    eps: float = 1e-8,
) -> AdamW:
    """Create AdamW optimizer with weight decay.

    Applies weight decay only to 2D parameters (weights, not biases/norms).

    Args:
        model: The model to optimize
        learning_rate: Base learning rate
        weight_decay: Weight decay coefficient
        betas: Adam beta parameters
        eps: Adam epsilon for numerical stability

    Returns:
        Configured AdamW optimizer
    """
    # Separate parameters into decay and no-decay groups
    decay_params = []
    no_decay_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # No weight decay for:
        # - 1D parameters (biases, layer norms)
        # - Embedding layers
        if param.dim() == 1 or "embedding" in name.lower():
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    param_groups = [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ]

    optimizer = AdamW(
        param_groups,
        lr=learning_rate,
        betas=betas,
        eps=eps,
    )

    return optimizer


def create_scheduler(
    optimizer: torch.optim.Optimizer,
    num_training_steps: int,
    warmup_ratio: float = 0.1,
    min_lr_ratio: float = 0.1,
    scheduler_type: str = "cosine",
) -> LambdaLR:
    """Create learning rate scheduler.

    Args:
        optimizer: The optimizer to schedule
        num_training_steps: Total number of training steps
        warmup_ratio: Ratio of warmup steps (e.g., 0.1 = 10%)
        min_lr_ratio: Minimum LR as ratio of max (e.g., 0.1 = 10% of peak LR)
        scheduler_type: Type of scheduler ("cosine", "linear", "constant")

    Returns:
        LambdaLR scheduler
    """
    num_warmup_steps = int(num_training_steps * warmup_ratio)

    if scheduler_type == "cosine":
        def lr_lambda(current_step: int) -> float:
            # Warmup phase
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))

            # Cosine annealing phase
            progress = float(current_step - num_warmup_steps) / float(
                max(1, num_training_steps - num_warmup_steps)
            )
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))

            # Scale between min_lr_ratio and 1.0
            return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay

    elif scheduler_type == "linear":
        def lr_lambda(current_step: int) -> float:
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))

            progress = float(current_step - num_warmup_steps) / float(
                max(1, num_training_steps - num_warmup_steps)
            )
            return max(min_lr_ratio, 1.0 - progress * (1.0 - min_lr_ratio))

    elif scheduler_type == "constant":
        def lr_lambda(current_step: int) -> float:
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            return 1.0

    else:
        raise ValueError(f"Unknown scheduler type: {scheduler_type}")

    return LambdaLR(optimizer, lr_lambda)


def get_parameter_count(model: torch.nn.Module) -> dict:
    """Get detailed parameter count for a model.

    Args:
        model: The model to analyze

    Returns:
        Dictionary with parameter counts
    """
    total_params = 0
    trainable_params = 0
    embedding_params = 0

    for name, param in model.named_parameters():
        num_params = param.numel()
        total_params += num_params

        if param.requires_grad:
            trainable_params += num_params

        if "embedding" in name.lower():
            embedding_params += num_params

    return {
        "total": total_params,
        "trainable": trainable_params,
        "embedding": embedding_params,
        "non_embedding": total_params - embedding_params,
    }


def get_optimizer_state(optimizer: torch.optim.Optimizer) -> dict:
    """Get optimizer state statistics.

    Args:
        optimizer: The optimizer to analyze

    Returns:
        Dictionary with optimizer state info
    """
    num_params = sum(
        sum(p.numel() for p in group["params"])
        for group in optimizer.param_groups
    )

    current_lrs = [group["lr"] for group in optimizer.param_groups]

    return {
        "num_param_groups": len(optimizer.param_groups),
        "total_params": num_params,
        "learning_rates": current_lrs,
    }


def clip_grad_norm(
    model: torch.nn.Module,
    max_norm: float = 1.0,
) -> float:
    """Clip gradient norm and return the norm value.

    Args:
        model: The model with gradients
        max_norm: Maximum gradient norm

    Returns:
        The gradient norm before clipping
    """
    parameters = [p for p in model.parameters() if p.grad is not None]
    if len(parameters) == 0:
        return 0.0

    total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
    return total_norm.item()