|
|
""" |
|
|
Optimizer and learning rate scheduler for SLM training. |
|
|
|
|
|
Uses AdamW with cosine annealing and warmup. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import torch |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
|
|
|
def create_optimizer( |
|
|
model: torch.nn.Module, |
|
|
learning_rate: float = 3e-4, |
|
|
weight_decay: float = 0.1, |
|
|
betas: Tuple[float, float] = (0.9, 0.95), |
|
|
eps: float = 1e-8, |
|
|
) -> AdamW: |
|
|
"""Create AdamW optimizer with weight decay. |
|
|
|
|
|
Applies weight decay only to 2D parameters (weights, not biases/norms). |
|
|
|
|
|
Args: |
|
|
model: The model to optimize |
|
|
learning_rate: Base learning rate |
|
|
weight_decay: Weight decay coefficient |
|
|
betas: Adam beta parameters |
|
|
eps: Adam epsilon for numerical stability |
|
|
|
|
|
Returns: |
|
|
Configured AdamW optimizer |
|
|
""" |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if param.dim() == 1 or "embedding" in name.lower(): |
|
|
no_decay_params.append(param) |
|
|
else: |
|
|
decay_params.append(param) |
|
|
|
|
|
param_groups = [ |
|
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
|
{"params": no_decay_params, "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
optimizer = AdamW( |
|
|
param_groups, |
|
|
lr=learning_rate, |
|
|
betas=betas, |
|
|
eps=eps, |
|
|
) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def create_scheduler( |
|
|
optimizer: torch.optim.Optimizer, |
|
|
num_training_steps: int, |
|
|
warmup_ratio: float = 0.1, |
|
|
min_lr_ratio: float = 0.1, |
|
|
scheduler_type: str = "cosine", |
|
|
) -> LambdaLR: |
|
|
"""Create learning rate scheduler. |
|
|
|
|
|
Args: |
|
|
optimizer: The optimizer to schedule |
|
|
num_training_steps: Total number of training steps |
|
|
warmup_ratio: Ratio of warmup steps (e.g., 0.1 = 10%) |
|
|
min_lr_ratio: Minimum LR as ratio of max (e.g., 0.1 = 10% of peak LR) |
|
|
scheduler_type: Type of scheduler ("cosine", "linear", "constant") |
|
|
|
|
|
Returns: |
|
|
LambdaLR scheduler |
|
|
""" |
|
|
num_warmup_steps = int(num_training_steps * warmup_ratio) |
|
|
|
|
|
if scheduler_type == "cosine": |
|
|
def lr_lambda(current_step: int) -> float: |
|
|
|
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
|
|
|
|
|
|
progress = float(current_step - num_warmup_steps) / float( |
|
|
max(1, num_training_steps - num_warmup_steps) |
|
|
) |
|
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
|
|
|
|
|
|
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay |
|
|
|
|
|
elif scheduler_type == "linear": |
|
|
def lr_lambda(current_step: int) -> float: |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
|
|
|
progress = float(current_step - num_warmup_steps) / float( |
|
|
max(1, num_training_steps - num_warmup_steps) |
|
|
) |
|
|
return max(min_lr_ratio, 1.0 - progress * (1.0 - min_lr_ratio)) |
|
|
|
|
|
elif scheduler_type == "constant": |
|
|
def lr_lambda(current_step: int) -> float: |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
return 1.0 |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown scheduler type: {scheduler_type}") |
|
|
|
|
|
return LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
|
|
|
def get_parameter_count(model: torch.nn.Module) -> dict: |
|
|
"""Get detailed parameter count for a model. |
|
|
|
|
|
Args: |
|
|
model: The model to analyze |
|
|
|
|
|
Returns: |
|
|
Dictionary with parameter counts |
|
|
""" |
|
|
total_params = 0 |
|
|
trainable_params = 0 |
|
|
embedding_params = 0 |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
num_params = param.numel() |
|
|
total_params += num_params |
|
|
|
|
|
if param.requires_grad: |
|
|
trainable_params += num_params |
|
|
|
|
|
if "embedding" in name.lower(): |
|
|
embedding_params += num_params |
|
|
|
|
|
return { |
|
|
"total": total_params, |
|
|
"trainable": trainable_params, |
|
|
"embedding": embedding_params, |
|
|
"non_embedding": total_params - embedding_params, |
|
|
} |
|
|
|
|
|
|
|
|
def get_optimizer_state(optimizer: torch.optim.Optimizer) -> dict: |
|
|
"""Get optimizer state statistics. |
|
|
|
|
|
Args: |
|
|
optimizer: The optimizer to analyze |
|
|
|
|
|
Returns: |
|
|
Dictionary with optimizer state info |
|
|
""" |
|
|
num_params = sum( |
|
|
sum(p.numel() for p in group["params"]) |
|
|
for group in optimizer.param_groups |
|
|
) |
|
|
|
|
|
current_lrs = [group["lr"] for group in optimizer.param_groups] |
|
|
|
|
|
return { |
|
|
"num_param_groups": len(optimizer.param_groups), |
|
|
"total_params": num_params, |
|
|
"learning_rates": current_lrs, |
|
|
} |
|
|
|
|
|
|
|
|
def clip_grad_norm( |
|
|
model: torch.nn.Module, |
|
|
max_norm: float = 1.0, |
|
|
) -> float: |
|
|
"""Clip gradient norm and return the norm value. |
|
|
|
|
|
Args: |
|
|
model: The model with gradients |
|
|
max_norm: Maximum gradient norm |
|
|
|
|
|
Returns: |
|
|
The gradient norm before clipping |
|
|
""" |
|
|
parameters = [p for p in model.parameters() if p.grad is not None] |
|
|
if len(parameters) == 0: |
|
|
return 0.0 |
|
|
|
|
|
total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm) |
|
|
return total_norm.item() |
|
|
|