| """ |
| Spectral Normalization from https://arxiv.org/abs/1802.05957 |
| """ |
| import torch |
| from torch.nn.functional import normalize |
|
|
|
|
| class SpectralNorm(object): |
| |
| |
| |
|
|
| _version = 1 |
|
|
| |
| |
| |
| |
|
|
| def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): |
| self.name = name |
| self.dim = dim |
| if n_power_iterations <= 0: |
| raise ValueError( |
| 'Expected n_power_iterations to be positive, but ' |
| 'got n_power_iterations={}'.format(n_power_iterations)) |
| self.n_power_iterations = n_power_iterations |
| self.eps = eps |
|
|
| def reshape_weight_to_matrix(self, weight): |
| weight_mat = weight |
| if self.dim != 0: |
| |
| weight_mat = weight_mat.permute( |
| self.dim, |
| *[d for d in range(weight_mat.dim()) if d != self.dim]) |
| height = weight_mat.size(0) |
| return weight_mat.reshape(height, -1) |
|
|
| def compute_weight(self, module, do_power_iteration): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| weight = getattr(module, self.name + '_orig') |
| u = getattr(module, self.name + '_u') |
| v = getattr(module, self.name + '_v') |
| weight_mat = self.reshape_weight_to_matrix(weight) |
|
|
| if do_power_iteration: |
| with torch.no_grad(): |
| for _ in range(self.n_power_iterations): |
| |
| |
| |
| v = normalize(torch.mv(weight_mat.t(), u), |
| dim=0, |
| eps=self.eps, |
| out=v) |
| u = normalize(torch.mv(weight_mat, v), |
| dim=0, |
| eps=self.eps, |
| out=u) |
| if self.n_power_iterations > 0: |
| |
| u = u.clone() |
| v = v.clone() |
|
|
| sigma = torch.dot(u, torch.mv(weight_mat, v)) |
| weight = weight / sigma |
| return weight |
|
|
| def remove(self, module): |
| with torch.no_grad(): |
| weight = self.compute_weight(module, do_power_iteration=False) |
| delattr(module, self.name) |
| delattr(module, self.name + '_u') |
| delattr(module, self.name + '_v') |
| delattr(module, self.name + '_orig') |
| module.register_parameter(self.name, |
| torch.nn.Parameter(weight.detach())) |
|
|
| def __call__(self, module, inputs): |
| setattr( |
| module, self.name, |
| self.compute_weight(module, do_power_iteration=module.training)) |
|
|
| def _solve_v_and_rescale(self, weight_mat, u, target_sigma): |
| |
| |
| |
| v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), |
| weight_mat.t(), u.unsqueeze(1)).squeeze(1) |
| return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) |
|
|
| @staticmethod |
| def apply(module, name, n_power_iterations, dim, eps): |
| for k, hook in module._forward_pre_hooks.items(): |
| if isinstance(hook, SpectralNorm) and hook.name == name: |
| raise RuntimeError( |
| "Cannot register two spectral_norm hooks on " |
| "the same parameter {}".format(name)) |
|
|
| fn = SpectralNorm(name, n_power_iterations, dim, eps) |
| weight = module._parameters[name] |
|
|
| with torch.no_grad(): |
| weight_mat = fn.reshape_weight_to_matrix(weight) |
|
|
| h, w = weight_mat.size() |
| |
| u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) |
| v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) |
|
|
| delattr(module, fn.name) |
| module.register_parameter(fn.name + "_orig", weight) |
| |
| |
| |
| |
| |
| setattr(module, fn.name, weight.data) |
| module.register_buffer(fn.name + "_u", u) |
| module.register_buffer(fn.name + "_v", v) |
|
|
| module.register_forward_pre_hook(fn) |
|
|
| module._register_state_dict_hook(SpectralNormStateDictHook(fn)) |
| module._register_load_state_dict_pre_hook( |
| SpectralNormLoadStateDictPreHook(fn)) |
| return fn |
|
|
|
|
| |
| |
| class SpectralNormLoadStateDictPreHook(object): |
| |
| def __init__(self, fn): |
| self.fn = fn |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def __call__(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| fn = self.fn |
| version = local_metadata.get('spectral_norm', |
| {}).get(fn.name + '.version', None) |
| if version is None or version < 1: |
| with torch.no_grad(): |
| weight_orig = state_dict[prefix + fn.name + '_orig'] |
| |
| |
| weight_mat = fn.reshape_weight_to_matrix(weight_orig) |
| u = state_dict[prefix + fn.name + '_u'] |
| |
| |
|
|
|
|
| |
| |
| class SpectralNormStateDictHook(object): |
| |
| def __init__(self, fn): |
| self.fn = fn |
|
|
| def __call__(self, module, state_dict, prefix, local_metadata): |
| if 'spectral_norm' not in local_metadata: |
| local_metadata['spectral_norm'] = {} |
| key = self.fn.name + '.version' |
| if key in local_metadata['spectral_norm']: |
| raise RuntimeError( |
| "Unexpected key in metadata['spectral_norm']: {}".format(key)) |
| local_metadata['spectral_norm'][key] = self.fn._version |
|
|
|
|
| def spectral_norm(module, |
| name='weight', |
| n_power_iterations=1, |
| eps=1e-12, |
| dim=None): |
| r"""Applies spectral normalization to a parameter in the given module. |
| |
| .. math:: |
| \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, |
| \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} |
| |
| Spectral normalization stabilizes the training of discriminators (critics) |
| in Generative Adversarial Networks (GANs) by rescaling the weight tensor |
| with spectral norm :math:`\sigma` of the weight matrix calculated using |
| power iteration method. If the dimension of the weight tensor is greater |
| than 2, it is reshaped to 2D in power iteration method to get spectral |
| norm. This is implemented via a hook that calculates spectral norm and |
| rescales weight before every :meth:`~Module.forward` call. |
| |
| See `Spectral Normalization for Generative Adversarial Networks`_ . |
| |
| .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 |
| |
| Args: |
| module (nn.Module): containing module |
| name (str, optional): name of weight parameter |
| n_power_iterations (int, optional): number of power iterations to |
| calculate spectral norm |
| eps (float, optional): epsilon for numerical stability in |
| calculating norms |
| dim (int, optional): dimension corresponding to number of outputs, |
| the default is ``0``, except for modules that are instances of |
| ConvTranspose{1,2,3}d, when it is ``1`` |
| |
| Returns: |
| The original module with the spectral norm hook |
| |
| Example:: |
| |
| >>> m = spectral_norm(nn.Linear(20, 40)) |
| >>> m |
| Linear(in_features=20, out_features=40, bias=True) |
| >>> m.weight_u.size() |
| torch.Size([40]) |
| |
| """ |
| if dim is None: |
| if isinstance(module, |
| (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, |
| torch.nn.ConvTranspose3d)): |
| dim = 1 |
| else: |
| dim = 0 |
| SpectralNorm.apply(module, name, n_power_iterations, dim, eps) |
| return module |
|
|
|
|
| def remove_spectral_norm(module, name='weight'): |
| r"""Removes the spectral normalization reparameterization from a module. |
| |
| Args: |
| module (Module): containing module |
| name (str, optional): name of weight parameter |
| |
| Example: |
| >>> m = spectral_norm(nn.Linear(40, 10)) |
| >>> remove_spectral_norm(m) |
| """ |
| for k, hook in module._forward_pre_hooks.items(): |
| if isinstance(hook, SpectralNorm) and hook.name == name: |
| hook.remove(module) |
| del module._forward_pre_hooks[k] |
| return module |
|
|
| raise ValueError("spectral_norm of '{}' not found in {}".format( |
| name, module)) |
|
|
|
|
| def use_spectral_norm(module, use_sn=False): |
| if use_sn: |
| return spectral_norm(module) |
| return module |