|
|
from ConvLSTM import ConvLSTM |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
class MLP_5D(nn.Module): |
|
|
def __init__(self, height, width): |
|
|
super(MLP_5D, self).__init__() |
|
|
|
|
|
self.fc1 = nn.Linear(64, 128) |
|
|
self.dropout1 = nn.Dropout(0.05) |
|
|
self.fc2 = nn.Linear(128, 64) |
|
|
self.dropout2 = nn.Dropout(0.05) |
|
|
self.fc3 = nn.Linear(64, 1) |
|
|
|
|
|
self.height = height |
|
|
self.width = width |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, timesteps, channels, height, width = x.shape |
|
|
|
|
|
|
|
|
assert height == self.height and width == self.width, "Height and width mismatch" |
|
|
|
|
|
|
|
|
x = x.permute(0, 1, 3, 4, 2).reshape(-1, channels) |
|
|
|
|
|
|
|
|
|
|
|
x = self.fc1(x) |
|
|
x = torch.nn.functional.softplus(x) |
|
|
x = self.dropout1(x) |
|
|
x = self.fc2(x) |
|
|
x = torch.nn.functional.softplus(x) |
|
|
x = self.dropout2(x) |
|
|
x = self.fc3(x) |
|
|
x = torch.nn.functional.softplus(x) |
|
|
|
|
|
|
|
|
x = x.view(batch_size, timesteps, self.height, self.width, 1).permute(0, 1, 4, 2, 3) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvLSTMNetwork(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, output_channels, batch_first=True, pool_size=(2,2)): |
|
|
super(ConvLSTMNetwork, self).__init__() |
|
|
|
|
|
|
|
|
self.convlstm = ConvLSTM(input_dim=input_dim, |
|
|
hidden_dim=hidden_dims, |
|
|
kernel_size=kernel_size, |
|
|
num_layers=num_layers, |
|
|
batch_first=batch_first, |
|
|
bias=True, |
|
|
return_all_layers=True) |
|
|
|
|
|
|
|
|
self.batch_norms = nn.ModuleList([ |
|
|
nn.BatchNorm3d(hidden_dim) for hidden_dim in hidden_dims |
|
|
]) |
|
|
|
|
|
|
|
|
self.conv3d = nn.Conv3d(in_channels=hidden_dims[-1], |
|
|
out_channels=output_channels, |
|
|
kernel_size=(1, 3, 3), |
|
|
padding=(0, 1, 1)) |
|
|
|
|
|
|
|
|
self.mlp = MLP_5D(height=81, width=97) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.classification_head = nn.Sequential( |
|
|
nn.Conv3d(output_channels, 1, kernel_size=(1,1,1)), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
self.activation_variance = defaultdict(list) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
x: (B, T, input_dim, H, W) |
|
|
""" |
|
|
|
|
|
layer_output_list, last_state_list = self.convlstm(x) |
|
|
|
|
|
|
|
|
for i, output in enumerate(layer_output_list): |
|
|
|
|
|
output = output.permute(0, 2, 1, 3, 4) |
|
|
output = self.batch_norms[i](output) |
|
|
output = output.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
|
|
|
activation_variance = output.var(dim=(3, 4)).mean().item() |
|
|
self.activation_variance[f"ConvLSTM_layer_{i}"].append(activation_variance) |
|
|
|
|
|
layer_output_list[i] = output |
|
|
|
|
|
|
|
|
final_output = layer_output_list[-1] |
|
|
|
|
|
|
|
|
final_output = final_output.permute(0, 2, 1, 3, 4) |
|
|
final_output = self.conv3d(final_output) |
|
|
|
|
|
|
|
|
|
|
|
final_output_t = final_output.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
|
|
|
regression_output = self.mlp(final_output_t) |
|
|
|
|
|
|
|
|
|
|
|
final_output_c = final_output |
|
|
classification_output = self.classification_head(final_output_c) |
|
|
|
|
|
|
|
|
|
|
|
classification_output = classification_output.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
return regression_output, classification_output |