import torch import torch.nn as nn import torch.nn.functional as F class ResidualDenseBlock_4C(nn.Module): def __init__(self, nf=64, gc = 32, bias=True): super(ResidualDenseBlock_4C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) return x4 * 0.2 + x class RRDB(nn.Module): '''Residual in Residual Dense Block''' def __init__(self, nf): super(RRDB, self).__init__() gc = nf // 2 self.RDB1 = ResidualDenseBlock_4C(nf, gc) self.RDB2 = ResidualDenseBlock_4C(nf, gc) self.RDB3 = ResidualDenseBlock_4C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True, skip=True, scale=2, bn=True, motion=False): super().__init__() factor = scale # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: if skip: self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True) self.conv = ConvLayer(in_channels, out_channels, bn=bn) else: self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True) self.conv = ConvLayer(in_channels, out_channels) else: if skip: self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor) self.conv = ConvLayer(out_channels*2, out_channels, bn=bn, motion=motion) else: self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor) self.conv = ConvLayer(out_channels, out_channels, bn=bn, motion=motion) def forward(self, x1, x2=None): x1 = self.up(x1) if x2 is None: return self.conv(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, downsample=False, upsample=False, skip=False, factor=2, motion=False): super().__init__() self.upsample = upsample self.maxpool= None if downsample: self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2) if factor == 4: self.maxpool = nn.MaxPool2d(2) elif upsample: self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor) if motion: self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), nn.BatchNorm2d(out_channels)) else: self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True), nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1)) else: self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.shortcut = nn.Sequential() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, input): shortcut = self.shortcut(input) input = nn.ReLU()(self.conv1(input)) input = nn.ReLU()(self.conv2(input)) input = input + shortcut if self.maxpool is not None: input = self.maxpool(input) return nn.LeakyReLU()(input) class ConvLayer(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None, bn=True, motion=False, dilation=1): super().__init__() if not mid_channels: mid_channels = out_channels self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) if motion else nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, bias=False, dilation=dilation), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class Conv3D(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, stride, padding): super(Conv3D, self).__init__() self.conv3d = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding) self.bn3d = nn.BatchNorm3d(out_channel) def forward(self, x): # input x: (batch, seq, c, h, w) x = x.permute(0, 2, 1, 3, 4).contiguous() # (batch, c, seq_len, h, w) x = F.leaky_relu(self.bn3d(self.conv3d(x))) x = x.permute(0, 2, 1, 3, 4).contiguous() # (batch, seq_len, c, h, w) return x class MatrixPredictor3DConv(nn.Module): def __init__(self, hidden_len=64): super(MatrixPredictor3DConv, self).__init__() self.unet_base = hidden_len #64 self.hidden_len = hidden_len #64 self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1) self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1) self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)) self.conv3d_2 = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3 , 3, 3), stride=1, padding=(0, 1, 1)) self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1) self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1) self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1) self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1) self.bn_pre_1 = nn.BatchNorm2d(hidden_len) self.bn_pre_2 = nn.BatchNorm2d(hidden_len) self.bn1_1 = nn.BatchNorm2d(self.unet_base) self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2) self.bn3_1 = nn.BatchNorm2d(self.unet_base) self.bn4_1 = nn.BatchNorm2d(self.hidden_len) def forward(self,x): # x [B,T,C,32,32] # out: [B,C,32,32] batch, seq, z, h, w = x.size() x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1)) x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x))) x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x))) x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x))) x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() # (batch, seq, c, h, w) x_1 = self.conv3d_1(x_1) # (batch, seq, c, h, w), 1st temporal conv x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch * seq, c, h, w) x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) # (batch * seq, c, h // 2, w // 2) x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() # (batch, seq, c, h, w) x_2 = self.conv3d_2(x_2) # (batch, seq=1, c, h // 2, w // 2), 2nd temporal conv x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() # (batch * seq, c, h//2, w//2), seq = 1 x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) # (batch, seq, c, h, w) x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() # (batch, c, seq, h, w) x_1 = F.adaptive_max_pool3d(x_1, (1, None, None)) # (batch, c, 1, h, w) x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() # (batch, 1, c, h, w) x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch*1, c, h, w) x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, scale_factor=(2, 2)), x_1), dim=1)))) x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) # (batch, seq, 1, h, w) x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, scale_factor=(2, 2))))) return x class SimpleMatrixPredictor3DConv_direct(nn.Module): def __init__(self, T, hidden_len=64, image_pred=False, aft_seq_length=10): super(SimpleMatrixPredictor3DConv_direct, self).__init__() self.unet_base = hidden_len #64 self.hidden_len = hidden_len #64 self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1) self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1) self.fut_len = aft_seq_length self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)) if self.fut_len > 1 : self.temporal_layer = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)) else: self.temporal_layer = nn.Sequential( nn.Conv2d(self.unet_base *2, self.unet_base * 2, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()) input_len = T if image_pred else T - 1 self.conv_translate = nn.Sequential( nn.Conv2d(self.unet_base * input_len , self.unet_base * self.fut_len, kernel_size=1, stride=1, padding=0), nn.LeakyReLU()) self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1) self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1) self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1) self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1) self.bn_pre_1 = nn.BatchNorm2d(hidden_len) self.bn_pre_2 = nn.BatchNorm2d(hidden_len) self.bn1_1 = nn.BatchNorm2d(self.unet_base) self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2) self.bn3_1 = nn.BatchNorm2d(self.unet_base) self.bn4_1 = nn.BatchNorm2d(self.hidden_len) self.bn_translate = nn.BatchNorm2d(self.unet_base * self.fut_len) def forward(self,x): # x [B,T,C,32,32] # out: [B,C,32,32] batch, seq, z, h, w = x.size() x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1)) x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x))) x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x))) x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x))) x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() # (batch, seq, c, h, w) x_1 = self.conv3d_1(x_1) # (batch, seq, c, h, w), 1st temporal conv batch, seq, c, h, w = x_1.shape x_tmp = x_1.reshape(batch,-1,h,w) x_tmp = self.bn_translate(self.conv_translate(x_tmp)) x_1 = x_tmp.reshape(batch,self.fut_len,c,h,w) x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch * seq, c, h, w) x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) # (batch * seq, c, h // 2, w // 2) if self.fut_len > 1: x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() # (batch, seq, c, h, w) x_2 = self.temporal_layer(x_2) # (batch, seq=10, c, h // 2, w // 2) x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() # (batch * seq, c, h//2, w//2), seq = 1 else: x_2 = self.temporal_layer(x_2) # (batch * seq,c, h // 2, w // 2) x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) # (batch, seq, c, h, w) x_1 = x_1.reshape(-1, x_1.size(2), x_1.size(3), x_1.size(4)) x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, size=x_1.shape[2:]), x_1), dim=1)))) x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) # (batch, seq, 1, h, w) x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, size = x.shape[3:])))) return x class PredictModel(nn.Module): def __init__(self, T, hidden_len=32, aft_seq_length=10, mx_h=32, mx_w=32, use_direct_predictor=True): super(PredictModel, self).__init__() self.mx_h = mx_h self.mx_w = mx_w self.hidden_len = hidden_len self.fut_len = aft_seq_length self.conv1 = nn.Conv2d( 1, hidden_len, kernel_size=3, padding=1, bias=False) self.fuse_conv = nn.Conv2d(hidden_len*2, hidden_len, kernel_size=3, padding=1, bias=False) if use_direct_predictor: self.predictor = SimpleMatrixPredictor3DConv_direct(T=T, hidden_len=hidden_len, aft_seq_length=aft_seq_length) else: self.predictor = MatrixPredictor3DConv(hidden_len) self.out_conv = nn.Conv2d(hidden_len, 1, kernel_size=3, padding=1, bias=False) self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() def res_interpolate(self,in_tensor,template_tensor): ''' in_tensor: batch,c,h'w',H'W' tempolate_tensor: batch,c,hw,HW out_tensor: batch,c,hw,HW ''' out_tensor = F.interpolate(in_tensor,template_tensor.shape[-2:]) # (BThw,target_h,target_w) return out_tensor def forward(self,matrix_seq, softmax=False, res=None): B,T,hw,window_size = matrix_seq.size() matrix_seq = matrix_seq.reshape(-1,hw,self.mx_h,self.mx_w) # (BT,hw,hw) matrix_seq = matrix_seq.reshape(B*T*hw,self.mx_h,self.mx_w).unsqueeze(1) # (BThw,1,h,w) x = self.conv1(matrix_seq) x = x.reshape(B,T,hw,-1,self.mx_h,self.mx_w) x = x.permute(0,2,1,3,4,5).reshape(B*hw,T,-1,self.mx_h,self.mx_w) emb = self.predictor(x) emb = emb.reshape(B*hw*self.fut_len,-1,self.mx_h,self.mx_w) res_emb = emb.clone() if res is not None: template = emb.clone().reshape(B,hw,emb.shape[1],-1).permute(0,2,1,3) in_tensor = res.clone().reshape(B,hw//4,emb.shape[1],-1).permute(0,2,1,3) res_tensor = self.res_interpolate(in_tensor,template).permute(0,2,1,3).reshape(emb.shape) emb = self.fuse_conv(torch.cat([emb,res_tensor],dim=1)) out = self.out_conv(emb) #(Bhwt,16,h//4,w//4) out = out.reshape(B,hw,-1,self.mx_h,self.mx_w) out = out.permute(0,2,1,3,4) out = out.reshape(B,-1,hw,window_size) if softmax: out = out.view(B,out.shape[1],-1) out = self.softmax(out) out = out.reshape(B,-1,hw,window_size) return out,res_emb