|
|
|
|
| import torch
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
| def disable_running_stats(model):
|
| def _disable(module):
|
| if isinstance(module, nn.BatchNorm2d):
|
| module.backup_momentum = module.momentum
|
| module.momentum = 0
|
|
|
| model.apply(_disable)
|
|
|
| def enable_running_stats(model):
|
| def _enable(module):
|
| if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
|
| module.momentum = module.backup_momentum
|
|
|
| model.apply(_enable)
|
|
|
| class SAM(torch.optim.Optimizer):
|
| def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
|
| assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
|
|
| defaults = dict(rho=rho, **kwargs)
|
| super(SAM, self).__init__(params, defaults)
|
|
|
| self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| self.param_groups = self.base_optimizer.param_groups
|
|
|
| @torch.no_grad()
|
| def first_step(self, zero_grad=False):
|
| grad_norm = self._grad_norm()
|
| for group in self.param_groups:
|
| scale = group["rho"] / (grad_norm + 1e-12)
|
|
|
| for p in group["params"]:
|
| if p.grad is None: continue
|
| e_w = p.grad * scale.to(p)
|
| p.add_(e_w)
|
| self.state[p]["e_w"] = e_w
|
|
|
| if zero_grad: self.zero_grad()
|
|
|
| @torch.no_grad()
|
| def second_step(self, zero_grad=False):
|
| for group in self.param_groups:
|
| for p in group["params"]:
|
| if p.grad is None: continue
|
| p.sub_(self.state[p]["e_w"])
|
|
|
| self.base_optimizer.step()
|
|
|
| if zero_grad: self.zero_grad()
|
|
|
| @torch.no_grad()
|
| def step(self, closure=None):
|
| assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
| closure = torch.enable_grad()(closure)
|
|
|
| self.first_step(zero_grad=True)
|
| closure()
|
| self.second_step()
|
|
|
| def _grad_norm(self):
|
| shared_device = self.param_groups[0]["params"][0].device
|
| norm = torch.norm(
|
| torch.stack([
|
| p.grad.norm(p=2).to(shared_device)
|
| for group in self.param_groups for p in group["params"]
|
| if p.grad is not None
|
| ]),
|
| p=2
|
| )
|
| return norm |