import torch import torch.nn as nn import torch.nn.functional as F import torchvision class Block(nn.Module): def __init__(self, in_channels=128, size=32): super(Block, self).__init__() self.conv_param = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) self.dense_ts = nn.Linear(192, 128) self.layer_norm = nn.LayerNorm([128, size, size]) def forward(self, x_img, x_ts): x_parameter = F.relu(self.conv_param(x_img)) time_parameter = F.relu(self.dense_ts(x_ts)) time_parameter = time_parameter.view(-1, 128, 1, 1) x_parameter = x_parameter * time_parameter x_out = self.conv_out(x_img) x_out = x_out + x_parameter x_out = F.relu(self.layer_norm(x_out)) return x_out class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.l_ts = nn.Sequential( nn.Linear(1, 192), nn.LayerNorm([192]), nn.ReLU(), ) self.down_x32 = Block(in_channels=3, size=32) self.down_x16 = Block(size=16) self.down_x8 = Block(size=8) self.down_x4 = Block(size=4) self.mlp = nn.Sequential( nn.Linear(2240, 128), nn.LayerNorm([128]), nn.ReLU(), nn.Linear(128, 32 * 4 * 4), # make [-1, 32, 4, 4] nn.LayerNorm([32 * 4 * 4]), nn.ReLU(), ) self.up_x4 = Block(in_channels=32 + 128, size=4) self.up_x8 = Block(in_channels=256, size=8) self.up_x16 = Block(in_channels=256, size=16) self.up_x32 = Block(in_channels=256, size=32) self.cnn_output = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1, padding=0) # make optimizer self.opt = torch.optim.Adam(self.parameters(), lr=0.0008) def forward(self, x, x_ts): x_ts = self.l_ts(x_ts) # ----- left ( down ) ----- blocks = [ self.down_x32, self.down_x16, self.down_x8, self.down_x4, ] x_left_layers = [] for i, block in enumerate(blocks): x = block(x, x_ts) x_left_layers.append(x) if i < len(blocks) - 1: x = F.max_pool2d(x, 2) # ----- MLP ----- x = x.view(-1, 128 * 4 * 4) x = torch.cat([x, x_ts], dim=1) x = self.mlp(x) x = x.view(-1, 32, 4, 4) # ----- right ( up ) ----- blocks = [ self.up_x4, self.up_x8, self.up_x16, self.up_x32, ] for i, block in enumerate(blocks): # cat left x_left = x_left_layers[len(blocks) - i - 1] x = torch.cat([x, x_left], dim=1) x = block(x, x_ts) if i < len(blocks) - 1: x = F.interpolate(x, scale_factor=2, mode='bilinear') # ----- output ----- x = self.cnn_output(x) return x