marijanic's picture
Upload 3 files
8432aac verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class Block(nn.Module):
def __init__(self, in_channels=128, size=32):
super(Block, self).__init__()
self.conv_param = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1)
self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1)
self.dense_ts = nn.Linear(192, 128)
self.layer_norm = nn.LayerNorm([128, size, size])
def forward(self, x_img, x_ts):
x_parameter = F.relu(self.conv_param(x_img))
time_parameter = F.relu(self.dense_ts(x_ts))
time_parameter = time_parameter.view(-1, 128, 1, 1)
x_parameter = x_parameter * time_parameter
x_out = self.conv_out(x_img)
x_out = x_out + x_parameter
x_out = F.relu(self.layer_norm(x_out))
return x_out
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.l_ts = nn.Sequential(
nn.Linear(1, 192),
nn.LayerNorm([192]),
nn.ReLU(),
)
self.down_x32 = Block(in_channels=3, size=32)
self.down_x16 = Block(size=16)
self.down_x8 = Block(size=8)
self.down_x4 = Block(size=4)
self.mlp = nn.Sequential(
nn.Linear(2240, 128),
nn.LayerNorm([128]),
nn.ReLU(),
nn.Linear(128, 32 * 4 * 4), # make [-1, 32, 4, 4]
nn.LayerNorm([32 * 4 * 4]),
nn.ReLU(),
)
self.up_x4 = Block(in_channels=32 + 128, size=4)
self.up_x8 = Block(in_channels=256, size=8)
self.up_x16 = Block(in_channels=256, size=16)
self.up_x32 = Block(in_channels=256, size=32)
self.cnn_output = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1, padding=0)
# make optimizer
self.opt = torch.optim.Adam(self.parameters(), lr=0.0008)
def forward(self, x, x_ts):
x_ts = self.l_ts(x_ts)
# ----- left ( down ) -----
blocks = [
self.down_x32,
self.down_x16,
self.down_x8,
self.down_x4,
]
x_left_layers = []
for i, block in enumerate(blocks):
x = block(x, x_ts)
x_left_layers.append(x)
if i < len(blocks) - 1:
x = F.max_pool2d(x, 2)
# ----- MLP -----
x = x.view(-1, 128 * 4 * 4)
x = torch.cat([x, x_ts], dim=1)
x = self.mlp(x)
x = x.view(-1, 32, 4, 4)
# ----- right ( up ) -----
blocks = [
self.up_x4,
self.up_x8,
self.up_x16,
self.up_x32,
]
for i, block in enumerate(blocks):
# cat left
x_left = x_left_layers[len(blocks) - i - 1]
x = torch.cat([x, x_left], dim=1)
x = block(x, x_ts)
if i < len(blocks) - 1:
x = F.interpolate(x, scale_factor=2, mode='bilinear')
# ----- output -----
x = self.cnn_output(x)
return x