Spaces:
Runtime error
Runtime error
File size: 4,983 Bytes
7a59a55 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize
import pickle
from pg_modules.diffaug import DiffAugment
from pg_modules.blocks import conv2d, DownBlock, DownBlockPatch
from pg_modules.projector import F_RandomProj
from feature_networks.constants import VITS
class SingleDisc(nn.Module):
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, patch=False):
super().__init__()
# midas channels
nfc_midas = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
256: 32, 512: 16, 1024: 8}
# interpolate for start sz that are not powers of two
if start_sz not in nfc_midas.keys():
sizes = np.array(list(nfc_midas.keys()))
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
self.start_sz = start_sz
# if given ndf, allocate all layers with the same ndf
if ndf is None:
nfc = nfc_midas
else:
nfc = {k: ndf for k, v in nfc_midas.items()}
# for feature map discriminators with nfc not in nfc_midas
# this is the case for the pretrained backbone (midas.pretrained)
if nc is not None and head is None:
nfc[start_sz] = nc
layers = []
# Head if the initial input is the full modality
if head:
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True)]
# Down Blocks
DB = DownBlockPatch if patch else DownBlock
while start_sz > end_sz:
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
start_sz = start_sz // 2
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
self.main = nn.Sequential(*layers)
def forward(self, x, c):
return self.main(x)
class MultiScaleD(nn.Module):
def __init__(
self,
channels,
resolutions,
num_discs=4,
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
cond=0,
patch=False,
**kwargs,
):
super().__init__()
assert num_discs in [1, 2, 3, 4, 5]
# the first disc is on the lowest level of the backbone
self.disc_in_channels = channels[:num_discs]
self.disc_in_res = resolutions[:num_discs]
Disc = SingleDisc
mini_discs = []
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
start_sz = res if not patch else 16
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, patch=patch)],
self.mini_discs = nn.ModuleDict(mini_discs)
def forward(self, features, c, rec=False):
all_logits = []
for k, disc in self.mini_discs.items():
all_logits.append(disc(features[k], c).view(features[k].size(0), -1))
all_logits = torch.cat(all_logits, dim=1)
return all_logits
class ProjectedDiscriminator(torch.nn.Module):
def __init__(
self,
backbones,
diffaug=True,
interp224=True,
backbone_kwargs={},
**kwargs
):
super().__init__()
self.backbones = backbones
self.diffaug = diffaug
self.interp224 = interp224
# get backbones and multi-scale discs
feature_networks, discriminators = [], []
for i, bb_name in enumerate(backbones):
feat = F_RandomProj(bb_name, **backbone_kwargs)
disc = MultiScaleD(
channels=feat.CHANNELS,
resolutions=feat.RESOLUTIONS,
**backbone_kwargs,
)
feature_networks.append([bb_name, feat])
discriminators.append([bb_name, disc])
self.feature_networks = nn.ModuleDict(feature_networks)
self.discriminators = nn.ModuleDict(discriminators)
def train(self, mode=True):
self.feature_networks = self.feature_networks.train(False)
self.discriminators = self.discriminators.train(mode)
return self
def eval(self):
return self.train(False)
def forward(self, x, c):
logits = []
for bb_name, feat in self.feature_networks.items():
# apply augmentation (x in [-1, 1])
x_aug = DiffAugment(x, policy='color,translation,cutout') if self.diffaug else x
# transform to [0,1]
x_aug = x_aug.add(1).div(2)
# apply F-specific normalization
x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug)
# upsample if smaller, downsample if larger + VIT
if self.interp224 or bb_name in VITS:
x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False)
# forward pass
features = feat(x_n)
logits += self.discriminators[bb_name](features, c)
return logits
|