import torch class CNN2D(torch.nn.Module): def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) super(CNN2D, self).__init__() # Create convolutional blocks self.conv_blocks = torch.nn.ModuleList() prev_channel = 1 # Assuming the input is a grayscale image, modify if using more channels for i in range(len(channels)): # Add stacked conv layers block = [] for j, conv_channel in enumerate(channels[i]): block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i])) prev_channel = conv_channel # Add batch normalization block.append(torch.nn.BatchNorm2d(prev_channel)) # Add ReLU activation block.append(torch.nn.ReLU()) self.conv_blocks.append(torch.nn.Sequential(*block)) # Create pooling blocks self.pool_blocks = torch.nn.ModuleList() for i in range(len(pool_padding)): # Adding Max Pool (reduces dimensions) self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) # Global pooling self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.linear = torch.nn.Linear(prev_channel, num_classes) def forward(self, inwav): for i in range(len(self.conv_blocks)): # Apply convolutional layer inwav = self.conv_blocks[i](inwav) # Apply max pooling if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) # Apply global pooling out = self.global_pool(inwav).squeeze() # [batch_size, 256, 1, 1] -> [batch_size, 256] out = self.linear(out) # [batch_size, num_classes] return out class ResBlock2D(torch.nn.Module): def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad): super(ResBlock2D, self).__init__() self.res = torch.nn.Sequential( torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), torch.nn.BatchNorm2d(channel), torch.nn.ReLU(), torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), torch.nn.BatchNorm2d(channel), ) self.bn = torch.nn.BatchNorm2d(channel) self.relu = torch.nn.ReLU() def forward(self, x): identity = x x = self.res(x) if x.shape[1] == identity.shape[1]: x += identity elif x.shape[1] > identity.shape[1]: if x.shape[1] % identity.shape[1] == 0: x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1) else: raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") else: if identity.shape[1] % x.shape[1] == 0: identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1) else: raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") x = identity x = self.bn(x) x = self.relu(x) return x class CNNRes2D(torch.nn.Module): def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) super(CNNRes2D, self).__init__() # Create initial convolutional block prev_channel = 1 # Assuming input has 1 channel, modify if needed self.conv_block = torch.nn.Sequential( torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]), torch.nn.BatchNorm2d(channels[0][0]), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]), ) # Create residual blocks prev_channel = channels[0][0] self.res_blocks = torch.nn.ModuleList() for i in range(1, len(channels)): block = [] for j, conv_channel in enumerate(channels[i]): block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i])) prev_channel = conv_channel self.res_blocks.append(torch.nn.Sequential(*block)) # Create pooling blocks self.pool_blocks = torch.nn.ModuleList() for i in range(1, len(pool_padding)): self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) # Global pooling self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.linear = torch.nn.Linear(prev_channel, num_classes) def forward(self, inwav): inwav = self.conv_block(inwav) for i in range(len(self.res_blocks)): inwav = self.res_blocks[i](inwav) if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) out = self.global_pool(inwav).squeeze() out = self.linear(out) return out # # Example instantiation of the network # cnn2d_res = CNNRes2D( # channels=[[128], [128]*2], # conv_kernels=[(3, 3), (3, 3)], # conv_strides=[(1, 1), (1, 1)], # conv_padding=[(1, 1), (1, 1)], # pool_padding=[(0, 0), (0, 0)] # )