""" Gradient Descent Optimizers ========================== This module implements various gradient descent optimization algorithms with proper mathematical formulations and PyTorch compatibility. Algorithms implemented: - SGD (Stochastic Gradient Descent) - Adam (Adaptive Moment Estimation) - AdamW (Adam with Decoupled Weight Decay) All optimizers follow the mathematical principles and include proper momentum, learning rate scheduling, and weight decay handling. """ import math import logging from abc import ABC, abstractmethod from typing import Dict, List, Tuple, Any import torch logger = logging.getLogger(__name__) class Optimizer(ABC): """Abstract base class for all optimizers""" def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, **kwargs): self.params = params self.lr = lr self.step_count = 0 self.state = {} @abstractmethod def step(self, closure=None): """Perform a single optimization step""" pass def zero_grad(self): """Zero the gradients of all parameters""" for param in self.params: if param.grad is not None: param.grad.zero_() def state_dict(self): """Return the state of the optimizer""" return { 'state': self.state, 'param_groups': [{'params': self.params, 'lr': self.lr}], 'step_count': self.step_count } def load_state_dict(self, state_dict): """Load the state of the optimizer""" self.state = state_dict['state'] self.lr = state_dict['param_groups'][0]['lr'] self.step_count = state_dict['step_count'] class SGD(Optimizer): """ Stochastic Gradient Descent optimizer Mathematical formulation: θ_{t+1} = θ_t - α * ∇_θ J(θ_t) With momentum: v_{t+1} = μ * v_t + ∇_θ J(θ_t) θ_{t+1} = θ_t - α * v_{t+1} """ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, momentum: float = 0.0, weight_decay: float = 0.0, dampening: float = 0.0, nesterov: bool = False): super().__init__(params, lr) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening self.nesterov = nesterov # Initialize momentum buffers for param in self.params: if momentum > 0: self.state[param] = {'momentum_buffer': torch.zeros_like(param)} logger.info(f"Initialized SGD optimizer: lr={lr}, momentum={momentum}, weight_decay={weight_decay}") def step(self, closure=None): """Perform SGD optimization step""" loss = None if closure is not None: loss = closure() for param in self.params: if param.grad is None: continue grad = param.grad.data # Apply weight decay if self.weight_decay != 0: grad = grad.add(param.data, alpha=self.weight_decay) # Apply momentum if self.momentum != 0: param_state = self.state[param] if 'momentum_buffer' not in param_state: param_state['momentum_buffer'] = torch.zeros_like(param.data) momentum_buffer = param_state['momentum_buffer'] momentum_buffer.mul_(self.momentum).add_(grad, alpha=1 - self.dampening) if self.nesterov: grad = grad.add(momentum_buffer, alpha=self.momentum) else: grad = momentum_buffer # Update parameters param.data.add_(grad, alpha=-self.lr) self.step_count += 1 return loss class Adam(Optimizer): """ Adam (Adaptive Moment Estimation) optimizer Mathematical formulation: m_t = β₁ * m_{t-1} + (1 - β₁) * g_t v_t = β₂ * v_{t-1} + (1 - β₂) * g_t² m̂_t = m_t / (1 - β₁ᵗ) v̂_t = v_t / (1 - β₂ᵗ) θ_{t+1} = θ_t - α * m̂_t / (√v̂_t + ε) """ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, amsgrad: bool = False): super().__init__(params, lr) self.betas = betas self.eps = eps self.weight_decay = weight_decay self.amsgrad = amsgrad # Initialize moment estimates for param in self.params: self.state[param] = { 'step': 0, 'exp_avg': torch.zeros_like(param.data), 'exp_avg_sq': torch.zeros_like(param.data) } if amsgrad: self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data) logger.info(f"Initialized Adam optimizer: lr={lr}, betas={betas}, eps={eps}") def step(self, closure=None): """Perform Adam optimization step""" loss = None if closure is not None: loss = closure() for param in self.params: if param.grad is None: continue grad = param.grad.data # Apply weight decay if self.weight_decay != 0: grad = grad.add(param.data, alpha=self.weight_decay) param_state = self.state[param] exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq'] beta1, beta2 = self.betas param_state['step'] += 1 bias_correction1 = 1 - beta1 ** param_state['step'] bias_correction2 = 1 - beta2 ** param_state['step'] # Update biased first moment estimate exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # Update biased second raw moment estimate exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if self.amsgrad: # Maintains the maximum of all 2nd moment running avg. of squared gradients torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq']) # Use the max. for normalizing running avg. of squared gradients denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) # Update parameters step_size = self.lr / bias_correction1 param.data.addcdiv_(exp_avg, denom, value=-step_size) self.step_count += 1 return loss class AdamW(Optimizer): """ AdamW (Adam with Decoupled Weight Decay) optimizer Mathematical formulation: θ_t = θ_{t-1} - α * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1}) Where weight decay is applied directly to parameters, not gradients. """ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01, amsgrad: bool = False): super().__init__(params, lr) self.betas = betas self.eps = eps self.weight_decay = weight_decay self.amsgrad = amsgrad # Initialize moment estimates for param in self.params: self.state[param] = { 'step': 0, 'exp_avg': torch.zeros_like(param.data), 'exp_avg_sq': torch.zeros_like(param.data) } if amsgrad: self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data) logger.info(f"Initialized AdamW optimizer: lr={lr}, betas={betas}, weight_decay={weight_decay}") def step(self, closure=None): """Perform AdamW optimization step""" loss = None if closure is not None: loss = closure() for param in self.params: if param.grad is None: continue grad = param.grad.data param_state = self.state[param] exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq'] beta1, beta2 = self.betas param_state['step'] += 1 bias_correction1 = 1 - beta1 ** param_state['step'] bias_correction2 = 1 - beta2 ** param_state['step'] # Update biased first moment estimate exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # Update biased second raw moment estimate exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if self.amsgrad: # Maintains the maximum of all 2nd moment running avg. of squared gradients torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq']) # Use the max. for normalizing running avg. of squared gradients denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) # Update parameters with decoupled weight decay step_size = self.lr / bias_correction1 param.data.mul_(1 - self.lr * self.weight_decay) param.data.addcdiv_(exp_avg, denom, value=-step_size) self.step_count += 1 return loss class OptimizerFactory: """Factory class for creating optimizers""" @staticmethod def create_optimizer(optimizer_type: str, params: List[torch.Tensor], **kwargs) -> Optimizer: """Create an optimizer instance""" optimizers = { 'sgd': SGD, 'adam': Adam, 'adamw': AdamW } if optimizer_type.lower() not in optimizers: raise ValueError(f"Unknown optimizer type: {optimizer_type}") optimizer_class = optimizers[optimizer_type.lower()] return optimizer_class(params, **kwargs) @staticmethod def get_default_config(optimizer_type: str) -> Dict[str, Any]: """Get default configuration for optimizer""" configs = { 'sgd': { 'lr': 1e-3, 'momentum': 0.9, 'weight_decay': 1e-4 }, 'adam': { 'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-8, 'weight_decay': 1e-4 }, 'adamw': { 'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-8, 'weight_decay': 0.01 } } return configs.get(optimizer_type.lower(), {})