ProjectedGANCLC / pg_modules /discriminator.py
ZJW666's picture
fist version
7a59a55
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize
import pickle
from pg_modules.diffaug import DiffAugment
from pg_modules.blocks import conv2d, DownBlock, DownBlockPatch
from pg_modules.projector import F_RandomProj
from feature_networks.constants import VITS
class SingleDisc(nn.Module):
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, patch=False):
super().__init__()
# midas channels
nfc_midas = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
256: 32, 512: 16, 1024: 8}
# interpolate for start sz that are not powers of two
if start_sz not in nfc_midas.keys():
sizes = np.array(list(nfc_midas.keys()))
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
self.start_sz = start_sz
# if given ndf, allocate all layers with the same ndf
if ndf is None:
nfc = nfc_midas
else:
nfc = {k: ndf for k, v in nfc_midas.items()}
# for feature map discriminators with nfc not in nfc_midas
# this is the case for the pretrained backbone (midas.pretrained)
if nc is not None and head is None:
nfc[start_sz] = nc
layers = []
# Head if the initial input is the full modality
if head:
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True)]
# Down Blocks
DB = DownBlockPatch if patch else DownBlock
while start_sz > end_sz:
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
start_sz = start_sz // 2
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
self.main = nn.Sequential(*layers)
def forward(self, x, c):
return self.main(x)
class MultiScaleD(nn.Module):
def __init__(
self,
channels,
resolutions,
num_discs=4,
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
cond=0,
patch=False,
**kwargs,
):
super().__init__()
assert num_discs in [1, 2, 3, 4, 5]
# the first disc is on the lowest level of the backbone
self.disc_in_channels = channels[:num_discs]
self.disc_in_res = resolutions[:num_discs]
Disc = SingleDisc
mini_discs = []
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
start_sz = res if not patch else 16
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, patch=patch)],
self.mini_discs = nn.ModuleDict(mini_discs)
def forward(self, features, c, rec=False):
all_logits = []
for k, disc in self.mini_discs.items():
all_logits.append(disc(features[k], c).view(features[k].size(0), -1))
all_logits = torch.cat(all_logits, dim=1)
return all_logits
class ProjectedDiscriminator(torch.nn.Module):
def __init__(
self,
backbones,
diffaug=True,
interp224=True,
backbone_kwargs={},
**kwargs
):
super().__init__()
self.backbones = backbones
self.diffaug = diffaug
self.interp224 = interp224
# get backbones and multi-scale discs
feature_networks, discriminators = [], []
for i, bb_name in enumerate(backbones):
feat = F_RandomProj(bb_name, **backbone_kwargs)
disc = MultiScaleD(
channels=feat.CHANNELS,
resolutions=feat.RESOLUTIONS,
**backbone_kwargs,
)
feature_networks.append([bb_name, feat])
discriminators.append([bb_name, disc])
self.feature_networks = nn.ModuleDict(feature_networks)
self.discriminators = nn.ModuleDict(discriminators)
def train(self, mode=True):
self.feature_networks = self.feature_networks.train(False)
self.discriminators = self.discriminators.train(mode)
return self
def eval(self):
return self.train(False)
def forward(self, x, c):
logits = []
for bb_name, feat in self.feature_networks.items():
# apply augmentation (x in [-1, 1])
x_aug = DiffAugment(x, policy='color,translation,cutout') if self.diffaug else x
# transform to [0,1]
x_aug = x_aug.add(1).div(2)
# apply F-specific normalization
x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug)
# upsample if smaller, downsample if larger + VIT
if self.interp224 or bb_name in VITS:
x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False)
# forward pass
features = feat(x_n)
logits += self.discriminators[bb_name](features, c)
return logits