|
|
""" |
|
|
Contains an implementation of a Module that performs 3D convolution on an image. |
|
|
|
|
|
Author: Ole-Christian Galbo Engstrøm |
|
|
E-mail: ocge@foss.dk |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .unet import Permute |
|
|
|
|
|
|
|
|
class Conv3Don2D(nn.Module): |
|
|
def __init__(self, in_channels, num_filters, kernel_size, normalization): |
|
|
super(Conv3Don2D, self).__init__() |
|
|
bias = normalization is None |
|
|
self.conv = nn.Conv3d( |
|
|
in_channels=1, |
|
|
out_channels=num_filters, |
|
|
kernel_size=kernel_size, |
|
|
stride=(1, 2, 2), |
|
|
padding="valid", |
|
|
bias=bias, |
|
|
) |
|
|
self.num_filters = num_filters |
|
|
self.normalization = normalization |
|
|
self.norm_layer = self._get_norm_layer() |
|
|
self.out_channels = num_filters * (in_channels - kernel_size[0] + 1) |
|
|
|
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv3d): |
|
|
nn.init.kaiming_normal_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def _get_norm_layer(self): |
|
|
if self.normalization == "bn": |
|
|
return nn.BatchNorm3d(self.num_filters, momentum=0.01) |
|
|
if self.normalization == "ln": |
|
|
return self._get_layer_norm() |
|
|
return nn.Identity() |
|
|
|
|
|
def _get_layer_norm(self): |
|
|
layers = [ |
|
|
Permute((0, 2, 3, 4, 1)), |
|
|
nn.LayerNorm(self.num_filters), |
|
|
Permute((0, 4, 1, 2, 3)), |
|
|
] |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def _reshape3Dto4D(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x.view(-1, 1, x.size(-3), x.size(-2), x.size(-1)) |
|
|
|
|
|
def _reshape4Dto3D(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x.view(-1, self.out_channels, x.size(-2), x.size(-1)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self._reshape3Dto4D(x) |
|
|
x = self.conv(x) |
|
|
x = self.norm_layer(x) |
|
|
x = self._reshape4Dto3D(x) |
|
|
return x |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
conv3d_model = Conv3Don2D( |
|
|
in_channels=124, |
|
|
num_filters=1, |
|
|
kernel_size=(7, 2, 2), |
|
|
normalization=None, |
|
|
) |
|
|
print(conv3d_model) |
|
|
|
|
|
num_params = sum(p.numel() for p in conv3d_model.parameters()) |
|
|
print(f"Number of parameters: {num_params}") |
|
|
|
|
|
num_trainable_params = sum( |
|
|
p.numel() for p in conv3d_model.parameters() if p.requires_grad |
|
|
) |
|
|
print(f"Number of trainable parameters: {num_trainable_params}") |
|
|
pass |
|
|
|