Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| ''' | |
| # -------------------------------------------- | |
| # Kai Zhang (github: https://github.com/cszn) | |
| # 03/Mar/2019 | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # SVD Orthogonal Regularization | |
| # -------------------------------------------- | |
| def regularizer_orth(m): | |
| """ | |
| # ---------------------------------------- | |
| # SVD Orthogonal Regularization | |
| # ---------------------------------------- | |
| # Applies regularization to the training by performing the | |
| # orthogonalization technique described in the paper | |
| # This function is to be called by the torch.nn.Module.apply() method, | |
| # which applies svd_orthogonalization() to every layer of the model. | |
| # usage: net.apply(regularizer_orth) | |
| # ---------------------------------------- | |
| """ | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| w = m.weight.data.clone() | |
| c_out, c_in, f1, f2 = w.size() | |
| # dtype = m.weight.data.type() | |
| w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) | |
| # self.netG.apply(svd_orthogonalization) | |
| u, s, v = torch.svd(w) | |
| s[s > 1.5] = s[s > 1.5] - 1e-4 | |
| s[s < 0.5] = s[s < 0.5] + 1e-4 | |
| w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) | |
| m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) | |
| else: | |
| pass | |
| # -------------------------------------------- | |
| # SVD Orthogonal Regularization | |
| # -------------------------------------------- | |
| def regularizer_orth2(m): | |
| """ | |
| # ---------------------------------------- | |
| # Applies regularization to the training by performing the | |
| # orthogonalization technique described in the paper | |
| # This function is to be called by the torch.nn.Module.apply() method, | |
| # which applies svd_orthogonalization() to every layer of the model. | |
| # usage: net.apply(regularizer_orth2) | |
| # ---------------------------------------- | |
| """ | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| w = m.weight.data.clone() | |
| c_out, c_in, f1, f2 = w.size() | |
| # dtype = m.weight.data.type() | |
| w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) | |
| u, s, v = torch.svd(w) | |
| s_mean = s.mean() | |
| s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 | |
| s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 | |
| w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) | |
| m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) | |
| else: | |
| pass | |
| def regularizer_clip(m): | |
| """ | |
| # ---------------------------------------- | |
| # usage: net.apply(regularizer_clip) | |
| # ---------------------------------------- | |
| """ | |
| eps = 1e-4 | |
| c_min = -1.5 | |
| c_max = 1.5 | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1 or classname.find('Linear') != -1: | |
| w = m.weight.data.clone() | |
| w[w > c_max] -= eps | |
| w[w < c_min] += eps | |
| m.weight.data = w | |
| if m.bias is not None: | |
| b = m.bias.data.clone() | |
| b[b > c_max] -= eps | |
| b[b < c_min] += eps | |
| m.bias.data = b | |
| # elif classname.find('BatchNorm2d') != -1: | |
| # | |
| # rv = m.running_var.data.clone() | |
| # rm = m.running_mean.data.clone() | |
| # | |
| # if m.affine: | |
| # m.weight.data | |
| # m.bias.data | |