Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Normalize | |
| import pickle | |
| from pg_modules.diffaug import DiffAugment | |
| from pg_modules.blocks import conv2d, DownBlock, DownBlockPatch | |
| from pg_modules.projector import F_RandomProj | |
| from feature_networks.constants import VITS | |
| class SingleDisc(nn.Module): | |
| def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, patch=False): | |
| super().__init__() | |
| # midas channels | |
| nfc_midas = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, | |
| 256: 32, 512: 16, 1024: 8} | |
| # interpolate for start sz that are not powers of two | |
| if start_sz not in nfc_midas.keys(): | |
| sizes = np.array(list(nfc_midas.keys())) | |
| start_sz = sizes[np.argmin(abs(sizes - start_sz))] | |
| self.start_sz = start_sz | |
| # if given ndf, allocate all layers with the same ndf | |
| if ndf is None: | |
| nfc = nfc_midas | |
| else: | |
| nfc = {k: ndf for k, v in nfc_midas.items()} | |
| # for feature map discriminators with nfc not in nfc_midas | |
| # this is the case for the pretrained backbone (midas.pretrained) | |
| if nc is not None and head is None: | |
| nfc[start_sz] = nc | |
| layers = [] | |
| # Head if the initial input is the full modality | |
| if head: | |
| layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True)] | |
| # Down Blocks | |
| DB = DownBlockPatch if patch else DownBlock | |
| while start_sz > end_sz: | |
| layers.append(DB(nfc[start_sz], nfc[start_sz//2])) | |
| start_sz = start_sz // 2 | |
| layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) | |
| self.main = nn.Sequential(*layers) | |
| def forward(self, x, c): | |
| return self.main(x) | |
| class MultiScaleD(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| resolutions, | |
| num_discs=4, | |
| proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing | |
| cond=0, | |
| patch=False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| assert num_discs in [1, 2, 3, 4, 5] | |
| # the first disc is on the lowest level of the backbone | |
| self.disc_in_channels = channels[:num_discs] | |
| self.disc_in_res = resolutions[:num_discs] | |
| Disc = SingleDisc | |
| mini_discs = [] | |
| for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): | |
| start_sz = res if not patch else 16 | |
| mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, patch=patch)], | |
| self.mini_discs = nn.ModuleDict(mini_discs) | |
| def forward(self, features, c, rec=False): | |
| all_logits = [] | |
| for k, disc in self.mini_discs.items(): | |
| all_logits.append(disc(features[k], c).view(features[k].size(0), -1)) | |
| all_logits = torch.cat(all_logits, dim=1) | |
| return all_logits | |
| class ProjectedDiscriminator(torch.nn.Module): | |
| def __init__( | |
| self, | |
| backbones, | |
| diffaug=True, | |
| interp224=True, | |
| backbone_kwargs={}, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.backbones = backbones | |
| self.diffaug = diffaug | |
| self.interp224 = interp224 | |
| # get backbones and multi-scale discs | |
| feature_networks, discriminators = [], [] | |
| for i, bb_name in enumerate(backbones): | |
| feat = F_RandomProj(bb_name, **backbone_kwargs) | |
| disc = MultiScaleD( | |
| channels=feat.CHANNELS, | |
| resolutions=feat.RESOLUTIONS, | |
| **backbone_kwargs, | |
| ) | |
| feature_networks.append([bb_name, feat]) | |
| discriminators.append([bb_name, disc]) | |
| self.feature_networks = nn.ModuleDict(feature_networks) | |
| self.discriminators = nn.ModuleDict(discriminators) | |
| def train(self, mode=True): | |
| self.feature_networks = self.feature_networks.train(False) | |
| self.discriminators = self.discriminators.train(mode) | |
| return self | |
| def eval(self): | |
| return self.train(False) | |
| def forward(self, x, c): | |
| logits = [] | |
| for bb_name, feat in self.feature_networks.items(): | |
| # apply augmentation (x in [-1, 1]) | |
| x_aug = DiffAugment(x, policy='color,translation,cutout') if self.diffaug else x | |
| # transform to [0,1] | |
| x_aug = x_aug.add(1).div(2) | |
| # apply F-specific normalization | |
| x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug) | |
| # upsample if smaller, downsample if larger + VIT | |
| if self.interp224 or bb_name in VITS: | |
| x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False) | |
| # forward pass | |
| features = feat(x_n) | |
| logits += self.discriminators[bb_name](features, c) | |
| return logits | |