PebbleLM-117M / src /training /optimizer.py
nameissakthi's picture
Add model architecture code
27871e7
"""
Optimizer and learning rate scheduler for SLM training.
Uses AdamW with cosine annealing and warmup.
"""
import math
from typing import Optional, Tuple, List
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
def create_optimizer(
model: torch.nn.Module,
learning_rate: float = 3e-4,
weight_decay: float = 0.1,
betas: Tuple[float, float] = (0.9, 0.95),
eps: float = 1e-8,
) -> AdamW:
"""Create AdamW optimizer with weight decay.
Applies weight decay only to 2D parameters (weights, not biases/norms).
Args:
model: The model to optimize
learning_rate: Base learning rate
weight_decay: Weight decay coefficient
betas: Adam beta parameters
eps: Adam epsilon for numerical stability
Returns:
Configured AdamW optimizer
"""
# Separate parameters into decay and no-decay groups
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# No weight decay for:
# - 1D parameters (biases, layer norms)
# - Embedding layers
if param.dim() == 1 or "embedding" in name.lower():
no_decay_params.append(param)
else:
decay_params.append(param)
param_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
optimizer = AdamW(
param_groups,
lr=learning_rate,
betas=betas,
eps=eps,
)
return optimizer
def create_scheduler(
optimizer: torch.optim.Optimizer,
num_training_steps: int,
warmup_ratio: float = 0.1,
min_lr_ratio: float = 0.1,
scheduler_type: str = "cosine",
) -> LambdaLR:
"""Create learning rate scheduler.
Args:
optimizer: The optimizer to schedule
num_training_steps: Total number of training steps
warmup_ratio: Ratio of warmup steps (e.g., 0.1 = 10%)
min_lr_ratio: Minimum LR as ratio of max (e.g., 0.1 = 10% of peak LR)
scheduler_type: Type of scheduler ("cosine", "linear", "constant")
Returns:
LambdaLR scheduler
"""
num_warmup_steps = int(num_training_steps * warmup_ratio)
if scheduler_type == "cosine":
def lr_lambda(current_step: int) -> float:
# Warmup phase
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
# Cosine annealing phase
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
# Scale between min_lr_ratio and 1.0
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
elif scheduler_type == "linear":
def lr_lambda(current_step: int) -> float:
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(min_lr_ratio, 1.0 - progress * (1.0 - min_lr_ratio))
elif scheduler_type == "constant":
def lr_lambda(current_step: int) -> float:
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return 1.0
else:
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
return LambdaLR(optimizer, lr_lambda)
def get_parameter_count(model: torch.nn.Module) -> dict:
"""Get detailed parameter count for a model.
Args:
model: The model to analyze
Returns:
Dictionary with parameter counts
"""
total_params = 0
trainable_params = 0
embedding_params = 0
for name, param in model.named_parameters():
num_params = param.numel()
total_params += num_params
if param.requires_grad:
trainable_params += num_params
if "embedding" in name.lower():
embedding_params += num_params
return {
"total": total_params,
"trainable": trainable_params,
"embedding": embedding_params,
"non_embedding": total_params - embedding_params,
}
def get_optimizer_state(optimizer: torch.optim.Optimizer) -> dict:
"""Get optimizer state statistics.
Args:
optimizer: The optimizer to analyze
Returns:
Dictionary with optimizer state info
"""
num_params = sum(
sum(p.numel() for p in group["params"])
for group in optimizer.param_groups
)
current_lrs = [group["lr"] for group in optimizer.param_groups]
return {
"num_param_groups": len(optimizer.param_groups),
"total_params": num_params,
"learning_rates": current_lrs,
}
def clip_grad_norm(
model: torch.nn.Module,
max_norm: float = 1.0,
) -> float:
"""Clip gradient norm and return the norm value.
Args:
model: The model with gradients
max_norm: Maximum gradient norm
Returns:
The gradient norm before clipping
"""
parameters = [p for p in model.parameters() if p.grad is not None]
if len(parameters) == 0:
return 0.0
total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
return total_norm.item()