IMTalker / renderer /modules.py
cbsjtu01's picture
add app.py
9400036
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from renderer.lia_resblocks import ConvLayer
import torch.nn.utils.spectral_norm as spectral_norm
class NormLayer(nn.Module):
def __init__(self, num_features, norm_type='batch'):
super().__init__()
if norm_type == 'batch':
self.norm = nn.BatchNorm2d(num_features)
elif norm_type == 'instance':
self.norm = nn.InstanceNorm2d(num_features)
elif norm_type == 'layer':
self.norm = nn.GroupNorm(1, num_features)
else:
raise ValueError(f"Unsupported normalization type: {norm_type}")
def forward(self, x):
return self.norm(x)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
activation=nn.LeakyReLU, norm_type='batch'):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
self.norm = NormLayer(out_channels, norm_type)
self.activation = activation(inplace=True) if activation else None
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class FeatResBlock(nn.Module):
def __init__(self, channels, dropout_rate=0, activation=nn.LeakyReLU, norm_type='batch'):
super().__init__()
self.conv1 = ConvBlock(channels, channels, activation=activation, norm_type=norm_type)
self.conv2 = ConvBlock(channels, channels, activation=None, norm_type=norm_type)
self.activation = activation(inplace=True) if activation else None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
out += residual
if self.activation:
out = self.activation(out)
return out
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 3, downsample=True)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip)
return out
class ConvResBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_rate=0, activation=nn.LeakyReLU,
norm_type='batch'):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm = NormLayer(out_channels, norm_type)
self.activation = activation(inplace=True) if activation else None
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.feat_res_block1 = FeatResBlock(out_channels, dropout_rate, activation, norm_type)
self.feat_res_block2 = FeatResBlock(out_channels, dropout_rate, activation, norm_type)
def forward(self, x):
out = self.conv1(x)
out = self.norm(out)
out = self.activation(out)
out = self.conv2(out)
out = self.feat_res_block1(out)
out = self.feat_res_block2(out)
return out
# this is only used on the densefeatureencoder
class DownConvResBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_rate=0, activation=nn.LeakyReLU,
norm_type='batch'):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm = NormLayer(out_channels, norm_type)
self.activation = activation(inplace=True) if activation else None
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.feat_res_block1 = FeatResBlock(out_channels, dropout_rate, activation, norm_type)
self.feat_res_block2 = FeatResBlock(out_channels, dropout_rate, activation, norm_type)
def forward(self, x):
out = self.conv1(x)
out = self.norm(out)
out = self.activation(out)
out = self.avgpool(out)
out = self.conv2(out)
out = self.feat_res_block1(out)
out = self.feat_res_block2(out)
return out
# this is used on the framedecoder / enhancedframedecoder
class UpConvResBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_rate=0, activation=nn.LeakyReLU,
norm_type='batch', upsample_mode='nearest'):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode=upsample_mode)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm = NormLayer(out_channels, norm_type)
self.activation = activation(inplace=True) if activation else None
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.feat_res_block1 = FeatResBlock(out_channels, dropout_rate, activation, norm_type)
self.feat_res_block2 = FeatResBlock(out_channels, dropout_rate, activation, norm_type)
def forward(self, x):
out = self.upsample(x)
out = self.conv1(out)
out = self.norm(out)
out = self.activation(out)
out = self.conv2(out)
out = self.feat_res_block1(out)
out = self.feat_res_block2(out)
return out
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
nhidden = 128
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
def forward(self, x, segmap):
normalized = self.param_free_norm(x)
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
def forward(self, x, seg1):
x_s = self.shortcut(x, seg1)
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
out = x_s + dx
return out
def shortcut(self, x, seg1):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg1))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
class SPADEDecoder(nn.Module):
def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
for i in range(num_down_blocks):
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
self.upscale = upscale
super().__init__()
norm_G = 'spadespectralinstance'
label_num_channels = input_channels # 256
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
self.up = nn.Upsample(scale_factor=2)
if self.upscale is None or self.upscale <= 1:
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
else:
self.conv_img = nn.Sequential(
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
nn.PixelShuffle(upscale_factor=2)
)
def forward(self, feature):
seg = feature # Bx256x64x64
x = self.fc(feature) # Bx512x64x64
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.G_middle_2(x, seg)
x = self.G_middle_3(x, seg)
x = self.G_middle_4(x, seg)
x = self.G_middle_5(x, seg)
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
x = torch.sigmoid(x) # Bx3xHxW
return x