| |
|
|
| |
| """ |
| |
| Purpose : |
| |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.data |
|
|
| __author__ = "Kartik Prabhu, Mahantesh Pattadkal, and Soumick Chatterjee" |
| __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany" |
| __credits__ = ["Kartik Prabhu", "Mahantesh Pattadkal", "Soumick Chatterjee"] |
| __license__ = "GPL" |
| __version__ = "1.0.0" |
| __maintainer__ = "Soumick Chatterjee" |
| __email__ = "soumick.chatterjee@ovgu.de" |
| __status__ = "Production" |
|
|
|
|
| class ConvBlock(nn.Module): |
| """ |
| Convolution Block |
| """ |
|
|
| def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True): |
| super(ConvBlock, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=padding, bias=bias), |
| nn.BatchNorm3d(num_features=out_channels), |
| nn.LeakyReLU(inplace=True), |
| nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=padding, bias=bias), |
| nn.BatchNorm3d(num_features=out_channels), |
| nn.LeakyReLU(inplace=True) |
| ) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| return x |
|
|
|
|
| class SeparableConvBlock(nn.Module): |
| """ |
| Convolution Block |
| """ |
|
|
| def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True): |
| super(SeparableConvBlock, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, |
| bias=bias), |
| nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=padding, bias=bias), |
| nn.PReLU(num_parameters=out_channels, init=0.25), |
| nn.Dropout3d(), |
| nn.BatchNorm3d(num_features=out_channels), |
| nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, |
| bias=bias), |
| nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=padding, bias=bias), |
| nn.PReLU(num_parameters=out_channels, init=0.25), |
| nn.Dropout3d(), |
| nn.BatchNorm3d(num_features=out_channels)) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| return x |
|
|
|
|
| class UpConv(nn.Module): |
| """ |
| Up Convolution Block |
| """ |
|
|
| |
| def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True): |
| super(UpConv, self).__init__() |
| self.up = nn.Sequential( |
| nn.Upsample(scale_factor=2), |
| nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=padding, bias=bias), |
| nn.BatchNorm3d(num_features=out_channels), |
| nn.LeakyReLU(inplace=True)) |
|
|
| def forward(self, x): |
| x = self.up(x) |
| return x |
|
|
|
|
| class UNet(nn.Module): |
| """ |
| UNet - Basic Implementation |
| Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width]. |
| Paper : https://arxiv.org/abs/1505.04597 |
| """ |
|
|
| def __init__(self, in_ch=1, out_ch=1, init_features=64): |
| super(UNet, self).__init__() |
|
|
| n1 = init_features |
| filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] |
|
|
| self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2) |
| self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2) |
| self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2) |
| self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2) |
|
|
| self.Conv1 = ConvBlock(in_ch, filters[0]) |
| self.Conv2 = SeparableConvBlock(filters[0], filters[1]) |
| self.Conv3 = SeparableConvBlock(filters[1], filters[2]) |
| self.Conv4 = SeparableConvBlock(filters[2], filters[3]) |
| self.Conv5 = SeparableConvBlock(filters[3], filters[4]) |
|
|
| self.Up5 = UpConv(filters[4], filters[3]) |
| self.Up_conv5 = SeparableConvBlock(filters[4], filters[3]) |
|
|
| self.Up4 = UpConv(filters[3], filters[2]) |
| self.Up_conv4 = SeparableConvBlock(filters[3], filters[2]) |
|
|
| self.Up3 = UpConv(filters[2], filters[1]) |
| self.Up_conv3 = SeparableConvBlock(filters[2], filters[1]) |
|
|
| self.Up2 = UpConv(filters[1], filters[0]) |
| self.Up_conv2 = ConvBlock(filters[1], filters[0]) |
|
|
| self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) |
|
|
| |
|
|
| def forward(self, x): |
| |
| |
| |
|
|
| e1 = self.Conv1(x) |
| |
| |
|
|
| e2 = self.Maxpool1(e1) |
| e2 = self.Conv2(e2) |
| |
| |
|
|
| e3 = self.Maxpool2(e2) |
| e3 = self.Conv3(e3) |
| |
| |
|
|
| e4 = self.Maxpool3(e3) |
| e4 = self.Conv4(e4) |
| |
| |
|
|
| e5 = self.Maxpool4(e4) |
| e5 = self.Conv5(e5) |
| |
| |
|
|
| d5 = self.Up5(e5) |
| |
| |
| |
| |
| d5 = torch.cat((e4, d5), dim=1) |
| d5 = self.Up_conv5(d5) |
| |
| |
|
|
| d4 = self.Up4(d5) |
| |
| |
| d4 = torch.cat((e3, d4), dim=1) |
| d4 = self.Up_conv4(d4) |
| |
| |
| d3 = self.Up3(d4) |
| d3 = torch.cat((e2, d3), dim=1) |
| d3 = self.Up_conv3(d3) |
| |
| |
| d2 = self.Up2(d3) |
| d2 = torch.cat((e1, d2), dim=1) |
| d2 = self.Up_conv2(d2) |
| |
| |
| out = self.Conv(d2) |
| |
| |
| |
|
|
| return out |
|
|
|
|
| class UNetDeepSup(nn.Module): |
| """ |
| UNet - Basic Implementation |
| Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width]. |
| Paper : https://arxiv.org/abs/1505.04597 |
| """ |
|
|
| def __init__(self, in_ch=1, out_ch=1, init_features=64): |
| super(UNetDeepSup, self).__init__() |
|
|
| n1 = init_features |
| filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] |
|
|
| self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2) |
| self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2) |
| self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2) |
| self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2) |
|
|
| self.Conv1 = ConvBlock(in_ch, filters[0]) |
| self.Conv2 = SeparableConvBlock(filters[0], filters[1]) |
| self.Conv3 = SeparableConvBlock(filters[1], filters[2]) |
| self.Conv4 = SeparableConvBlock(filters[2], filters[3]) |
| self.Conv5 = SeparableConvBlock(filters[3], filters[4]) |
|
|
| |
| self.Conv_d3 = SeparableConvBlock(filters[1], 1) |
| self.Conv_d4 = SeparableConvBlock(filters[2], 1) |
|
|
| self.Up5 = UpConv(filters[4], filters[3]) |
| self.Up_conv5 = SeparableConvBlock(filters[4], filters[3]) |
|
|
| self.Up4 = UpConv(filters[3], filters[2]) |
| self.Up_conv4 = SeparableConvBlock(filters[3], filters[2]) |
|
|
| self.Up3 = UpConv(filters[2], filters[1]) |
| self.Up_conv3 = SeparableConvBlock(filters[2], filters[1]) |
|
|
| self.Up2 = UpConv(filters[1], filters[0]) |
| self.Up_conv2 = ConvBlock(filters[1], filters[0]) |
|
|
| self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) |
|
|
| for submodule in self.modules(): |
| submodule.register_forward_hook(self.nan_hook) |
|
|
| |
|
|
| def nan_hook(self, module, inp, output): |
| for i, out in enumerate(output): |
| nan_mask = torch.isnan(out) |
| if nan_mask.any(): |
| print("In", self.__class__.__name__) |
| torch.save(inp, '/nfs1/sutrave/outputs/nan_values_input/inp_2_Nov.pt') |
| raise RuntimeError(" classname " + self.__class__.__name__ + "i " + str( |
| i) + f" module: {module} classname {self.__class__.__name__} Found NAN in output {i} at indices: ", |
| nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)]) |
|
|
| def forward(self, x): |
| |
| |
| |
|
|
| e1 = self.Conv1(x) |
| |
| |
|
|
| e2 = self.Maxpool1(e1) |
| e2 = self.Conv2(e2) |
| |
| |
|
|
| e3 = self.Maxpool2(e2) |
| e3 = self.Conv3(e3) |
| |
| |
|
|
| e4 = self.Maxpool3(e3) |
| e4 = self.Conv4(e4) |
| |
| |
|
|
| e5 = self.Maxpool4(e4) |
| e5 = self.Conv5(e5) |
| |
| |
|
|
| d5 = self.Up5(e5) |
| |
| |
| |
| |
| d5 = torch.cat((e4, d5), dim=1) |
| d5 = self.Up_conv5(d5) |
| |
| |
|
|
| d4 = self.Up4(d5) |
| |
| |
| d4 = torch.cat((e3, d4), dim=1) |
| d4 = self.Up_conv4(d4) |
| d4_out = self.Conv_d4(d4) |
|
|
| |
| |
| d3 = self.Up3(d4) |
| d3 = torch.cat((e2, d3), dim=1) |
| d3 = self.Up_conv3(d3) |
| d3_out = self.Conv_d3(d3) |
|
|
| |
| |
| d2 = self.Up2(d3) |
| d2 = torch.cat((e1, d2), dim=1) |
| d2 = self.Up_conv2(d2) |
| |
| |
| out = self.Conv(d2) |
| |
| |
| |
|
|
| return out |