|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class ResidualDenseBlock_4C(nn.Module): |
|
|
def __init__(self, nf=64, gc = 32, bias=True): |
|
|
super(ResidualDenseBlock_4C, self).__init__() |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) |
|
|
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) |
|
|
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) |
|
|
self.conv4 = nn.Conv2d(nf + 3 * gc, nf, 3, 1, 1, bias=bias) |
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.lrelu(self.conv1(x)) |
|
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) |
|
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) |
|
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) |
|
|
return x4 * 0.2 + x |
|
|
|
|
|
class RRDB(nn.Module): |
|
|
'''Residual in Residual Dense Block''' |
|
|
|
|
|
def __init__(self, nf): |
|
|
super(RRDB, self).__init__() |
|
|
gc = nf // 2 |
|
|
self.RDB1 = ResidualDenseBlock_4C(nf, gc) |
|
|
self.RDB2 = ResidualDenseBlock_4C(nf, gc) |
|
|
self.RDB3 = ResidualDenseBlock_4C(nf, gc) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.RDB1(x) |
|
|
out = self.RDB2(out) |
|
|
out = self.RDB3(out) |
|
|
return out * 0.2 + x |
|
|
|
|
|
|
|
|
|
|
|
class Up(nn.Module): |
|
|
"""Upscaling then double conv""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, bilinear=True, skip=True, scale=2, bn=True, motion=False): |
|
|
super().__init__() |
|
|
factor = scale |
|
|
|
|
|
if bilinear: |
|
|
if skip: |
|
|
self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True) |
|
|
self.conv = ConvLayer(in_channels, out_channels, bn=bn) |
|
|
|
|
|
else: |
|
|
self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True) |
|
|
self.conv = ConvLayer(in_channels, out_channels) |
|
|
else: |
|
|
if skip: |
|
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor) |
|
|
self.conv = ConvLayer(out_channels*2, out_channels, bn=bn, motion=motion) |
|
|
else: |
|
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor) |
|
|
self.conv = ConvLayer(out_channels, out_channels, bn=bn, motion=motion) |
|
|
|
|
|
def forward(self, x1, x2=None): |
|
|
|
|
|
x1 = self.up(x1) |
|
|
if x2 is None: |
|
|
return self.conv(x1) |
|
|
|
|
|
diffY = x2.size()[2] - x1.size()[2] |
|
|
diffX = x2.size()[3] - x1.size()[3] |
|
|
|
|
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
|
|
diffY // 2, diffY - diffY // 2]) |
|
|
|
|
|
|
|
|
|
|
|
x = torch.cat([x2, x1], dim=1) |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, downsample=False, |
|
|
upsample=False, skip=False, factor=2, motion=False): |
|
|
super().__init__() |
|
|
self.upsample = upsample |
|
|
self.maxpool= None |
|
|
if downsample: |
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) |
|
|
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2) |
|
|
if factor == 4: |
|
|
self.maxpool = nn.MaxPool2d(2) |
|
|
|
|
|
elif upsample: |
|
|
self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor) |
|
|
|
|
|
if motion: |
|
|
self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor, |
|
|
mode='bilinear', |
|
|
align_corners=True), |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), |
|
|
nn.BatchNorm2d(out_channels)) |
|
|
else: |
|
|
self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor, |
|
|
mode='bilinear', |
|
|
align_corners=True), |
|
|
nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1)) |
|
|
|
|
|
else: |
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
self.shortcut = nn.Sequential() |
|
|
|
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
def forward(self, input): |
|
|
shortcut = self.shortcut(input) |
|
|
input = nn.ReLU()(self.conv1(input)) |
|
|
input = nn.ReLU()(self.conv2(input)) |
|
|
input = input + shortcut |
|
|
if self.maxpool is not None: |
|
|
input = self.maxpool(input) |
|
|
return nn.LeakyReLU()(input) |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, mid_channels=None, bn=True, motion=False, dilation=1): |
|
|
super().__init__() |
|
|
if not mid_channels: |
|
|
mid_channels = out_channels |
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
) if motion else nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, bias=False, dilation=dilation), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class Conv3D(nn.Module): |
|
|
def __init__(self, in_channel, out_channel, kernel_size, stride, padding): |
|
|
super(Conv3D, self).__init__() |
|
|
self.conv3d = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding) |
|
|
self.bn3d = nn.BatchNorm3d(out_channel) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x.permute(0, 2, 1, 3, 4).contiguous() |
|
|
x = F.leaky_relu(self.bn3d(self.conv3d(x))) |
|
|
x = x.permute(0, 2, 1, 3, 4).contiguous() |
|
|
|
|
|
return x |
|
|
|
|
|
class MatrixPredictor3DConv(nn.Module): |
|
|
def __init__(self, hidden_len=64): |
|
|
super(MatrixPredictor3DConv, self).__init__() |
|
|
self.unet_base = hidden_len |
|
|
self.hidden_len = hidden_len |
|
|
self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1) |
|
|
self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)) |
|
|
self.conv3d_2 = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3 , 3, 3), stride=1, padding=(0, 1, 1)) |
|
|
|
|
|
self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1) |
|
|
self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1) |
|
|
self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
self.bn_pre_1 = nn.BatchNorm2d(hidden_len) |
|
|
self.bn_pre_2 = nn.BatchNorm2d(hidden_len) |
|
|
self.bn1_1 = nn.BatchNorm2d(self.unet_base) |
|
|
self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2) |
|
|
self.bn3_1 = nn.BatchNorm2d(self.unet_base) |
|
|
self.bn4_1 = nn.BatchNorm2d(self.hidden_len) |
|
|
|
|
|
|
|
|
def forward(self,x): |
|
|
|
|
|
|
|
|
batch, seq, z, h, w = x.size() |
|
|
x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1)) |
|
|
x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x))) |
|
|
x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x))) |
|
|
x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x))) |
|
|
|
|
|
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() |
|
|
x_1 = self.conv3d_1(x_1) |
|
|
x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() |
|
|
x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) |
|
|
x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() |
|
|
x_2 = self.conv3d_2(x_2) |
|
|
x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() |
|
|
|
|
|
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) |
|
|
x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() |
|
|
x_1 = F.adaptive_max_pool3d(x_1, (1, None, None)) |
|
|
x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() |
|
|
x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() |
|
|
x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, scale_factor=(2, 2)), x_1), dim=1)))) |
|
|
x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) |
|
|
x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, scale_factor=(2, 2))))) |
|
|
return x |
|
|
|
|
|
class SimpleMatrixPredictor3DConv_direct(nn.Module): |
|
|
def __init__(self, T, hidden_len=64, image_pred=False, aft_seq_length=10): |
|
|
super(SimpleMatrixPredictor3DConv_direct, self).__init__() |
|
|
self.unet_base = hidden_len |
|
|
self.hidden_len = hidden_len |
|
|
self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1) |
|
|
self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1) |
|
|
self.fut_len = aft_seq_length |
|
|
|
|
|
self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)) |
|
|
|
|
|
if self.fut_len > 1 : |
|
|
self.temporal_layer = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)) |
|
|
else: |
|
|
self.temporal_layer = nn.Sequential( |
|
|
nn.Conv2d(self.unet_base *2, self.unet_base * 2, kernel_size=3, stride=1, padding=1), |
|
|
nn.LeakyReLU()) |
|
|
|
|
|
input_len = T if image_pred else T - 1 |
|
|
self.conv_translate = nn.Sequential( |
|
|
nn.Conv2d(self.unet_base * input_len , self.unet_base * self.fut_len, kernel_size=1, stride=1, padding=0), |
|
|
nn.LeakyReLU()) |
|
|
|
|
|
self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1) |
|
|
self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1) |
|
|
self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
self.bn_pre_1 = nn.BatchNorm2d(hidden_len) |
|
|
self.bn_pre_2 = nn.BatchNorm2d(hidden_len) |
|
|
self.bn1_1 = nn.BatchNorm2d(self.unet_base) |
|
|
self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2) |
|
|
self.bn3_1 = nn.BatchNorm2d(self.unet_base) |
|
|
self.bn4_1 = nn.BatchNorm2d(self.hidden_len) |
|
|
self.bn_translate = nn.BatchNorm2d(self.unet_base * self.fut_len) |
|
|
|
|
|
|
|
|
def forward(self,x): |
|
|
|
|
|
|
|
|
batch, seq, z, h, w = x.size() |
|
|
x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1)) |
|
|
x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x))) |
|
|
x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x))) |
|
|
x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x))) |
|
|
|
|
|
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() |
|
|
|
|
|
x_1 = self.conv3d_1(x_1) |
|
|
batch, seq, c, h, w = x_1.shape |
|
|
x_tmp = x_1.reshape(batch,-1,h,w) |
|
|
x_tmp = self.bn_translate(self.conv_translate(x_tmp)) |
|
|
x_1 = x_tmp.reshape(batch,self.fut_len,c,h,w) |
|
|
x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() |
|
|
x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) |
|
|
if self.fut_len > 1: |
|
|
x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() |
|
|
x_2 = self.temporal_layer(x_2) |
|
|
|
|
|
x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() |
|
|
else: |
|
|
x_2 = self.temporal_layer(x_2) |
|
|
|
|
|
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) |
|
|
|
|
|
x_1 = x_1.reshape(-1, x_1.size(2), x_1.size(3), x_1.size(4)) |
|
|
|
|
|
|
|
|
x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, size=x_1.shape[2:]), x_1), dim=1)))) |
|
|
x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) |
|
|
x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, size = x.shape[3:])))) |
|
|
|
|
|
return x |
|
|
|
|
|
class PredictModel(nn.Module): |
|
|
def __init__(self, T, hidden_len=32, aft_seq_length=10, mx_h=32, mx_w=32, use_direct_predictor=True): |
|
|
super(PredictModel, self).__init__() |
|
|
self.mx_h = mx_h |
|
|
self.mx_w = mx_w |
|
|
self.hidden_len = hidden_len |
|
|
self.fut_len = aft_seq_length |
|
|
self.conv1 = nn.Conv2d( 1, hidden_len, kernel_size=3, padding=1, bias=False) |
|
|
self.fuse_conv = nn.Conv2d(hidden_len*2, hidden_len, kernel_size=3, padding=1, bias=False) |
|
|
if use_direct_predictor: |
|
|
self.predictor = SimpleMatrixPredictor3DConv_direct(T=T, hidden_len=hidden_len, aft_seq_length=aft_seq_length) |
|
|
else: |
|
|
self.predictor = MatrixPredictor3DConv(hidden_len) |
|
|
self.out_conv = nn.Conv2d(hidden_len, 1, kernel_size=3, padding=1, bias=False) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
def res_interpolate(self,in_tensor,template_tensor): |
|
|
''' |
|
|
in_tensor: batch,c,h'w',H'W' |
|
|
tempolate_tensor: batch,c,hw,HW |
|
|
out_tensor: batch,c,hw,HW |
|
|
''' |
|
|
out_tensor = F.interpolate(in_tensor,template_tensor.shape[-2:]) |
|
|
|
|
|
return out_tensor |
|
|
|
|
|
def forward(self,matrix_seq, softmax=False, res=None): |
|
|
|
|
|
B,T,hw,window_size = matrix_seq.size() |
|
|
|
|
|
matrix_seq = matrix_seq.reshape(-1,hw,self.mx_h,self.mx_w) |
|
|
matrix_seq = matrix_seq.reshape(B*T*hw,self.mx_h,self.mx_w).unsqueeze(1) |
|
|
|
|
|
x = self.conv1(matrix_seq) |
|
|
x = x.reshape(B,T,hw,-1,self.mx_h,self.mx_w) |
|
|
x = x.permute(0,2,1,3,4,5).reshape(B*hw,T,-1,self.mx_h,self.mx_w) |
|
|
emb = self.predictor(x) |
|
|
|
|
|
emb = emb.reshape(B*hw*self.fut_len,-1,self.mx_h,self.mx_w) |
|
|
res_emb = emb.clone() |
|
|
if res is not None: |
|
|
template = emb.clone().reshape(B,hw,emb.shape[1],-1).permute(0,2,1,3) |
|
|
in_tensor = res.clone().reshape(B,hw//4,emb.shape[1],-1).permute(0,2,1,3) |
|
|
|
|
|
res_tensor = self.res_interpolate(in_tensor,template).permute(0,2,1,3).reshape(emb.shape) |
|
|
|
|
|
emb = self.fuse_conv(torch.cat([emb,res_tensor],dim=1)) |
|
|
|
|
|
out = self.out_conv(emb) |
|
|
|
|
|
out = out.reshape(B,hw,-1,self.mx_h,self.mx_w) |
|
|
out = out.permute(0,2,1,3,4) |
|
|
out = out.reshape(B,-1,hw,window_size) |
|
|
|
|
|
if softmax: |
|
|
out = out.view(B,out.shape[1],-1) |
|
|
out = self.softmax(out) |
|
|
out = out.reshape(B,-1,hw,window_size) |
|
|
|
|
|
return out,res_emb |