Mango-Metrics-NLM
feat: Phi-3.5-MoE multi-agent model repository
c8b77b5
"""
Gradient Descent Optimizers
==========================
This module implements various gradient descent optimization algorithms
with proper mathematical formulations and PyTorch compatibility.
Algorithms implemented:
- SGD (Stochastic Gradient Descent)
- Adam (Adaptive Moment Estimation)
- AdamW (Adam with Decoupled Weight Decay)
All optimizers follow the mathematical principles and include proper
momentum, learning rate scheduling, and weight decay handling.
"""
import math
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Any
import torch
logger = logging.getLogger(__name__)
class Optimizer(ABC):
"""Abstract base class for all optimizers"""
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, **kwargs):
self.params = params
self.lr = lr
self.step_count = 0
self.state = {}
@abstractmethod
def step(self, closure=None):
"""Perform a single optimization step"""
pass
def zero_grad(self):
"""Zero the gradients of all parameters"""
for param in self.params:
if param.grad is not None:
param.grad.zero_()
def state_dict(self):
"""Return the state of the optimizer"""
return {
'state': self.state,
'param_groups': [{'params': self.params, 'lr': self.lr}],
'step_count': self.step_count
}
def load_state_dict(self, state_dict):
"""Load the state of the optimizer"""
self.state = state_dict['state']
self.lr = state_dict['param_groups'][0]['lr']
self.step_count = state_dict['step_count']
class SGD(Optimizer):
"""
Stochastic Gradient Descent optimizer
Mathematical formulation:
θ_{t+1} = θ_t - α * ∇_θ J(θ_t)
With momentum:
v_{t+1} = μ * v_t + ∇_θ J(θ_t)
θ_{t+1} = θ_t - α * v_{t+1}
"""
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3,
momentum: float = 0.0, weight_decay: float = 0.0,
dampening: float = 0.0, nesterov: bool = False):
super().__init__(params, lr)
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
self.nesterov = nesterov
# Initialize momentum buffers
for param in self.params:
if momentum > 0:
self.state[param] = {'momentum_buffer': torch.zeros_like(param)}
logger.info(f"Initialized SGD optimizer: lr={lr}, momentum={momentum}, weight_decay={weight_decay}")
def step(self, closure=None):
"""Perform SGD optimization step"""
loss = None
if closure is not None:
loss = closure()
for param in self.params:
if param.grad is None:
continue
grad = param.grad.data
# Apply weight decay
if self.weight_decay != 0:
grad = grad.add(param.data, alpha=self.weight_decay)
# Apply momentum
if self.momentum != 0:
param_state = self.state[param]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = torch.zeros_like(param.data)
momentum_buffer = param_state['momentum_buffer']
momentum_buffer.mul_(self.momentum).add_(grad, alpha=1 - self.dampening)
if self.nesterov:
grad = grad.add(momentum_buffer, alpha=self.momentum)
else:
grad = momentum_buffer
# Update parameters
param.data.add_(grad, alpha=-self.lr)
self.step_count += 1
return loss
class Adam(Optimizer):
"""
Adam (Adaptive Moment Estimation) optimizer
Mathematical formulation:
m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
m̂_t = m_t / (1 - β₁ᵗ)
v̂_t = v_t / (1 - β₂ᵗ)
θ_{t+1} = θ_t - α * m̂_t / (√v̂_t + ε)
"""
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8,
weight_decay: float = 0.0, amsgrad: bool = False):
super().__init__(params, lr)
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
self.amsgrad = amsgrad
# Initialize moment estimates
for param in self.params:
self.state[param] = {
'step': 0,
'exp_avg': torch.zeros_like(param.data),
'exp_avg_sq': torch.zeros_like(param.data)
}
if amsgrad:
self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data)
logger.info(f"Initialized Adam optimizer: lr={lr}, betas={betas}, eps={eps}")
def step(self, closure=None):
"""Perform Adam optimization step"""
loss = None
if closure is not None:
loss = closure()
for param in self.params:
if param.grad is None:
continue
grad = param.grad.data
# Apply weight decay
if self.weight_decay != 0:
grad = grad.add(param.data, alpha=self.weight_decay)
param_state = self.state[param]
exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq']
beta1, beta2 = self.betas
param_state['step'] += 1
bias_correction1 = 1 - beta1 ** param_state['step']
bias_correction2 = 1 - beta2 ** param_state['step']
# Update biased first moment estimate
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Update biased second raw moment estimate
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if self.amsgrad:
# Maintains the maximum of all 2nd moment running avg. of squared gradients
torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq'])
# Use the max. for normalizing running avg. of squared gradients
denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
# Update parameters
step_size = self.lr / bias_correction1
param.data.addcdiv_(exp_avg, denom, value=-step_size)
self.step_count += 1
return loss
class AdamW(Optimizer):
"""
AdamW (Adam with Decoupled Weight Decay) optimizer
Mathematical formulation:
θ_t = θ_{t-1} - α * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1})
Where weight decay is applied directly to parameters, not gradients.
"""
def __init__(self, params: List[torch.Tensor], lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8,
weight_decay: float = 0.01, amsgrad: bool = False):
super().__init__(params, lr)
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
self.amsgrad = amsgrad
# Initialize moment estimates
for param in self.params:
self.state[param] = {
'step': 0,
'exp_avg': torch.zeros_like(param.data),
'exp_avg_sq': torch.zeros_like(param.data)
}
if amsgrad:
self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data)
logger.info(f"Initialized AdamW optimizer: lr={lr}, betas={betas}, weight_decay={weight_decay}")
def step(self, closure=None):
"""Perform AdamW optimization step"""
loss = None
if closure is not None:
loss = closure()
for param in self.params:
if param.grad is None:
continue
grad = param.grad.data
param_state = self.state[param]
exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq']
beta1, beta2 = self.betas
param_state['step'] += 1
bias_correction1 = 1 - beta1 ** param_state['step']
bias_correction2 = 1 - beta2 ** param_state['step']
# Update biased first moment estimate
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Update biased second raw moment estimate
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if self.amsgrad:
# Maintains the maximum of all 2nd moment running avg. of squared gradients
torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq'])
# Use the max. for normalizing running avg. of squared gradients
denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
# Update parameters with decoupled weight decay
step_size = self.lr / bias_correction1
param.data.mul_(1 - self.lr * self.weight_decay)
param.data.addcdiv_(exp_avg, denom, value=-step_size)
self.step_count += 1
return loss
class OptimizerFactory:
"""Factory class for creating optimizers"""
@staticmethod
def create_optimizer(optimizer_type: str, params: List[torch.Tensor], **kwargs) -> Optimizer:
"""Create an optimizer instance"""
optimizers = {
'sgd': SGD,
'adam': Adam,
'adamw': AdamW
}
if optimizer_type.lower() not in optimizers:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer_class = optimizers[optimizer_type.lower()]
return optimizer_class(params, **kwargs)
@staticmethod
def get_default_config(optimizer_type: str) -> Dict[str, Any]:
"""Get default configuration for optimizer"""
configs = {
'sgd': {
'lr': 1e-3,
'momentum': 0.9,
'weight_decay': 1e-4
},
'adam': {
'lr': 1e-3,
'betas': (0.9, 0.999),
'eps': 1e-8,
'weight_decay': 1e-4
},
'adamw': {
'lr': 1e-3,
'betas': (0.9, 0.999),
'eps': 1e-8,
'weight_decay': 0.01
}
}
return configs.get(optimizer_type.lower(), {})