import torch from torch.optim import Optimizer class DynamicNesterov(Optimizer): """ Dynamic Nesterov Accelerated Gradient optimizer with adaptive momentum estimation. """ def __init__(self, params, lr=1e-3, beta_max=0.9, epsilon=1e-8): defaults = dict(lr=lr, beta_max=beta_max, epsilon=epsilon) super(DynamicNesterov, self).__init__(params, defaults) def _estimate_spectral_norm(self, grad_norm): """ Simplified approximation of Hessian spectral norm. """ return grad_norm * 0.1 def step(self, closure=None): """Performs a single optimization step.""" loss = None if closure is not None: loss = closure() for group in self.param_groups: lr = group['lr'] beta_max = group['beta_max'] epsilon = group['epsilon'] for p in group['params']: if p.grad is None: continue grad = p.grad.data # Initialize state state = self.state[p] if len(state) == 0: state['momentum_buffer'] = torch.zeros_like(p.data) state['prev_grad'] = torch.zeros_like(p.data) momentum_buffer = state['momentum_buffer'] prev_grad = state['prev_grad'] # Compute gradient norm grad_norm = torch.norm(grad) # Estimate spectral norm spectral_norm = self._estimate_spectral_norm(grad_norm) # Dynamic momentum coefficient beta = min(beta_max, spectral_norm / (grad_norm + epsilon)) # Nesterov update momentum_buffer.mul_(beta).add_(grad, alpha=-lr) p.data.add_(momentum_buffer, alpha=beta) p.data.add_(grad, alpha=-lr) # Save previous gradient prev_grad.copy_(grad) return loss