| """ |
| https://github.com/tarun005/FLAVR/blob/main/model/FLAVR_arch.py |
| https://github.com/tarun005/FLAVR/blob/main/model/resnet_3D.py (only SEGating) |
| """ |
| import math |
| import numpy as np |
| import importlib |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class SEGating(nn.Module): |
|
|
| def __init__(self , inplanes , reduction=16): |
|
|
| super().__init__() |
|
|
| self.pool = nn.AdaptiveAvgPool3d(1) |
| self.attn_layer = nn.Sequential( |
| nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True), |
| nn.Sigmoid() |
| ) |
| |
| def forward(self , x): |
|
|
| out = self.pool(x) |
| y = self.attn_layer(out) |
| return x * y |
|
|
| def joinTensors(X1 , X2 , type="concat"): |
|
|
| if type == "concat": |
| return torch.cat([X1 , X2] , dim=1) |
| elif type == "add": |
| return X1 + X2 |
| else: |
| return X1 |
|
|
|
|
| class Conv_2d(nn.Module): |
|
|
| def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False): |
|
|
| super().__init__() |
| self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] |
|
|
| if batchnorm: |
| self.conv += [nn.BatchNorm2d(out_ch)] |
|
|
| self.conv = nn.Sequential(*self.conv) |
|
|
| def forward(self, x): |
|
|
| return self.conv(x) |
|
|
| class upConv3D(nn.Module): |
|
|
| def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False): |
|
|
| super().__init__() |
|
|
| self.upmode = upmode |
|
|
| if self.upmode=="transpose": |
| self.upconv = nn.ModuleList( |
| [nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding), |
| SEGating(out_ch) |
| ] |
| ) |
|
|
| else: |
| self.upconv = nn.ModuleList( |
| [nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False), |
| nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1), |
| SEGating(out_ch) |
| ] |
| ) |
|
|
| if batchnorm: |
| self.upconv += [nn.BatchNorm3d(out_ch)] |
|
|
| self.upconv = nn.Sequential(*self.upconv) |
|
|
| def forward(self, x): |
|
|
| return self.upconv(x) |
|
|
| class Conv_3d(nn.Module): |
|
|
| def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False): |
|
|
| super().__init__() |
| self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), |
| SEGating(out_ch) |
| ] |
|
|
| if batchnorm: |
| self.conv += [nn.BatchNorm3d(out_ch)] |
|
|
| self.conv = nn.Sequential(*self.conv) |
|
|
| def forward(self, x): |
|
|
| return self.conv(x) |
|
|
| class upConv2D(nn.Module): |
|
|
| def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False): |
|
|
| super().__init__() |
|
|
| self.upmode = upmode |
|
|
| if self.upmode=="transpose": |
| self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)] |
|
|
| else: |
| self.upconv = [ |
| nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), |
| nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1) |
| ] |
|
|
| if batchnorm: |
| self.upconv += [nn.BatchNorm2d(out_ch)] |
|
|
| self.upconv = nn.Sequential(*self.upconv) |
|
|
| def forward(self, x): |
|
|
| return self.upconv(x) |
|
|
|
|
| class UNet_3D_3D(nn.Module): |
| def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"): |
| super().__init__() |
|
|
| nf = [512 , 256 , 128 , 64] |
| out_channels = 3*n_outputs |
| self.joinType = joinType |
| self.n_outputs = n_outputs |
|
|
| growth = 2 if joinType == "concat" else 1 |
| self.lrelu = nn.LeakyReLU(0.2, True) |
|
|
| unet_3D = importlib.import_module(".resnet_3D", "vfi_models.flavr") |
| if n_outputs > 1: |
| unet_3D.useBias = True |
| self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm) |
|
|
| self.decoder = nn.Sequential( |
| Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm), |
| upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm), |
| upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm), |
| Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm), |
| upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm) |
| ) |
|
|
| self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm) |
|
|
| self.outconv = nn.Sequential( |
| nn.ReflectionPad2d(3), |
| nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0) |
| ) |
|
|
| def forward(self, images): |
|
|
| images = torch.stack(images , dim=2) |
|
|
| |
| mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True) |
| images = images-mean_ |
|
|
| x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images) |
|
|
| dx_3 = self.lrelu(self.decoder[0](x_4)) |
| dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType) |
|
|
| dx_2 = self.lrelu(self.decoder[1](dx_3)) |
| dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType) |
|
|
| dx_1 = self.lrelu(self.decoder[2](dx_2)) |
| dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType) |
|
|
| dx_0 = self.lrelu(self.decoder[3](dx_1)) |
| dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType) |
|
|
| dx_out = self.lrelu(self.decoder[4](dx_0)) |
| dx_out = torch.cat(torch.unbind(dx_out , 2) , 1) |
|
|
| out = self.lrelu(self.feature_fuse(dx_out)) |
| out = self.outconv(out) |
|
|
| out = torch.split(out, dim=1, split_size_or_sections=3) |
| mean_ = mean_.squeeze(2) |
| out = [o+mean_ for o in out] |
| |
| return out |
|
|
| class InputPadder: |
| """ Pads images such that dimensions are divisible by divisor """ |
| def __init__(self, dims, divisor=16): |
| self.ht, self.wd = dims[-2:] |
| pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor |
| pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor |
| self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] |
|
|
| def pad(self, input_tensor): |
| return F.pad(input_tensor, self._pad, mode='replicate') |
|
|
| def unpad(self, input_tensor): |
| return self._unpad(input_tensor) |
| |
| def _unpad(self, x): |
| ht, wd = x.shape[-2:] |
| c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] |
| return x[..., c[0]:c[1], c[2]:c[3]] |