|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class ConvLSTMCell(nn.Module): |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, kernel_size, bias, device): |
|
|
""" |
|
|
Initialize ConvLSTM cell. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
input_dim: int |
|
|
Number of channels of input tensor. |
|
|
hidden_dim: int |
|
|
Number of channels of hidden state. |
|
|
kernel_size: (int, int) |
|
|
Size of the convolutional kernel. |
|
|
bias: bool |
|
|
Whether or not to add the bias. |
|
|
""" |
|
|
|
|
|
super(ConvLSTMCell, self).__init__() |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.kernel_size = kernel_size |
|
|
self.padding = kernel_size[0] // 2, kernel_size[1] // 2 |
|
|
self.bias = bias |
|
|
self.device = device |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
in_channels=self.input_dim + self.hidden_dim, |
|
|
out_channels=4 * self.hidden_dim, |
|
|
kernel_size=self.kernel_size, |
|
|
padding=self.padding, |
|
|
bias=self.bias, |
|
|
) |
|
|
|
|
|
def __initStates(self, size): |
|
|
return torch.zeros(size).to(self.device), torch.zeros(size).to(self.device) |
|
|
|
|
|
|
|
|
def forward(self, input_tensor, cur_state): |
|
|
if cur_state == None: |
|
|
h_cur, c_cur = self.__initStates( |
|
|
[ |
|
|
input_tensor.shape[0], |
|
|
self.hidden_dim, |
|
|
input_tensor.shape[2], |
|
|
input_tensor.shape[3], |
|
|
] |
|
|
) |
|
|
else: |
|
|
h_cur, c_cur = cur_state |
|
|
|
|
|
combined = torch.cat( |
|
|
[input_tensor, h_cur], dim=1 |
|
|
) |
|
|
combined_conv = self.conv(combined) |
|
|
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) |
|
|
|
|
|
i = torch.sigmoid(cc_i) |
|
|
f = torch.sigmoid(cc_f) |
|
|
o = torch.sigmoid(cc_o) |
|
|
g = torch.tanh(cc_g) |
|
|
|
|
|
c_next = f * c_cur + i * g |
|
|
h_next = o * torch.tanh(c_next) |
|
|
|
|
|
return h_next, c_next |
|
|
|
|
|
def init_hidden(self, batch_size, image_size): |
|
|
height, width = image_size |
|
|
return ( |
|
|
torch.zeros( |
|
|
batch_size, |
|
|
self.hidden_dim, |
|
|
height, |
|
|
width, |
|
|
device=self.conv.weight.device, |
|
|
), |
|
|
torch.zeros( |
|
|
batch_size, |
|
|
self.hidden_dim, |
|
|
height, |
|
|
width, |
|
|
device=self.conv.weight.device, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class ConvLSTM(nn.Module): |
|
|
""" |
|
|
|
|
|
Parameters: |
|
|
input_dim: Number of channels in input |
|
|
hidden_dim: Number of hidden channels |
|
|
kernel_size: Size of kernel in convolutions |
|
|
num_layers: Number of LSTM layers stacked on each other |
|
|
batch_first: Whether or not dimension 0 is the batch or not |
|
|
bias: Bias or no bias in Convolution |
|
|
return_all_layers: Return the list of computations for all layers |
|
|
Note: Will do same padding. |
|
|
|
|
|
Input: |
|
|
A tensor of size B, T, C, H, W or T, B, C, H, W |
|
|
Output: |
|
|
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). |
|
|
0 - layer_output_list is the list of lists of length T of each output |
|
|
1 - last_state_list is the list of last states |
|
|
each element of the list is a tuple (h, c) for hidden state and memory |
|
|
Example: |
|
|
>> x = torch.rand((32, 10, 64, 128, 128)) |
|
|
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) |
|
|
>> _, last_states = convlstm(x) |
|
|
>> h = last_states[0][0] # 0 for layer index, 0 for h index |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim, |
|
|
hidden_dim, |
|
|
kernel_size, |
|
|
num_layers, |
|
|
batch_first=False, |
|
|
bias=True, |
|
|
return_all_layers=False, |
|
|
): |
|
|
super(ConvLSTM, self).__init__() |
|
|
|
|
|
self._check_kernel_size_consistency(kernel_size) |
|
|
|
|
|
|
|
|
kernel_size = self._extend_for_multilayer(kernel_size, num_layers) |
|
|
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) |
|
|
if not len(kernel_size) == len(hidden_dim) == num_layers: |
|
|
raise ValueError("Inconsistent list length.") |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.kernel_size = kernel_size |
|
|
self.num_layers = num_layers |
|
|
self.batch_first = batch_first |
|
|
self.bias = bias |
|
|
self.return_all_layers = return_all_layers |
|
|
|
|
|
cell_list = [] |
|
|
for i in range(0, self.num_layers): |
|
|
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] |
|
|
|
|
|
cell_list.append( |
|
|
ConvLSTMCell( |
|
|
input_dim=cur_input_dim, |
|
|
hidden_dim=self.hidden_dim[i], |
|
|
kernel_size=self.kernel_size[i], |
|
|
bias=self.bias, |
|
|
) |
|
|
) |
|
|
|
|
|
self.cell_list = nn.ModuleList(cell_list) |
|
|
|
|
|
def forward(self, input_tensor, hidden_state=None): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
input_tensor: todo |
|
|
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) |
|
|
hidden_state: todo |
|
|
None. todo implement stateful |
|
|
|
|
|
Returns |
|
|
------- |
|
|
last_state_list, layer_output |
|
|
""" |
|
|
if not self.batch_first: |
|
|
|
|
|
input_tensor = input_tensor.permute(1, 0, 2, 3, 4) |
|
|
|
|
|
b, _, _, h, w = input_tensor.size() |
|
|
|
|
|
|
|
|
if hidden_state is not None: |
|
|
raise NotImplementedError() |
|
|
else: |
|
|
|
|
|
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w)) |
|
|
|
|
|
layer_output_list = [] |
|
|
last_state_list = [] |
|
|
|
|
|
seq_len = input_tensor.size(1) |
|
|
cur_layer_input = input_tensor |
|
|
|
|
|
for layer_idx in range(self.num_layers): |
|
|
|
|
|
h, c = hidden_state[layer_idx] |
|
|
output_inner = [] |
|
|
for t in range(seq_len): |
|
|
h, c = self.cell_list[layer_idx]( |
|
|
input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c] |
|
|
) |
|
|
output_inner.append(h) |
|
|
|
|
|
layer_output = torch.stack(output_inner, dim=1) |
|
|
cur_layer_input = layer_output |
|
|
|
|
|
layer_output_list.append(layer_output) |
|
|
last_state_list.append([h, c]) |
|
|
|
|
|
if not self.return_all_layers: |
|
|
layer_output_list = layer_output_list[-1:] |
|
|
last_state_list = last_state_list[-1:] |
|
|
|
|
|
return layer_output_list, last_state_list |
|
|
|
|
|
def _init_hidden(self, batch_size, image_size): |
|
|
init_states = [] |
|
|
for i in range(self.num_layers): |
|
|
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) |
|
|
return init_states |
|
|
|
|
|
@staticmethod |
|
|
def _check_kernel_size_consistency(kernel_size): |
|
|
if not ( |
|
|
isinstance(kernel_size, tuple) |
|
|
or ( |
|
|
isinstance(kernel_size, list) |
|
|
and all([isinstance(elem, tuple) for elem in kernel_size]) |
|
|
) |
|
|
): |
|
|
raise ValueError("`kernel_size` must be tuple or list of tuples") |
|
|
|
|
|
@staticmethod |
|
|
def _extend_for_multilayer(param, num_layers): |
|
|
if not isinstance(param, list): |
|
|
param = [param] * num_layers |
|
|
return param |
|
|
|
|
|
|
|
|
def normal_init(m, mean, std): |
|
|
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): |
|
|
m.weight.data.normal_(mean, std) |
|
|
m.bias.data.zero_() |
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
|
def __init__(self, device, inputChannels=4, outputChannels=3, d=64): |
|
|
super().__init__() |
|
|
self.d = d |
|
|
self.device = device |
|
|
|
|
|
self.conv1 = nn.Conv2d(inputChannels, d, 3, 2, 1) |
|
|
self.conv2 = nn.Conv2d(d, d * 2, 3, 2, 1) |
|
|
self.conv3 = nn.Conv2d(d * 2, d * 4, 3, 2, 1) |
|
|
self.conv4 = nn.Conv2d(d * 4, d * 8, 3, 2, 1) |
|
|
self.conv5 = nn.Conv2d(d * 8, d * 8, 3, 2, 1) |
|
|
self.conv6 = nn.Conv2d(d * 8, d * 8, 3, 2, 1) |
|
|
self.conv7 = nn.Conv2d(d * 8, d * 8, 3, 2, 1) |
|
|
|
|
|
self.conv_lstm_d1 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device) |
|
|
self.conv_lstm_d2 = ConvLSTMCell(d * 8 * 2, d * 8, (3, 3), False, device) |
|
|
self.conv_lstm_d3 = ConvLSTMCell(d * 8 * 2, d * 8, (3, 3), False, device) |
|
|
self.conv_lstm_d4 = ConvLSTMCell(d * 8 * 2, d * 4, (3, 3), False, device) |
|
|
self.conv_lstm_d5 = ConvLSTMCell(d * 4 * 2, d * 2, (3, 3), False, device) |
|
|
self.conv_lstm_d6 = ConvLSTMCell(d * 2 * 2, d, (3, 3), False, device) |
|
|
self.conv_lstm_d7 = ConvLSTMCell(d * 2, d, (3, 3), False, device) |
|
|
|
|
|
self.conv_lstm_e1 = ConvLSTMCell(d, d, (3, 3), False, device) |
|
|
self.conv_lstm_e2 = ConvLSTMCell(d * 2, d * 2, (3, 3), False, device) |
|
|
self.conv_lstm_e3 = ConvLSTMCell(d * 4, d * 4, (3, 3), False, device) |
|
|
self.conv_lstm_e4 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device) |
|
|
self.conv_lstm_e5 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device) |
|
|
self.conv_lstm_e6 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device) |
|
|
self.conv_lstm_e7 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device) |
|
|
|
|
|
self.up = nn.Upsample(scale_factor=2) |
|
|
self.conv_out = nn.Conv2d(d, outputChannels, 3, 1, 1) |
|
|
|
|
|
self.slope = 0.2 |
|
|
|
|
|
def weight_init(self, mean, std): |
|
|
for m in self._modules: |
|
|
normal_init(self._modules[m], mean, std) |
|
|
|
|
|
def forward_step(self, input, states_encoder, states_decoder): |
|
|
|
|
|
e1 = self.conv1(input) |
|
|
states_e1 = self.conv_lstm_e1(e1, states_encoder[0]) |
|
|
e2 = self.conv2(F.leaky_relu(states_e1[0], self.slope)) |
|
|
states_e2 = self.conv_lstm_e2(e2, states_encoder[1]) |
|
|
e3 = self.conv3(F.leaky_relu(states_e2[0], self.slope)) |
|
|
states_e3 = self.conv_lstm_e3(e3, states_encoder[2]) |
|
|
e4 = self.conv4(F.leaky_relu(states_e3[0], self.slope)) |
|
|
states_e4 = self.conv_lstm_e4(e4, states_encoder[3]) |
|
|
e5 = self.conv5(F.leaky_relu(states_e4[0], self.slope)) |
|
|
states_e5 = self.conv_lstm_e5(e5, states_encoder[4]) |
|
|
e6 = self.conv6(F.leaky_relu(states_e5[0], self.slope)) |
|
|
states_e6 = self.conv_lstm_e6(e6, states_encoder[5]) |
|
|
e7 = self.conv7(F.leaky_relu(states_e6[0], self.slope)) |
|
|
|
|
|
states1 = self.conv_lstm_d1(F.relu(e7), states_decoder[0]) |
|
|
d1 = self.up(states1[0]) |
|
|
d1 = torch.cat([d1, e6], 1) |
|
|
|
|
|
states2 = self.conv_lstm_d2(F.relu(d1), states_decoder[1]) |
|
|
d2 = self.up(states2[0]) |
|
|
d2 = torch.cat([d2, e5], 1) |
|
|
|
|
|
states3 = self.conv_lstm_d3(F.relu(d2), states_decoder[2]) |
|
|
d3 = self.up(states3[0]) |
|
|
d3 = torch.cat([d3, e4], 1) |
|
|
|
|
|
states4 = self.conv_lstm_d4(F.relu(d3), states_decoder[3]) |
|
|
d4 = self.up(states4[0]) |
|
|
d4 = torch.cat([d4, e3], 1) |
|
|
|
|
|
states5 = self.conv_lstm_d5(F.relu(d4), states_decoder[4]) |
|
|
d5 = self.up(states5[0]) |
|
|
d5 = torch.cat([d5, e2], 1) |
|
|
|
|
|
states6 = self.conv_lstm_d6(F.relu(d5), states_decoder[5]) |
|
|
d6 = self.up(states6[0]) |
|
|
d6 = torch.cat([d6, e1], 1) |
|
|
|
|
|
states7 = self.conv_lstm_d7(F.relu(d6), states_decoder[6]) |
|
|
d7 = self.up(states7[0]) |
|
|
|
|
|
o = torch.clip(torch.tanh(self.conv_out(d7)), min=-0.0, max=1) |
|
|
|
|
|
states_e = [states_e1, states_e2, states_e3, states_e4, states_e5, states_e6] |
|
|
states_d = [states1, states2, states3, states4, states5, states6, states7] |
|
|
|
|
|
return o, (states_e, states_d) |
|
|
|
|
|
def forward(self, tensor): |
|
|
states_encoder = (None, None, None, None, None, None, None) |
|
|
states_decoder = (None, None, None, None, None, None, None) |
|
|
output = torch.empty_like(tensor) |
|
|
for timeStep in range(tensor.shape[4]): |
|
|
output[:, :, :, :, timeStep], states = self.forward_step( |
|
|
tensor[:, :, :, :, timeStep], states_encoder, states_decoder |
|
|
) |
|
|
states_encoder, states_decoder = states[0], states[1] |
|
|
return output, states |
|
|
|