|
|
|
|
|
|
|
|
from typing import Tuple, Dict |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch |
|
|
|
|
|
|
|
|
class Multiresblock(nn.Module): |
|
|
def __init__(self,input_features : int, corresponding_unet_filters : int ,alpha : float =1.67)->None: |
|
|
""" |
|
|
MultiResblock |
|
|
Arguments: |
|
|
x - input layer |
|
|
corresponding_unet_filters - Unet filters for the same stage |
|
|
alpha - 1.67 - factor used in the paper to dervie number of filters for multiresunet filters from Unet filters |
|
|
Returns - None |
|
|
""" |
|
|
super().__init__() |
|
|
self.corresponding_unet_filters = corresponding_unet_filters |
|
|
self.alpha = alpha |
|
|
self.W = corresponding_unet_filters * alpha |
|
|
self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5), |
|
|
kernel_size = (1,1),activation='None',padding = 0) |
|
|
|
|
|
self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167), |
|
|
kernel_size = (3,3),activation='relu',padding = 1) |
|
|
self.conv2d_bn_5x5 = Conv2d_batchnorm(input_features=int(self.W*0.167),num_of_filters = int(self.W*0.333), |
|
|
kernel_size = (3,3),activation='relu',padding = 1) |
|
|
self.conv2d_bn_7x7 = Conv2d_batchnorm(input_features=int(self.W*0.333),num_of_filters = int(self.W*0.5), |
|
|
kernel_size = (3,3),activation='relu',padding = 1) |
|
|
self.batch_norm1 = nn.BatchNorm2d(int(self.W*0.5)+int(self.W*0.167)+int(self.W*0.333) ,affine=False) |
|
|
|
|
|
def forward(self,x: torch.Tensor)->torch.Tensor: |
|
|
|
|
|
temp = self.conv2d_bn_1x1(x) |
|
|
a = self.conv2d_bn_3x3(x) |
|
|
b = self.conv2d_bn_5x5(a) |
|
|
c = self.conv2d_bn_7x7(b) |
|
|
x = torch.cat([a,b,c],axis=1) |
|
|
x = self.batch_norm1(x) |
|
|
x = x + temp |
|
|
x = self.batch_norm1(x) |
|
|
return x |
|
|
|
|
|
class Conv2d_batchnorm(nn.Module): |
|
|
def __init__(self,input_features : int,num_of_filters : int ,kernel_size : Tuple = (2,2),stride : Tuple = (1,1), activation : str = 'relu',padding : int= 0)->None: |
|
|
""" |
|
|
Arguments: |
|
|
x - input layer |
|
|
num_of_filters - no. of filter outputs |
|
|
filters - shape of the filters to be used |
|
|
stride - stride dimension |
|
|
activation -activation function to be used |
|
|
Returns - None |
|
|
""" |
|
|
super().__init__() |
|
|
self.activation = activation |
|
|
self.conv1 = nn.Conv2d(in_channels=input_features,out_channels=num_of_filters,kernel_size=kernel_size,stride=stride,padding = padding) |
|
|
self.batchnorm = nn.BatchNorm2d(num_of_filters,affine=False) |
|
|
|
|
|
def forward(self,x : torch.Tensor)->torch.Tensor: |
|
|
x = self.conv1(x) |
|
|
x = self.batchnorm(x) |
|
|
if self.activation == 'relu': |
|
|
return F.relu(x) |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
class Respath(nn.Module): |
|
|
def __init__(self,input_features : int,filters : int,respath_length : int)->None: |
|
|
""" |
|
|
Arguments: |
|
|
input_features - input layer filters |
|
|
filters - output channels |
|
|
respath_length - length of the Respath |
|
|
|
|
|
Returns - None |
|
|
""" |
|
|
super().__init__() |
|
|
self.filters = filters |
|
|
self.respath_length = respath_length |
|
|
self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters, |
|
|
kernel_size = (1,1),activation='None',padding = 0) |
|
|
self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters, |
|
|
kernel_size = (3,3),activation='relu',padding = 1) |
|
|
self.conv2d_bn_1x1_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters, |
|
|
kernel_size = (1,1),activation='None',padding = 0) |
|
|
self.conv2d_bn_3x3_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters, |
|
|
kernel_size = (3,3),activation='relu',padding = 1) |
|
|
self.batch_norm1 = nn.BatchNorm2d(filters,affine=False) |
|
|
|
|
|
def forward(self,x : torch.Tensor)->torch.Tensor: |
|
|
shortcut = self.conv2d_bn_1x1(x) |
|
|
x = self.conv2d_bn_3x3(x) |
|
|
x = x + shortcut |
|
|
x = F.relu(x) |
|
|
x = self.batch_norm1(x) |
|
|
if self.respath_length>1: |
|
|
for i in range(self.respath_length): |
|
|
shortcut = self.conv2d_bn_1x1_common(x) |
|
|
x = self.conv2d_bn_3x3_common(x) |
|
|
x = x + shortcut |
|
|
x = F.relu(x) |
|
|
x = self.batch_norm1(x) |
|
|
return x |
|
|
else: |
|
|
return x |
|
|
|
|
|
class MultiResUnet(nn.Module): |
|
|
def __init__(self,channels : int,filters : int =32,nclasses : int =1)->None: |
|
|
|
|
|
""" |
|
|
Arguments: |
|
|
channels - input image channels |
|
|
filters - filters to begin with (Unet) |
|
|
nclasses - number of classes |
|
|
Returns - None |
|
|
""" |
|
|
super().__init__() |
|
|
self.alpha = 1.67 |
|
|
self.filters = filters |
|
|
self.nclasses = nclasses |
|
|
self.multiresblock1 = Multiresblock(input_features=channels,corresponding_unet_filters=self.filters) |
|
|
self.pool1 = nn.MaxPool2d(2,stride= 2) |
|
|
self.in_filters1 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333) |
|
|
self.respath1 = Respath(input_features=self.in_filters1 ,filters=self.filters,respath_length=4) |
|
|
self.multiresblock2 = Multiresblock(input_features= self.in_filters1,corresponding_unet_filters=self.filters*2) |
|
|
self.pool2 = nn.MaxPool2d(2, 2) |
|
|
self.in_filters2 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333) |
|
|
self.respath2 = Respath(input_features=self.in_filters2,filters=self.filters*2,respath_length=3) |
|
|
self.multiresblock3 = Multiresblock(input_features= self.in_filters2,corresponding_unet_filters=self.filters*4) |
|
|
self.pool3 = nn.MaxPool2d(2, 2) |
|
|
self.in_filters3 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333) |
|
|
self.respath3 = Respath(input_features=self.in_filters3,filters=self.filters*4,respath_length=2) |
|
|
self.multiresblock4 = Multiresblock(input_features= self.in_filters3,corresponding_unet_filters=self.filters*8) |
|
|
self.pool4 = nn.MaxPool2d(2, 2) |
|
|
self.in_filters4 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333) |
|
|
self.respath4 = Respath(input_features=self.in_filters4,filters=self.filters*8,respath_length=1) |
|
|
self.multiresblock5 = Multiresblock(input_features= self.in_filters4,corresponding_unet_filters=self.filters*16) |
|
|
self.in_filters5 = int(self.filters*16*self.alpha* 0.5)+int(self.filters*16*self.alpha*0.167)+int(self.filters*16*self.alpha*0.333) |
|
|
|
|
|
|
|
|
self.upsample6 = nn.ConvTranspose2d(in_channels=self.in_filters5,out_channels=self.filters*8,kernel_size=(2,2),stride=(2,2),padding = 0) |
|
|
self.concat_filters1 = self.filters*8+self.filters*8 |
|
|
self.multiresblock6 = Multiresblock(input_features=self.concat_filters1,corresponding_unet_filters=self.filters*8) |
|
|
self.in_filters6 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333) |
|
|
self.upsample7 = nn.ConvTranspose2d(in_channels=self.in_filters6,out_channels=self.filters*4,kernel_size=(2,2),stride=(2,2),padding = 0) |
|
|
self.concat_filters2 = self.filters*4+self.filters*4 |
|
|
self.multiresblock7 = Multiresblock(input_features=self.concat_filters2,corresponding_unet_filters=self.filters*4) |
|
|
self.in_filters7 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333) |
|
|
self.upsample8 = nn.ConvTranspose2d(in_channels=self.in_filters7,out_channels=self.filters*2,kernel_size=(2,2),stride=(2,2),padding = 0) |
|
|
self.concat_filters3 = self.filters*2+self.filters*2 |
|
|
self.multiresblock8 = Multiresblock(input_features=self.concat_filters3,corresponding_unet_filters=self.filters*2) |
|
|
self.in_filters8 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333) |
|
|
self.upsample9 = nn.ConvTranspose2d(in_channels=self.in_filters8,out_channels=self.filters,kernel_size=(2,2),stride=(2,2),padding = 0) |
|
|
self.concat_filters4 = self.filters+self.filters |
|
|
self.multiresblock9 = Multiresblock(input_features=self.concat_filters4,corresponding_unet_filters=self.filters) |
|
|
self.in_filters9 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333) |
|
|
self.conv_final = Conv2d_batchnorm(input_features=self.in_filters9,num_of_filters = self.nclasses, |
|
|
kernel_size = (1,1),activation='None') |
|
|
|
|
|
def forward(self,x : torch.Tensor)->torch.Tensor: |
|
|
x_multires1 = self.multiresblock1(x) |
|
|
x_pool1 = self.pool1(x_multires1) |
|
|
x_multires1 = self.respath1(x_multires1) |
|
|
x_multires2 = self.multiresblock2(x_pool1) |
|
|
x_pool2 = self.pool2(x_multires2) |
|
|
x_multires2 = self.respath2(x_multires2) |
|
|
x_multires3 = self.multiresblock3(x_pool2) |
|
|
x_pool3 = self.pool3(x_multires3) |
|
|
x_multires3 = self.respath3(x_multires3) |
|
|
x_multires4 = self.multiresblock4(x_pool3) |
|
|
x_pool4 = self.pool4(x_multires4) |
|
|
x_multires4 = self.respath4(x_multires4) |
|
|
x_multires5 = self.multiresblock5(x_pool4) |
|
|
up6 = torch.cat([self.upsample6(x_multires5),x_multires4],axis=1) |
|
|
x_multires6 = self.multiresblock6(up6) |
|
|
up7 = torch.cat([self.upsample7(x_multires6),x_multires3],axis=1) |
|
|
x_multires7 = self.multiresblock7(up7) |
|
|
up8 = torch.cat([self.upsample8(x_multires7),x_multires2],axis=1) |
|
|
x_multires8 = self.multiresblock8(up8) |
|
|
up9 = torch.cat([self.upsample9(x_multires8),x_multires1],axis=1) |
|
|
x_multires9 = self.multiresblock9(up9) |
|
|
if self.nclasses > 1: |
|
|
conv_final_layer = self.conv_final(x_multires9) |
|
|
else: |
|
|
conv_final_layer = torch.sigmoid(self.conv_final(x_multires9)) |
|
|
return conv_final_layer |