Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| class Block(nn.Module): | |
| def __init__(self, in_channels=128, size=32): | |
| super(Block, self).__init__() | |
| self.conv_param = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) | |
| self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) | |
| self.dense_ts = nn.Linear(192, 128) | |
| self.layer_norm = nn.LayerNorm([128, size, size]) | |
| def forward(self, x_img, x_ts): | |
| x_parameter = F.relu(self.conv_param(x_img)) | |
| time_parameter = F.relu(self.dense_ts(x_ts)) | |
| time_parameter = time_parameter.view(-1, 128, 1, 1) | |
| x_parameter = x_parameter * time_parameter | |
| x_out = self.conv_out(x_img) | |
| x_out = x_out + x_parameter | |
| x_out = F.relu(self.layer_norm(x_out)) | |
| return x_out | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super(Model, self).__init__() | |
| self.l_ts = nn.Sequential( | |
| nn.Linear(1, 192), | |
| nn.LayerNorm([192]), | |
| nn.ReLU(), | |
| ) | |
| self.down_x32 = Block(in_channels=3, size=32) | |
| self.down_x16 = Block(size=16) | |
| self.down_x8 = Block(size=8) | |
| self.down_x4 = Block(size=4) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(2240, 128), | |
| nn.LayerNorm([128]), | |
| nn.ReLU(), | |
| nn.Linear(128, 32 * 4 * 4), # make [-1, 32, 4, 4] | |
| nn.LayerNorm([32 * 4 * 4]), | |
| nn.ReLU(), | |
| ) | |
| self.up_x4 = Block(in_channels=32 + 128, size=4) | |
| self.up_x8 = Block(in_channels=256, size=8) | |
| self.up_x16 = Block(in_channels=256, size=16) | |
| self.up_x32 = Block(in_channels=256, size=32) | |
| self.cnn_output = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1, padding=0) | |
| # make optimizer | |
| self.opt = torch.optim.Adam(self.parameters(), lr=0.0008) | |
| def forward(self, x, x_ts): | |
| x_ts = self.l_ts(x_ts) | |
| # ----- left ( down ) ----- | |
| blocks = [ | |
| self.down_x32, | |
| self.down_x16, | |
| self.down_x8, | |
| self.down_x4, | |
| ] | |
| x_left_layers = [] | |
| for i, block in enumerate(blocks): | |
| x = block(x, x_ts) | |
| x_left_layers.append(x) | |
| if i < len(blocks) - 1: | |
| x = F.max_pool2d(x, 2) | |
| # ----- MLP ----- | |
| x = x.view(-1, 128 * 4 * 4) | |
| x = torch.cat([x, x_ts], dim=1) | |
| x = self.mlp(x) | |
| x = x.view(-1, 32, 4, 4) | |
| # ----- right ( up ) ----- | |
| blocks = [ | |
| self.up_x4, | |
| self.up_x8, | |
| self.up_x16, | |
| self.up_x32, | |
| ] | |
| for i, block in enumerate(blocks): | |
| # cat left | |
| x_left = x_left_layers[len(blocks) - i - 1] | |
| x = torch.cat([x, x_left], dim=1) | |
| x = block(x, x_ts) | |
| if i < len(blocks) - 1: | |
| x = F.interpolate(x, scale_factor=2, mode='bilinear') | |
| # ----- output ----- | |
| x = self.cnn_output(x) | |
| return x | |