pneumonia-detection / src /optimizers /dynamic_nesterov.py
tknight
Upload 6 files
2b74065 verified
import torch
from torch.optim import Optimizer
class DynamicNesterov(Optimizer):
"""
Dynamic Nesterov Accelerated Gradient optimizer
with adaptive momentum estimation.
"""
def __init__(self, params, lr=1e-3, beta_max=0.9, epsilon=1e-8):
defaults = dict(lr=lr, beta_max=beta_max, epsilon=epsilon)
super(DynamicNesterov, self).__init__(params, defaults)
def _estimate_spectral_norm(self, grad_norm):
"""
Simplified approximation of Hessian spectral norm.
"""
return grad_norm * 0.1
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
lr = group['lr']
beta_max = group['beta_max']
epsilon = group['epsilon']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
# Initialize state
state = self.state[p]
if len(state) == 0:
state['momentum_buffer'] = torch.zeros_like(p.data)
state['prev_grad'] = torch.zeros_like(p.data)
momentum_buffer = state['momentum_buffer']
prev_grad = state['prev_grad']
# Compute gradient norm
grad_norm = torch.norm(grad)
# Estimate spectral norm
spectral_norm = self._estimate_spectral_norm(grad_norm)
# Dynamic momentum coefficient
beta = min(beta_max, spectral_norm / (grad_norm + epsilon))
# Nesterov update
momentum_buffer.mul_(beta).add_(grad, alpha=-lr)
p.data.add_(momentum_buffer, alpha=beta)
p.data.add_(grad, alpha=-lr)
# Save previous gradient
prev_grad.copy_(grad)
return loss