# Adapted from https://raw.githubusercontent.com/soumickmj/NCC1701/main/Bridge/models/ResNet/MickResNet.py import torch import torch.nn as nn import numpy as np import sys import torch.nn.functional as F from tricorder.torch.transforms import Interpolator __author__ = "Soumick Chatterjee" __copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL" __credits__ = ["Soumick Chatterjee"] __license__ = "GPL" __version__ = "1.0.0" __email__ = "soumick.chatterjee@ovgu.de" __status__ = "Published" class ResidualBlock(nn.Module): def __init__(self, in_features, drop_prob=0.2): #drop_prob=0.2 super(ResidualBlock, self).__init__() conv_block = [ layer_pad(1), layer_conv(in_features, in_features, 3), layer_norm(in_features), act_relu(), layer_drop(p=drop_prob, inplace=True), layer_pad(1), layer_conv(in_features, in_features, 3) , layer_norm(in_features) ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class DownsamplingBlock(nn.Module): def __init__(self, in_features, out_features): super(DownsamplingBlock, self).__init__() conv_block = [ layer_conv(in_features, out_features, 3, stride=2, padding=1), layer_norm(out_features), act_relu() ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return self.conv_block(x) class UpsamplingBlock(nn.Module): def __init__(self, in_features, out_features, mode="upconv", interpolator=None, post_interp_convtrans=False): super(UpsamplingBlock, self).__init__() self.interpolator = interpolator self.mode = mode self.post_interp_convtrans = post_interp_convtrans if self.post_interp_convtrans: self.post_conv = layer_conv(out_features, out_features, 1) if mode == "upconv": conv_block = [ layer_convtrans(in_features, out_features, 3, stride=2, padding=1, output_padding=1), ] else: conv_block = [ layer_pad(1), layer_conv(in_features, out_features, 3), ] conv_block += [ layer_norm(out_features), act_relu() ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x, out_shape=None): if self.mode == "upconv": if self.post_interp_convtrans: x = self.conv_block(x) if x.shape[2:] != out_shape: return self.post_conv(self.interpolator(x, out_shape)) else: return x else: return self.conv_block(x) else: return self.conv_block(self.interpolator(x, out_shape)) class GP_ReconResNet(nn.Module): def __init__(self, in_channels=1, n_classes=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2, #res_drop_prob=0.2 out_act="softmax", forwardV=0, upinterp_algo='upconv', post_interp_convtrans=False, is3D=False): #should use 14 as that gives number of trainable parameters close to number of possible pixel values in a image 256x256 super(GP_ReconResNet, self).__init__() layers = {} if is3D: sys.exit("ResNet: for implemented for 3D, ReflectionPad3d code is required") layers["layer_conv"] = nn.Conv3d layers["layer_convtrans"] = nn.ConvTranspose3d if do_batchnorm: layers["layer_norm"] = nn.BatchNorm3d else: layers["layer_norm"] = nn.InstanceNorm3d layers["layer_drop"] = nn.Dropout3d layers["layer_pad"] = ReflectionPad3d layers["interp_mode"] = 'trilinear' else: layers["layer_conv"] = nn.Conv2d layers["layer_convtrans"] = nn.ConvTranspose2d if do_batchnorm: layers["layer_norm"] = nn.BatchNorm2d else: layers["layer_norm"] = nn.InstanceNorm2d layers["layer_drop"] = nn.Dropout2d layers["layer_pad"] = nn.ReflectionPad2d layers["interp_mode"] = 'bilinear' if is_relu_leaky: layers["act_relu"] = nn.PReLU else: layers["act_relu"] = nn.ReLU globals().update(layers) self.forwardV = forwardV self.upinterp_algo = upinterp_algo interpolator = Interpolator(mode=layers["interp_mode"] if self.upinterp_algo == "upconv" else self.upinterp_algo) in_channels = in_channels out_channels = n_classes # Initial convolution block intialConv = [ layer_pad(3), layer_conv(in_channels, starting_nfeatures, 7), layer_norm(starting_nfeatures), act_relu() ] # Downsampling [need to save the shape for upsample] downsam = [] in_features = starting_nfeatures out_features = in_features*2 for _ in range(updown_blocks): downsam.append(DownsamplingBlock(in_features, out_features)) in_features = out_features out_features = in_features*2 # Residual blocks resblocks = [] for _ in range(res_blocks): resblocks += [ResidualBlock(in_features, res_drop_prob)] # Upsampling upsam = [] out_features = in_features//2 for _ in range(updown_blocks): upsam.append(UpsamplingBlock(in_features, out_features, self.upinterp_algo, interpolator, post_interp_convtrans)) in_features = out_features out_features = in_features//2 # Output layer finalconv = [ layer_conv(starting_nfeatures, out_channels, 1), ] #kernel size changed from 7 to 1 to make GMP work if out_act == "sigmoid": finalconv += [ nn.Sigmoid(), ] elif out_act == "relu": finalconv += [ act_relu(), ] elif out_act == "tanh": finalconv += [ nn.Tanh(), ] elif out_act == "softmax": finalconv += [ nn.Softmax2d(), ] self.intialConv = nn.Sequential(*intialConv) self.downsam = nn.ModuleList(downsam) self.resblocks = nn.Sequential(*resblocks) self.upsam = nn.ModuleList(upsam) self.finalconv = nn.Sequential(*finalconv) ### For Classification, following Florian's GP-UNet self.GMP = nn.AdaptiveMaxPool2d((1, 1)) if self.forwardV == 0: self.forward = self.forwardV0 elif self.forwardV == 1: sys.exit("ResNet: its identical to V0 in case of GP_ResNet") elif self.forwardV == 2: self.forward = self.forwardV2 elif self.forwardV == 3: self.forward = self.forwardV3 elif self.forwardV == 4: self.forward = self.forwardV4 elif self.forwardV == 5: self.forward = self.forwardV5 def final_step(self, x): if self.training: x = self.GMP(x) return self.finalconv(x).view(x.shape[0],-1) else: mask = self.finalconv(x) x = self.GMP(x) pred = self.finalconv(x).view(x.shape[0],-1) return pred, mask def forwardV0(self, x): #v0: Original Version x = self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(x.shape[2:]) x = downblock(x) x = self.resblocks(x) for i, upblock in enumerate(self.upsam): x = upblock(x, shapes[-1-i]) return self.final_step(x) def forwardV2(self, x): #v2: residual of v1 + input to the residual blocks added back with the output out = self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(out.shape[2:]) out = downblock(out) out = out + self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) return self.final_step(out) def forwardV3(self, x): #v3: residual of v2 + input of the initial conv added back with the output out = x + self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(out.shape[2:]) out = downblock(out) out = out + self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) return self.final_step(out) def forwardV4(self, x): #v4: residual of v3 + output of the initial conv added back with the input of final conv iniconv = x + self.intialConv(x) shapes = [] if len(self.downsam) > 0: for i, downblock in enumerate(self.downsam): if i == 0: shapes.append(iniconv.shape[2:]) out = downblock(iniconv) else: shapes.append(out.shape[2:]) out = downblock(out) else: out = iniconv out = out + self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) out = iniconv + out return self.final_step(out) def forwardV5(self, x): #v5: residual of v4 + individual down blocks with individual up blocks outs = [x + self.intialConv(x)] shapes = [] for i, downblock in enumerate(self.downsam): shapes.append(outs[-1].shape[2:]) outs.append(downblock(outs[-1])) outs[-1] = outs[-1] + self.resblocks(outs[-1]) for i, upblock in enumerate(self.upsam): outs[-1] = upblock(outs[-1], shapes[-1-i]) outs[-1] = outs[-2] + outs.pop() return self.final_step(outs.pop()) #to run it here from this script, uncomment the following if __name__ == "__main__": #to run it image = torch.rand(2, 1, 240, 240) #specify your image: batch size, Channel, height, width model = GP_ReconResNet(in_channels=1, n_classes=3, upinterp_algo='sinc') #Initialize the model # model.eval() out = model(image) print(model(image))