Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn.utils import spectral_norm | |
| from modeling.base import BaseNetwork | |
| from layers.blocks import DestyleResBlock, Destyler, ResBlock | |
| class IFRNet(BaseNetwork): | |
| def __init__(self, base_n_channels, destyler_n_channels): | |
| super(IFRNet, self).__init__() | |
| self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features | |
| self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2) | |
| self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2) | |
| self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4) | |
| self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1) | |
| self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4) | |
| self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) | |
| self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8) | |
| self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1) | |
| self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8) | |
| self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) | |
| self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16) | |
| self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1) | |
| self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0) | |
| self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) | |
| self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) | |
| self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) | |
| self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) | |
| self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1) | |
| self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1) | |
| self.init_weights(init_type="normal", gain=0.02) | |
| def forward(self, x, vgg_feat): | |
| b_size, ch, h, w = vgg_feat.size() | |
| vgg_feat = vgg_feat.view(b_size, ch * h * w) | |
| vgg_feat = self.destyler(vgg_feat) | |
| out = self.ds_res1(x, self.ds_fc1(vgg_feat)) | |
| out = self.ds_res2(out, self.ds_fc2(vgg_feat)) | |
| out = self.ds_res3(out, self.ds_fc3(vgg_feat)) | |
| out = self.ds_res4(out, self.ds_fc4(vgg_feat)) | |
| out = self.ds_res5(out, self.ds_fc5(vgg_feat)) | |
| aux = self.ds_res6(out, self.ds_fc6(vgg_feat)) | |
| out = self.upsample(aux) | |
| out = self.res1(out) | |
| out = self.res2(out) | |
| out = self.upsample(out) | |
| out = self.res3(out) | |
| out = self.res4(out) | |
| out = self.upsample(out) | |
| out = self.res5(out) | |
| out = self.conv1(out) | |
| return out, aux | |
| class CIFR_Encoder(IFRNet): | |
| def __init__(self, base_n_channels, destyler_n_channels): | |
| super(CIFR_Encoder, self).__init__(base_n_channels, destyler_n_channels) | |
| def forward(self, x, vgg_feat): | |
| b_size, ch, h, w = vgg_feat.size() | |
| vgg_feat = vgg_feat.view(b_size, ch * h * w) | |
| vgg_feat = self.destyler(vgg_feat) | |
| feat1 = self.ds_res1(x, self.ds_fc1(vgg_feat)) | |
| feat2 = self.ds_res2(feat1, self.ds_fc2(vgg_feat)) | |
| feat3 = self.ds_res3(feat2, self.ds_fc3(vgg_feat)) | |
| feat4 = self.ds_res4(feat3, self.ds_fc4(vgg_feat)) | |
| feat5 = self.ds_res5(feat4, self.ds_fc5(vgg_feat)) | |
| feat6 = self.ds_res6(feat5, self.ds_fc6(vgg_feat)) | |
| feats = [feat1, feat2, feat3, feat4, feat5, feat6] | |
| out = self.upsample(feat6) | |
| out = self.res1(out) | |
| out = self.res2(out) | |
| out = self.upsample(out) | |
| out = self.res3(out) | |
| out = self.res4(out) | |
| out = self.upsample(out) | |
| out = self.res5(out) | |
| out = self.conv1(out) | |
| return out, feats | |
| class Normalize(nn.Module): | |
| def __init__(self, power=2): | |
| super(Normalize, self).__init__() | |
| self.power = power | |
| def forward(self, x): | |
| norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) | |
| out = x.div(norm + 1e-7) | |
| return out | |
| class PatchSampleF(BaseNetwork): | |
| def __init__(self, base_n_channels, style_or_content, use_mlp=False, nc=256): | |
| # potential issues: currently, we use the same patch_ids for multiple images in the batch | |
| super(PatchSampleF, self).__init__() | |
| self.is_content = True if style_or_content == "content" else False | |
| self.l2norm = Normalize(2) | |
| self.use_mlp = use_mlp | |
| self.nc = nc # hard-coded | |
| self.mlp_0 = nn.Sequential(*[nn.Linear(base_n_channels, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() | |
| self.mlp_1 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() | |
| self.mlp_2 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() | |
| self.mlp_3 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() | |
| self.mlp_4 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() | |
| self.mlp_5 = nn.Sequential(*[nn.Linear(base_n_channels * 8, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() | |
| self.init_weights(init_type="normal", gain=0.02) | |
| def gram_matrix(x): | |
| # a, b, c, d = x.size() # a=batch size(=1) | |
| a, b = x.size() | |
| # b=number of feature maps | |
| # (c,d)=dimensions of a f. map (N=c*d) | |
| # features = x.view(a * b, c * d) # resise F_XL into \hat F_XL | |
| G = torch.mm(x, x.t()) # compute the gram product | |
| # we 'normalize' the values of the gram matrix | |
| # by dividing by the number of element in each feature maps. | |
| return G.div(a * b) | |
| def forward(self, feats, num_patches=64, patch_ids=None): | |
| return_ids = [] | |
| return_feats = [] | |
| for feat_id, feat in enumerate(feats): | |
| B, C, H, W = feat.shape | |
| feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) | |
| if num_patches > 0: | |
| if patch_ids is not None: | |
| patch_id = patch_ids[feat_id] | |
| else: | |
| patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) | |
| patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device) | |
| x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1]) | |
| else: | |
| x_sample = feat_reshape | |
| patch_id = [] | |
| if self.use_mlp: | |
| mlp = getattr(self, 'mlp_%d' % feat_id) | |
| x_sample = mlp(x_sample) | |
| if not self.is_content: | |
| x_sample = self.gram_matrix(x_sample) | |
| return_ids.append(patch_id) | |
| x_sample = self.l2norm(x_sample) | |
| if num_patches == 0: | |
| x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W]) | |
| return_feats.append(x_sample) | |
| return return_feats, return_ids | |
| class MLP(nn.Module): | |
| def __init__(self, base_n_channels, out_features=14): | |
| super(MLP, self).__init__() | |
| self.aux_classifier = nn.Sequential( | |
| nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1), | |
| nn.MaxPool2d(2), | |
| # nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1), | |
| # nn.MaxPool2d(2), | |
| Flatten(), | |
| nn.Linear(base_n_channels * 8 * 8 * 2, out_features), | |
| # nn.Softmax(dim=-1) | |
| ) | |
| def forward(self, x): | |
| return self.aux_classifier(x) | |
| class Flatten(nn.Module): | |
| def forward(self, input): | |
| """ | |
| Note that input.size(0) is usually the batch size. | |
| So what it does is that given any input with input.size(0) # of batches, | |
| will flatten to be 1 * nb_elements. | |
| """ | |
| batch_size = input.size(0) | |
| out = input.view(batch_size, -1) | |
| return out # (batch_size, *size) | |
| class Discriminator(BaseNetwork): | |
| def __init__(self, base_n_channels): | |
| """ | |
| img_size : (int, int, int) | |
| Height and width must be powers of 2. E.g. (32, 32, 1) or | |
| (64, 128, 3). Last number indicates number of channels, e.g. 1 for | |
| grayscale or 3 for RGB | |
| """ | |
| super(Discriminator, self).__init__() | |
| self.image_to_features = nn.Sequential( | |
| spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| # spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)), | |
| # nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| output_size = 8 * base_n_channels * 3 * 3 | |
| self.features_to_prob = nn.Sequential( | |
| spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)), | |
| Flatten(), | |
| nn.Linear(output_size, 1) | |
| ) | |
| self.init_weights(init_type="normal", gain=0.02) | |
| def forward(self, input_data): | |
| x = self.image_to_features(input_data) | |
| return self.features_to_prob(x) | |
| class PatchDiscriminator(Discriminator): | |
| def __init__(self, base_n_channels): | |
| super(PatchDiscriminator, self).__init__(base_n_channels) | |
| self.features_to_prob = nn.Sequential( | |
| spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)), | |
| Flatten() | |
| ) | |
| def forward(self, input_data): | |
| x = self.image_to_features(input_data) | |
| return self.features_to_prob(x) | |
| if __name__ == '__main__': | |
| import torchvision | |
| ifrnet = CIFR_Encoder(32, 128).cuda() | |
| x = torch.rand((2, 3, 256, 256)).cuda() | |
| vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda() | |
| with torch.no_grad(): | |
| vgg_feat = vgg16(x) | |
| output, feats = ifrnet(x, vgg_feat) | |
| print(output.size()) | |
| for i, feat in enumerate(feats): | |
| print(i, feat.size()) | |
| disc = Discriminator(32).cuda() | |
| d_out = disc(output) | |
| print(d_out.size()) | |
| patch_disc = PatchDiscriminator(32).cuda() | |
| p_d_out = patch_disc(output) | |
| print(p_d_out.size()) | |