Spaces:
Paused
Paused
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.nn import Parameter | |
| from torch import nn | |
| from torch.autograd import Variable | |
| def l2normalize(vector, eps = 1e-15): | |
| return vector/(vector.norm()+eps) | |
| class SpectralNorm(nn.Module): | |
| def __init__(self, module, name='weight', power_iterations=1): | |
| super(SpectralNorm, self).__init__() | |
| self.module = module | |
| self.name = name | |
| self.power_iterations = power_iterations | |
| if not self._made_params(): | |
| self._make_params() | |
| def _update_u_v(self): | |
| u = getattr(self.module, self.name + "_u") | |
| v = getattr(self.module, self.name + "_v") | |
| w = getattr(self.module, self.name + "_bar") | |
| height = w.data.shape[0] | |
| for _ in range(self.power_iterations): | |
| v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) | |
| u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) | |
| # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
| sigma = u.dot(w.view(height, -1).mv(v)) | |
| setattr(self.module, self.name, w / sigma.expand_as(w)) | |
| def _made_params(self): | |
| try: | |
| u = getattr(self.module, self.name + "_u") | |
| v = getattr(self.module, self.name + "_v") | |
| w = getattr(self.module, self.name + "_bar") | |
| return True | |
| except AttributeError: | |
| return False | |
| def _make_params(self): | |
| w = getattr(self.module, self.name) | |
| height = w.data.shape[0] | |
| width = w.view(height, -1).data.shape[1] | |
| u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
| v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
| u.data = l2normalize(u.data) | |
| v.data = l2normalize(v.data) | |
| w_bar = Parameter(w.data) | |
| del self.module._parameters[self.name] | |
| self.module.register_parameter(self.name + "_u", u) | |
| self.module.register_parameter(self.name + "_v", v) | |
| self.module.register_parameter(self.name + "_bar", w_bar) | |
| def forward(self, *args): | |
| self._update_u_v() | |
| return self.module.forward(*args) |