|
|
""" |
|
|
Gradient Descent Optimizers |
|
|
========================== |
|
|
|
|
|
This module implements various gradient descent optimization algorithms |
|
|
with proper mathematical formulations and PyTorch compatibility. |
|
|
|
|
|
Algorithms implemented: |
|
|
- SGD (Stochastic Gradient Descent) |
|
|
- Adam (Adaptive Moment Estimation) |
|
|
- AdamW (Adam with Decoupled Weight Decay) |
|
|
|
|
|
All optimizers follow the mathematical principles and include proper |
|
|
momentum, learning rate scheduling, and weight decay handling. |
|
|
""" |
|
|
|
|
|
import math |
|
|
import logging |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, List, Tuple, Any |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class Optimizer(ABC): |
|
|
"""Abstract base class for all optimizers""" |
|
|
|
|
|
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, **kwargs): |
|
|
self.params = params |
|
|
self.lr = lr |
|
|
self.step_count = 0 |
|
|
self.state = {} |
|
|
|
|
|
@abstractmethod |
|
|
def step(self, closure=None): |
|
|
"""Perform a single optimization step""" |
|
|
pass |
|
|
|
|
|
def zero_grad(self): |
|
|
"""Zero the gradients of all parameters""" |
|
|
for param in self.params: |
|
|
if param.grad is not None: |
|
|
param.grad.zero_() |
|
|
|
|
|
def state_dict(self): |
|
|
"""Return the state of the optimizer""" |
|
|
return { |
|
|
'state': self.state, |
|
|
'param_groups': [{'params': self.params, 'lr': self.lr}], |
|
|
'step_count': self.step_count |
|
|
} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
"""Load the state of the optimizer""" |
|
|
self.state = state_dict['state'] |
|
|
self.lr = state_dict['param_groups'][0]['lr'] |
|
|
self.step_count = state_dict['step_count'] |
|
|
|
|
|
|
|
|
class SGD(Optimizer): |
|
|
""" |
|
|
Stochastic Gradient Descent optimizer |
|
|
|
|
|
Mathematical formulation: |
|
|
θ_{t+1} = θ_t - α * ∇_θ J(θ_t) |
|
|
|
|
|
With momentum: |
|
|
v_{t+1} = μ * v_t + ∇_θ J(θ_t) |
|
|
θ_{t+1} = θ_t - α * v_{t+1} |
|
|
""" |
|
|
|
|
|
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, |
|
|
momentum: float = 0.0, weight_decay: float = 0.0, |
|
|
dampening: float = 0.0, nesterov: bool = False): |
|
|
super().__init__(params, lr) |
|
|
self.momentum = momentum |
|
|
self.weight_decay = weight_decay |
|
|
self.dampening = dampening |
|
|
self.nesterov = nesterov |
|
|
|
|
|
|
|
|
for param in self.params: |
|
|
if momentum > 0: |
|
|
self.state[param] = {'momentum_buffer': torch.zeros_like(param)} |
|
|
|
|
|
logger.info(f"Initialized SGD optimizer: lr={lr}, momentum={momentum}, weight_decay={weight_decay}") |
|
|
|
|
|
def step(self, closure=None): |
|
|
"""Perform SGD optimization step""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
loss = closure() |
|
|
|
|
|
for param in self.params: |
|
|
if param.grad is None: |
|
|
continue |
|
|
|
|
|
grad = param.grad.data |
|
|
|
|
|
|
|
|
if self.weight_decay != 0: |
|
|
grad = grad.add(param.data, alpha=self.weight_decay) |
|
|
|
|
|
|
|
|
if self.momentum != 0: |
|
|
param_state = self.state[param] |
|
|
if 'momentum_buffer' not in param_state: |
|
|
param_state['momentum_buffer'] = torch.zeros_like(param.data) |
|
|
|
|
|
momentum_buffer = param_state['momentum_buffer'] |
|
|
momentum_buffer.mul_(self.momentum).add_(grad, alpha=1 - self.dampening) |
|
|
|
|
|
if self.nesterov: |
|
|
grad = grad.add(momentum_buffer, alpha=self.momentum) |
|
|
else: |
|
|
grad = momentum_buffer |
|
|
|
|
|
|
|
|
param.data.add_(grad, alpha=-self.lr) |
|
|
|
|
|
self.step_count += 1 |
|
|
return loss |
|
|
|
|
|
|
|
|
class Adam(Optimizer): |
|
|
""" |
|
|
Adam (Adaptive Moment Estimation) optimizer |
|
|
|
|
|
Mathematical formulation: |
|
|
m_t = β₁ * m_{t-1} + (1 - β₁) * g_t |
|
|
v_t = β₂ * v_{t-1} + (1 - β₂) * g_t² |
|
|
m̂_t = m_t / (1 - β₁ᵗ) |
|
|
v̂_t = v_t / (1 - β₂ᵗ) |
|
|
θ_{t+1} = θ_t - α * m̂_t / (√v̂_t + ε) |
|
|
""" |
|
|
|
|
|
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, |
|
|
betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, |
|
|
weight_decay: float = 0.0, amsgrad: bool = False): |
|
|
super().__init__(params, lr) |
|
|
self.betas = betas |
|
|
self.eps = eps |
|
|
self.weight_decay = weight_decay |
|
|
self.amsgrad = amsgrad |
|
|
|
|
|
|
|
|
for param in self.params: |
|
|
self.state[param] = { |
|
|
'step': 0, |
|
|
'exp_avg': torch.zeros_like(param.data), |
|
|
'exp_avg_sq': torch.zeros_like(param.data) |
|
|
} |
|
|
if amsgrad: |
|
|
self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data) |
|
|
|
|
|
logger.info(f"Initialized Adam optimizer: lr={lr}, betas={betas}, eps={eps}") |
|
|
|
|
|
def step(self, closure=None): |
|
|
"""Perform Adam optimization step""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
loss = closure() |
|
|
|
|
|
for param in self.params: |
|
|
if param.grad is None: |
|
|
continue |
|
|
|
|
|
grad = param.grad.data |
|
|
|
|
|
|
|
|
if self.weight_decay != 0: |
|
|
grad = grad.add(param.data, alpha=self.weight_decay) |
|
|
|
|
|
param_state = self.state[param] |
|
|
exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq'] |
|
|
beta1, beta2 = self.betas |
|
|
|
|
|
param_state['step'] += 1 |
|
|
bias_correction1 = 1 - beta1 ** param_state['step'] |
|
|
bias_correction2 = 1 - beta2 ** param_state['step'] |
|
|
|
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
|
|
|
|
|
|
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
|
|
|
|
|
if self.amsgrad: |
|
|
|
|
|
torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq']) |
|
|
|
|
|
denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps) |
|
|
else: |
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) |
|
|
|
|
|
|
|
|
step_size = self.lr / bias_correction1 |
|
|
param.data.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
|
|
|
self.step_count += 1 |
|
|
return loss |
|
|
|
|
|
|
|
|
class AdamW(Optimizer): |
|
|
""" |
|
|
AdamW (Adam with Decoupled Weight Decay) optimizer |
|
|
|
|
|
Mathematical formulation: |
|
|
θ_t = θ_{t-1} - α * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1}) |
|
|
|
|
|
Where weight decay is applied directly to parameters, not gradients. |
|
|
""" |
|
|
|
|
|
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, |
|
|
betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, |
|
|
weight_decay: float = 0.01, amsgrad: bool = False): |
|
|
super().__init__(params, lr) |
|
|
self.betas = betas |
|
|
self.eps = eps |
|
|
self.weight_decay = weight_decay |
|
|
self.amsgrad = amsgrad |
|
|
|
|
|
|
|
|
for param in self.params: |
|
|
self.state[param] = { |
|
|
'step': 0, |
|
|
'exp_avg': torch.zeros_like(param.data), |
|
|
'exp_avg_sq': torch.zeros_like(param.data) |
|
|
} |
|
|
if amsgrad: |
|
|
self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data) |
|
|
|
|
|
logger.info(f"Initialized AdamW optimizer: lr={lr}, betas={betas}, weight_decay={weight_decay}") |
|
|
|
|
|
def step(self, closure=None): |
|
|
"""Perform AdamW optimization step""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
loss = closure() |
|
|
|
|
|
for param in self.params: |
|
|
if param.grad is None: |
|
|
continue |
|
|
|
|
|
grad = param.grad.data |
|
|
param_state = self.state[param] |
|
|
exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq'] |
|
|
beta1, beta2 = self.betas |
|
|
|
|
|
param_state['step'] += 1 |
|
|
bias_correction1 = 1 - beta1 ** param_state['step'] |
|
|
bias_correction2 = 1 - beta2 ** param_state['step'] |
|
|
|
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
|
|
|
|
|
|
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
|
|
|
|
|
if self.amsgrad: |
|
|
|
|
|
torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq']) |
|
|
|
|
|
denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps) |
|
|
else: |
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) |
|
|
|
|
|
|
|
|
step_size = self.lr / bias_correction1 |
|
|
param.data.mul_(1 - self.lr * self.weight_decay) |
|
|
param.data.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
|
|
|
self.step_count += 1 |
|
|
return loss |
|
|
|
|
|
|
|
|
class OptimizerFactory: |
|
|
"""Factory class for creating optimizers""" |
|
|
|
|
|
@staticmethod |
|
|
def create_optimizer(optimizer_type: str, params: List[torch.Tensor], **kwargs) -> Optimizer: |
|
|
"""Create an optimizer instance""" |
|
|
optimizers = { |
|
|
'sgd': SGD, |
|
|
'adam': Adam, |
|
|
'adamw': AdamW |
|
|
} |
|
|
|
|
|
if optimizer_type.lower() not in optimizers: |
|
|
raise ValueError(f"Unknown optimizer type: {optimizer_type}") |
|
|
|
|
|
optimizer_class = optimizers[optimizer_type.lower()] |
|
|
return optimizer_class(params, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def get_default_config(optimizer_type: str) -> Dict[str, Any]: |
|
|
"""Get default configuration for optimizer""" |
|
|
configs = { |
|
|
'sgd': { |
|
|
'lr': 1e-3, |
|
|
'momentum': 0.9, |
|
|
'weight_decay': 1e-4 |
|
|
}, |
|
|
'adam': { |
|
|
'lr': 1e-3, |
|
|
'betas': (0.9, 0.999), |
|
|
'eps': 1e-8, |
|
|
'weight_decay': 1e-4 |
|
|
}, |
|
|
'adamw': { |
|
|
'lr': 1e-3, |
|
|
'betas': (0.9, 0.999), |
|
|
'eps': 1e-8, |
|
|
'weight_decay': 0.01 |
|
|
} |
|
|
} |
|
|
|
|
|
return configs.get(optimizer_type.lower(), {}) |
|
|
|