| import math |
|
|
| from collections import namedtuple |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| import torchvision.models.vgg as vgg |
|
|
| from op import fused_leaky_relu |
|
|
|
|
| FeatureOutput = namedtuple( |
| "FeatureOutput", ["relu1", "relu2", "relu3", "relu4", "relu5"]) |
|
|
|
|
| def gram_matrix(y): |
| (b, ch, h, w) = y.size() |
| features = y.view(b, ch, w * h) |
| features_t = features.transpose(1, 2) |
| gram = features.bmm(features_t) / (ch * h * w) |
| return gram |
|
|
|
|
| class FeatureExtractor(nn.Module): |
| """Reference: |
| https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/3 |
| """ |
|
|
| def __init__(self): |
| super(FeatureExtractor, self).__init__() |
| self.vgg_layers = vgg.vgg19(pretrained=True).features |
| self.layer_name_mapping = { |
| '3': "relu1", |
| '8': "relu2", |
| '17': "relu3", |
| '26': "relu4", |
| '35': "relu5", |
| } |
|
|
| def forward(self, x): |
| output = {} |
| for name, module in self.vgg_layers._modules.items(): |
| x = module(x) |
| if name in self.layer_name_mapping: |
| output[self.layer_name_mapping[name]] = x |
| return FeatureOutput(**output) |
|
|
|
|
| class StyleEmbedder(nn.Module): |
| def __init__(self): |
| super(StyleEmbedder, self).__init__() |
| self.feature_extractor = FeatureExtractor() |
| self.feature_extractor.eval() |
| self.avg_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) |
|
|
| def forward(self, img): |
| N = img.shape[0] |
| features = self.feature_extractor(self.avg_pool(img)) |
|
|
| grams = [] |
| for feature in features: |
| gram = gram_matrix(feature) |
| grams.append(gram.view(N, -1)) |
| out = torch.cat(grams, dim=1) |
| return out |
|
|
|
|
| class PixelNorm(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, input): |
| return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) |
|
|
|
|
| class EqualLinear(nn.Module): |
| def __init__( |
| self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None |
| ): |
| super().__init__() |
|
|
| self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) |
|
|
| if bias: |
| self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) |
|
|
| else: |
| self.bias = None |
|
|
| self.activation = activation |
|
|
| self.scale = (1 / math.sqrt(in_dim)) * lr_mul |
| self.lr_mul = lr_mul |
|
|
| def forward(self, input): |
| if self.activation: |
| out = F.linear(input, self.weight * self.scale) |
| out = fused_leaky_relu(out, self.bias * self.lr_mul) |
|
|
| else: |
| out = F.linear( |
| input, self.weight * self.scale, bias=self.bias * self.lr_mul |
| ) |
|
|
| return out |
|
|
| def __repr__(self): |
| return ( |
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' |
| ) |
|
|
|
|
| class StyleEncoder(nn.Module): |
| def __init__( |
| self, |
| style_dim=512, |
| n_mlp=4, |
| ): |
| super().__init__() |
|
|
| self.style_dim = style_dim |
|
|
| e_dim = 610304 |
| self.embedder = StyleEmbedder() |
|
|
| layers = [] |
|
|
| layers.append(EqualLinear(e_dim, style_dim, lr_mul=1, activation='fused_lrelu')) |
| for i in range(n_mlp - 2): |
| layers.append( |
| EqualLinear( |
| style_dim, style_dim, lr_mul=1, activation='fused_lrelu' |
| ) |
| ) |
| layers.append(EqualLinear(style_dim, style_dim, lr_mul=1, activation=None)) |
| self.embedder_mlp = nn.Sequential(*layers) |
|
|
| def forward(self, image): |
| z_embed = self.embedder_mlp(self.embedder(image)) |
| return z_embed |
|
|
|
|
| class Projector(nn.Module): |
| def __init__(self, style_dim=512, n_mlp=4): |
| super().__init__() |
|
|
| layers = [] |
| for i in range(n_mlp - 1): |
| layers.append( |
| EqualLinear( |
| style_dim, style_dim, lr_mul=1, activation='fused_lrelu' |
| ) |
| ) |
| layers.append(EqualLinear(style_dim, style_dim, lr_mul=1, activation=None)) |
| self.projector = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.projector(x) |
|
|