| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class AdaptiveInstanceNorm1d(nn.Module): |
| def __init__(self, num_features, eps=1e-5, momentum=0.1): |
| super(AdaptiveInstanceNorm1d, self).__init__() |
| self.num_features = num_features |
| self.eps = eps |
| self.momentum = momentum |
| self.weight = None |
| self.bias = None |
| self.register_buffer('running_mean', torch.zeros(num_features)) |
| self.register_buffer('running_var', torch.ones(num_features)) |
|
|
| def forward(self, x, direct_weighting=False, no_std=False): |
| assert self.weight is not None and \ |
| self.bias is not None, "Please assign AdaIN weight first" |
| |
| x = x.permute(1,2,0) |
|
|
| b, c = x.size(0), x.size(1) |
| running_mean = self.running_mean.repeat(b) |
| running_var = self.running_var.repeat(b) |
| |
|
|
| if direct_weighting: |
| x_reshaped = x.contiguous().view(b * c) |
| if no_std: |
| out = x_reshaped + self.bias |
| else: |
| out = x_reshaped.mul(self.weight) + self.bias |
| out = out.view(b, c, *x.size()[2:]) |
| else: |
| x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) |
| out = F.batch_norm( |
| x_reshaped, running_mean, running_var, self.weight, self.bias, |
| True, self.momentum, self.eps) |
| out = out.view(b, c, *x.size()[2:]) |
|
|
| |
| out = out.permute(2,0,1) |
| return out |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + '(' + str(self.num_features) + ')' |
|
|
| def assign_adain_params(adain_params, model): |
| |
| for m in model.modules(): |
| if m.__class__.__name__ == "AdaptiveInstanceNorm1d": |
| mean = adain_params[: , : m.num_features] |
| std = adain_params[: , m.num_features: 2 * m.num_features] |
| m.bias = mean.contiguous().view(-1) |
| m.weight = std.contiguous().view(-1) |
| if adain_params.size(1) > 2 * m.num_features: |
| adain_params = adain_params[: , 2 * m.num_features:] |
|
|
|
|
| def get_num_adain_params(model): |
| |
| num_adain_params = 0 |
| for m in model.modules(): |
| if m.__class__.__name__ == "AdaptiveInstanceNorm1d": |
| num_adain_params += 2 * m.num_features |
| return num_adain_params |
|
|