import functools from typing import Tuple, List, Dict import numpy as np from torch import Tensor import torch import torch.nn as nn from omegaconf import DictConfig, OmegaConf from src.torch_utils import persistence from src.training.networks import Discriminator as ImageDiscriminator #---------------------------------------------------------------------------- @persistence.persistent_class class Discriminator(nn.Module): """ MoCoGAN discriminator, consisting on 2 parts: ImageDiscriminator and VideoDiscriminator """ def __init__(self, cfg: DictConfig, img_channels: int, img_resolution: int, *img_discr_args, **img_discr_kwargs): super().__init__() self.cfg = cfg self.image_discr = ImageDiscriminator( img_resolution=img_resolution, img_channels=img_channels, cfg=OmegaConf.create({ 'sampling': {'num_frames_per_video': 1}, 'dummy_c': False, 'fmaps': 1.0 if img_resolution >= 512 else 0.5, 'mbstd_group_size': 4, 'concat_res': -1, }), *img_discr_args, **img_discr_kwargs, ) self.video_discr = MoCoGANVideoDiscriminator( n_channels=img_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=True, noise_sigma=0.1, image_size=img_resolution, num_t_paddings=cfg.video_discr_num_t_paddings, ) self.video_discr.apply(weights_init) def params_with_lr(self, lr: float) -> List[Dict]: return [ {'params': self.image_discr.parameters()}, {'params': self.video_discr.parameters(), 'lr': self.cfg.video_discr_lr_multiplier * lr} ] def forward(self, img: Tensor, c: Tensor, t: Tensor, **img_discr_kwargs) -> Tuple[Tensor, "None"]: """ - img has shape [batch_size * num_frames_per_video, c, h, w] - c has shape [batch_size, c_dim] - t has shape [batch_size, num_frames_per_video] """ batch_size, num_frames_per_video = t.shape image_logits = self.image_discr(img, c, t, **img_discr_kwargs)['image_logits'] # [batch_size * num_frames] # Preparing input for the video discriminator videos = img.view(batch_size, num_frames_per_video, *img.shape[1:]) # [batch_size, t, c, h, w] videos = videos.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w] video_logits = self.video_discr(videos) # (num_subdiscrs, num_layers, [batch_size, 1, out_t, out_h, out_w]) # We return a tuple for backward compatibility return {'image_logits': image_logits, 'video_logits': video_logits.flatten(start_dim=1)} #---------------------------------------------------------------------------- def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1 and hasattr(m, 'weight'): m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm3d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def get_norm_layer(norm_type='instance'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm3d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm3d, affine=False, track_running_stats=True) else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer #---------------------------------------------------------------------------- @persistence.persistent_class class VideoDiscriminator(nn.Module): def __init__(self, num_input_channels, ndf=64, n_layers=3, n_frames_per_sample=16, norm_layer=nn.InstanceNorm3d, num_sub_discrs=2, get_intermediate_feat=True): super().__init__() self.num_sub_discrs = num_sub_discrs self.n_layers = n_layers self.get_intermediate_feat = get_intermediate_feat ndf_max = 64 for i in range(num_sub_discrs): block = SubVideoDiscriminator( num_input_channels=num_input_channels, ndf=min(ndf_max, ndf * (2 ** (num_sub_discrs - 1 - i))), n_layers=n_layers, norm_layer=norm_layer, get_intermediate_feat=get_intermediate_feat) if get_intermediate_feat: for j in range(n_layers + 2): setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(block, 'model' + str(j))) else: setattr(self, 'layer' + str(i), block.model) stride = 2 if n_frames_per_sample > 16 else [1, 2, 2] self.downsample = nn.AvgPool3d( 3, stride=stride, padding=[1, 1, 1], count_include_pad=False ) def singleD_forward(self, model, input): if self.get_intermediate_feat: result = [input] for i in range(len(model)): result.append(model[i](result[-1])) return result[1:] else: return [model(input)] def forward(self, x): result = [] x = x for block_idx in range(self.num_sub_discrs): if self.get_intermediate_feat: model = [getattr(self, 'scale' + str(self.num_sub_discrs - 1 - block_idx) + '_layer' + str(j)) for j in range(self.n_layers + 2)] else: model = getattr(self, 'layer' + str(self.num_sub_discrs - 1 - block_idx)) result.append(self.singleD_forward(model, x)) if block_idx != (self.num_sub_discrs - 1): x = self.downsample(x) return result #---------------------------------------------------------------------------- @persistence.persistent_class class SubVideoDiscriminator(nn.Module): def __init__(self, num_input_channels, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm3d, get_intermediate_feat=True): super().__init__() self.get_intermediate_feat = get_intermediate_feat self.n_layers = n_layers kernel_size = 4 padw = int(np.ceil((kernel_size - 1.0) / 2)) sequence = [[ nn.Conv3d(num_input_channels, ndf, kernel_size=kernel_size, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ]] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) sequence += [[ nn.Conv3d(nf_prev, nf, kernel_size=kernel_size, stride=2, padding=padw), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] nf_prev = nf nf = min(nf * 2, 512) sequence += [[ nn.Conv3d(nf_prev, nf, kernel_size=kernel_size, stride=1, padding=padw), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] sequence += [[ nn.Conv3d(nf, 1, kernel_size=kernel_size, stride=1, padding=padw) ]] if get_intermediate_feat: for n in range(len(sequence)): setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) else: self.model = nn.Sequential(*[s for ss in sequence for s in ss]) def forward(self, x): if self.get_intermediate_feat: res = [x] for n in range(self.n_layers + 2): model = getattr(self, 'model' + str(n)) res.append(model(res[-1])) return res[1:] else: return self.model(x) #---------------------------------------------------------------------------- class MoCoGANVideoDiscriminator(nn.Module): def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64, image_size: int=64, num_t_paddings: int=0): super(MoCoGANVideoDiscriminator, self).__init__() self.n_channels = n_channels self.n_output_neurons = n_output_neurons self.use_noise = use_noise self.bn_use_gamma = bn_use_gamma layers = [ Noise(use_noise, sigma=noise_sigma), nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(2 if num_t_paddings > 0 else 0, 1, 1), bias=False), nn.LeakyReLU(0.2, inplace=True), Noise(use_noise, sigma=noise_sigma), nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(2 if num_t_paddings > 1 else 0, 1, 1), bias=False), nn.BatchNorm3d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), Noise(use_noise, sigma=noise_sigma), nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(2 if num_t_paddings > 2 else 0, 1, 1), bias=False), nn.BatchNorm3d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), Noise(use_noise, sigma=noise_sigma), nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(2 if num_t_paddings > 3 else 0, 1, 1), bias=False), nn.BatchNorm3d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), ] if image_size == 256: layers.extend([ Noise(use_noise, sigma=noise_sigma), nn.Conv3d(ndf * 8, ndf * 8, 3, stride=(1, 1, 1), padding=(1 + (1 if num_t_paddings > 4 else 0), 1, 1), bias=False), nn.BatchNorm3d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), Noise(use_noise, sigma=noise_sigma), nn.Conv3d(ndf * 8, ndf * 8, 3, stride=(1, 1, 1), padding=(1 + (1 if num_t_paddings > 5 else 0), 1, 1), bias=False), nn.BatchNorm3d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), ]) layers.extend([ nn.Conv3d(ndf * 8, n_output_neurons, kernel_size=4, stride=1, padding=(2 if num_t_paddings > 5 else 0, 0, 0), bias=False), ]) self.main = nn.Sequential(*layers) def forward(self, input): return self.main(input).squeeze() #---------------------------------------------------------------------------- class Noise(nn.Module): def __init__(self, use_noise, sigma=0.2): super(Noise, self).__init__() self.use_noise = use_noise self.sigma = sigma def forward(self, x): if self.use_noise: return x + self.sigma * torch.randn_like(x) return x #----------------------------------------------------------------------------