""" Optimizer and learning rate scheduler for SLM training. Uses AdamW with cosine annealing and warmup. """ import math from typing import Optional, Tuple, List import torch from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR def create_optimizer( model: torch.nn.Module, learning_rate: float = 3e-4, weight_decay: float = 0.1, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-8, ) -> AdamW: """Create AdamW optimizer with weight decay. Applies weight decay only to 2D parameters (weights, not biases/norms). Args: model: The model to optimize learning_rate: Base learning rate weight_decay: Weight decay coefficient betas: Adam beta parameters eps: Adam epsilon for numerical stability Returns: Configured AdamW optimizer """ # Separate parameters into decay and no-decay groups decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # No weight decay for: # - 1D parameters (biases, layer norms) # - Embedding layers if param.dim() == 1 or "embedding" in name.lower(): no_decay_params.append(param) else: decay_params.append(param) param_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] optimizer = AdamW( param_groups, lr=learning_rate, betas=betas, eps=eps, ) return optimizer def create_scheduler( optimizer: torch.optim.Optimizer, num_training_steps: int, warmup_ratio: float = 0.1, min_lr_ratio: float = 0.1, scheduler_type: str = "cosine", ) -> LambdaLR: """Create learning rate scheduler. Args: optimizer: The optimizer to schedule num_training_steps: Total number of training steps warmup_ratio: Ratio of warmup steps (e.g., 0.1 = 10%) min_lr_ratio: Minimum LR as ratio of max (e.g., 0.1 = 10% of peak LR) scheduler_type: Type of scheduler ("cosine", "linear", "constant") Returns: LambdaLR scheduler """ num_warmup_steps = int(num_training_steps * warmup_ratio) if scheduler_type == "cosine": def lr_lambda(current_step: int) -> float: # Warmup phase if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) # Cosine annealing phase progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps) ) cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) # Scale between min_lr_ratio and 1.0 return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay elif scheduler_type == "linear": def lr_lambda(current_step: int) -> float: if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps) ) return max(min_lr_ratio, 1.0 - progress * (1.0 - min_lr_ratio)) elif scheduler_type == "constant": def lr_lambda(current_step: int) -> float: if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return 1.0 else: raise ValueError(f"Unknown scheduler type: {scheduler_type}") return LambdaLR(optimizer, lr_lambda) def get_parameter_count(model: torch.nn.Module) -> dict: """Get detailed parameter count for a model. Args: model: The model to analyze Returns: Dictionary with parameter counts """ total_params = 0 trainable_params = 0 embedding_params = 0 for name, param in model.named_parameters(): num_params = param.numel() total_params += num_params if param.requires_grad: trainable_params += num_params if "embedding" in name.lower(): embedding_params += num_params return { "total": total_params, "trainable": trainable_params, "embedding": embedding_params, "non_embedding": total_params - embedding_params, } def get_optimizer_state(optimizer: torch.optim.Optimizer) -> dict: """Get optimizer state statistics. Args: optimizer: The optimizer to analyze Returns: Dictionary with optimizer state info """ num_params = sum( sum(p.numel() for p in group["params"]) for group in optimizer.param_groups ) current_lrs = [group["lr"] for group in optimizer.param_groups] return { "num_param_groups": len(optimizer.param_groups), "total_params": num_params, "learning_rates": current_lrs, } def clip_grad_norm( model: torch.nn.Module, max_norm: float = 1.0, ) -> float: """Clip gradient norm and return the norm value. Args: model: The model with gradients max_norm: Maximum gradient norm Returns: The gradient norm before clipping """ parameters = [p for p in model.parameters() if p.grad is not None] if len(parameters) == 0: return 0.0 total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm) return total_norm.item()