import torch import torch.nn as nn from renderer.modules import DownConvResBlock, ResBlock, UpConvResBlock, ConvResBlock from renderer.attention_modules import CrossAttention, SelfAttention from renderer.lia_resblocks import StyledConv,EqualConv2d,EqualLinear class IdentityEncoder(nn.Module): def __init__(self, in_channels=3, output_channels=[64, 128, 256, 512, 512, 512], initial_channels=32, dm=512): super(IdentityEncoder, self).__init__() self.initial_conv = nn.Sequential( nn.Conv2d(in_channels, initial_channels, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(initial_channels), nn.ReLU(inplace=True) ) self.down_block_0 = DownConvResBlock(initial_channels, initial_channels) self.down_blocks = nn.ModuleList() current_channels = initial_channels for out_channels in output_channels: if out_channels==32:continue self.down_blocks.append(DownConvResBlock(current_channels, out_channels)) current_channels = out_channels self.equalconv = EqualConv2d(output_channels[-1], output_channels[-1], kernel_size=3, stride=1, padding=1) self.linear_layers = nn.ModuleList([EqualLinear(output_channels[-1], output_channels[-1]) for _ in range(4)]) self.final_linear = EqualLinear(output_channels[-1], dm) self.activation = nn.LeakyReLU(0.2) def forward(self, x): features = [] x = self.initial_conv(x) x = self.down_block_0(x) features.append(x) for block in self.down_blocks: x = block(x) features.append(x) x = x.view(x.size(0), x.size(1), -1).mean(dim=2) for linear_layer in self.linear_layers: x = self.activation(linear_layer(x)) x = self.final_linear(x) return features[::-1], x class MotionEncoder(nn.Module): def __init__(self, initial_channels=64, output_channels=[64, 128, 256, 512, 512, 512], dm=32): super(MotionEncoder, self).__init__() self.conv1 = nn.Conv2d(3, initial_channels, kernel_size=3, stride=1, padding=1) self.activation = nn.LeakyReLU(0.2) self.res_blocks = nn.ModuleList() in_channels = initial_channels for out_channels in output_channels: self.res_blocks.append(ResBlock(in_channels, out_channels)) in_channels = out_channels self.equalconv = EqualConv2d(output_channels[-1], output_channels[-1], kernel_size=3, stride=1, padding=1) self.linear_layers = nn.ModuleList([EqualLinear(output_channels[-1], output_channels[-1]) for _ in range(4)]) self.final_linear = EqualLinear(output_channels[-1], dm) def forward(self, x): x = self.activation(self.conv1(x)) for res_block in self.res_blocks: x = res_block(x) x = self.equalconv(x) x = x.view(x.size(0), x.size(1), -1).mean(dim=2) for linear_layer in self.linear_layers: x = self.activation(linear_layer(x)) x = self.final_linear(x) return x class MotionDecoder(nn.Module): def __init__(self, latent_dim=32, const_dim=32): super().__init__() self.const = nn.Parameter(torch.randn(1, const_dim, 4, 4)) self.style_conv_layers = nn.ModuleList([ StyledConv(const_dim, 512, 3, latent_dim), StyledConv(512, 512, 3, latent_dim, upsample=True), StyledConv(512, 512, 3, latent_dim), StyledConv(512, 512, 3, latent_dim), StyledConv(512, 512, 3, latent_dim, upsample=True), StyledConv(512, 512, 3, latent_dim), StyledConv(512, 512, 3, latent_dim), StyledConv(512, 256, 3, latent_dim, upsample=True), StyledConv(256, 256, 3, latent_dim), StyledConv(256, 256, 3, latent_dim), StyledConv(256, 128, 3, latent_dim, upsample=True), StyledConv(128, 128, 3, latent_dim), StyledConv(128, 128, 3, latent_dim) ]) def forward(self, t): x = self.const.repeat(t.shape[0], 1, 1, 1) m1, m2, m3, m4 = None, None, None, None for i, layer in enumerate(self.style_conv_layers): x = layer(x, t) if i == 3: m1 = x elif i == 6: m2 = x elif i == 9: m3 = x elif i == 12: m4 = x return m1, m2, m3, m4 class SynthesisNetwork(nn.Module): def __init__(self, args, feature_dims, spatial_dims): super().__init__() self.args = args feature_dims_rev = feature_dims[::-1] spatial_dims_rev = spatial_dims[::-1] self.upconv_blocks = nn.ModuleList([ UpConvResBlock(feature_dims_rev[i], feature_dims_rev[i+1]) for i in range(len(feature_dims_rev) - 1) ]) self.resblocks = nn.ModuleList([ ConvResBlock(feature_dims_rev[i+1]*2, feature_dims_rev[i+1]) for i in range(len(feature_dims_rev) - 1) ]) self.transformer_blocks = nn.ModuleList() for i in range(len(spatial_dims_rev) - 1): s_dim = spatial_dims_rev[i+1] f_dim = feature_dims_rev[i+1] self.transformer_blocks.append( SelfAttention(args=args, dim=f_dim, resolution=(s_dim, s_dim)) ) self.final_conv = nn.Sequential( nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(feature_dims_rev[-1], 3*4, kernel_size=3, padding=1), nn.PixelShuffle(upscale_factor=2), nn.Sigmoid() ) def forward(self, features_align): x = features_align[0] for i in range(len(self.upconv_blocks)): x = self.upconv_blocks[i](x) x = torch.cat([x, features_align[i + 1]], dim=1) x = self.resblocks[i](x) x = self.transformer_blocks[i](x) return self.final_conv(x) class IdentidyAdaptive(nn.Module): def __init__(self, dim_mot=32, dim_app=512, depth=4): super().__init__() self.in_layer = EqualLinear(dim_app+dim_mot, dim_app) self.linear_layers = nn.ModuleList([EqualLinear(dim_app, dim_app) for _ in range(depth)]) self.final_linear = EqualLinear(dim_app, dim_mot) self.activation = nn.LeakyReLU(0.2) self.scale_activation = nn.Sigmoid() def forward(self, mot, app): x = torch.cat((mot, app), dim=-1) x = self.in_layer(x) for linear_layer in self.linear_layers: x = self.activation(linear_layer(x)) out = self.final_linear(x) return out class IMTRenderer(nn.Module): def __init__(self, args): super().__init__() self.args = args self.feature_dims = [32, 64, 128, 256, 512, 512] self.motion_dims = self.feature_dims self.spatial_dims = [256, 128, 64, 32, 16, 8] self.dense_feature_encoder = IdentityEncoder(output_channels=self.feature_dims) self.latent_token_encoder = MotionEncoder(initial_channels=64, output_channels=[128, 256, 512, 512, 512]) self.latent_token_decoder = MotionDecoder() self.frame_decoder = SynthesisNetwork(args, self.feature_dims, self.spatial_dims) self.adapt = IdentidyAdaptive() self.imt = nn.ModuleList() for dim, s_dim in zip(self.feature_dims[::-1], self.spatial_dims[::-1]): self.imt.append(CrossAttention(args=args, dim=dim, resolution=(s_dim, s_dim))) def decode(self, A, B, C): num_levels = len(self.spatial_dims) aligned_features = [None] * num_levels attention_map = None for i in range(num_levels): attention_block = self.imt[i] if attention_block.is_standard_attention: aligned_feature, attention_map = attention_block.coarse_stage(A[i], B[i], C[i]) aligned_features[i] = aligned_feature else: aligned_feature = attention_block.fine_stage(C[i], attn=attention_map) aligned_features[i] = aligned_feature output_frame = self.frame_decoder(aligned_features) return output_frame def app_encode(self, x): f_r, id = self.dense_feature_encoder(x) return f_r, id def mot_encode(self, x): mot_latent = self.latent_token_encoder(x) return mot_latent def mot_decode(self, x): mot_map = self.latent_token_decoder(x) return mot_map def id_adapt(self, t, id): return self.adapt(t, id) def forward(self, x_current, x_reference): f_r, i_r = self.app_encode(x_reference) t_r = self.mot_encode(x_reference) t_c = self.mot_encode(x_current) ta_r = self.adapt(t_r, i_r) ta_c = self.adapt(t_c, i_r) ma_r = self.mot_decode(ta_r) ma_c = self.mot_decode(ta_c) output_frame = self.decode(ma_c, ma_r, f_r) return output_frame, t_c