import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import Normalize import pickle from pg_modules.diffaug import DiffAugment from pg_modules.blocks import conv2d, DownBlock, DownBlockPatch from pg_modules.projector import F_RandomProj from feature_networks.constants import VITS class SingleDisc(nn.Module): def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, patch=False): super().__init__() # midas channels nfc_midas = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, 256: 32, 512: 16, 1024: 8} # interpolate for start sz that are not powers of two if start_sz not in nfc_midas.keys(): sizes = np.array(list(nfc_midas.keys())) start_sz = sizes[np.argmin(abs(sizes - start_sz))] self.start_sz = start_sz # if given ndf, allocate all layers with the same ndf if ndf is None: nfc = nfc_midas else: nfc = {k: ndf for k, v in nfc_midas.items()} # for feature map discriminators with nfc not in nfc_midas # this is the case for the pretrained backbone (midas.pretrained) if nc is not None and head is None: nfc[start_sz] = nc layers = [] # Head if the initial input is the full modality if head: layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), nn.LeakyReLU(0.2, inplace=True)] # Down Blocks DB = DownBlockPatch if patch else DownBlock while start_sz > end_sz: layers.append(DB(nfc[start_sz], nfc[start_sz//2])) start_sz = start_sz // 2 layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) self.main = nn.Sequential(*layers) def forward(self, x, c): return self.main(x) class MultiScaleD(nn.Module): def __init__( self, channels, resolutions, num_discs=4, proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing cond=0, patch=False, **kwargs, ): super().__init__() assert num_discs in [1, 2, 3, 4, 5] # the first disc is on the lowest level of the backbone self.disc_in_channels = channels[:num_discs] self.disc_in_res = resolutions[:num_discs] Disc = SingleDisc mini_discs = [] for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): start_sz = res if not patch else 16 mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, patch=patch)], self.mini_discs = nn.ModuleDict(mini_discs) def forward(self, features, c, rec=False): all_logits = [] for k, disc in self.mini_discs.items(): all_logits.append(disc(features[k], c).view(features[k].size(0), -1)) all_logits = torch.cat(all_logits, dim=1) return all_logits class ProjectedDiscriminator(torch.nn.Module): def __init__( self, backbones, diffaug=True, interp224=True, backbone_kwargs={}, **kwargs ): super().__init__() self.backbones = backbones self.diffaug = diffaug self.interp224 = interp224 # get backbones and multi-scale discs feature_networks, discriminators = [], [] for i, bb_name in enumerate(backbones): feat = F_RandomProj(bb_name, **backbone_kwargs) disc = MultiScaleD( channels=feat.CHANNELS, resolutions=feat.RESOLUTIONS, **backbone_kwargs, ) feature_networks.append([bb_name, feat]) discriminators.append([bb_name, disc]) self.feature_networks = nn.ModuleDict(feature_networks) self.discriminators = nn.ModuleDict(discriminators) def train(self, mode=True): self.feature_networks = self.feature_networks.train(False) self.discriminators = self.discriminators.train(mode) return self def eval(self): return self.train(False) def forward(self, x, c): logits = [] for bb_name, feat in self.feature_networks.items(): # apply augmentation (x in [-1, 1]) x_aug = DiffAugment(x, policy='color,translation,cutout') if self.diffaug else x # transform to [0,1] x_aug = x_aug.add(1).div(2) # apply F-specific normalization x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug) # upsample if smaller, downsample if larger + VIT if self.interp224 or bb_name in VITS: x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False) # forward pass features = feat(x_n) logits += self.discriminators[bb_name](features, c) return logits