File size: 5,777 Bytes
f2688f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
class CNN2D(torch.nn.Module):
def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15):
assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding)
super(CNN2D, self).__init__()
# Create convolutional blocks
self.conv_blocks = torch.nn.ModuleList()
prev_channel = 1 # Assuming the input is a grayscale image, modify if using more channels
for i in range(len(channels)):
# Add stacked conv layers
block = []
for j, conv_channel in enumerate(channels[i]):
block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i]))
prev_channel = conv_channel
# Add batch normalization
block.append(torch.nn.BatchNorm2d(prev_channel))
# Add ReLU activation
block.append(torch.nn.ReLU())
self.conv_blocks.append(torch.nn.Sequential(*block))
# Create pooling blocks
self.pool_blocks = torch.nn.ModuleList()
for i in range(len(pool_padding)):
# Adding Max Pool (reduces dimensions)
self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i]))
# Global pooling
self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(prev_channel, num_classes)
def forward(self, inwav):
for i in range(len(self.conv_blocks)):
# Apply convolutional layer
inwav = self.conv_blocks[i](inwav)
# Apply max pooling
if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav)
# Apply global pooling
out = self.global_pool(inwav).squeeze() # [batch_size, 256, 1, 1] -> [batch_size, 256]
out = self.linear(out) # [batch_size, num_classes]
return out
class ResBlock2D(torch.nn.Module):
def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad):
super(ResBlock2D, self).__init__()
self.res = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad),
torch.nn.BatchNorm2d(channel),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad),
torch.nn.BatchNorm2d(channel),
)
self.bn = torch.nn.BatchNorm2d(channel)
self.relu = torch.nn.ReLU()
def forward(self, x):
identity = x
x = self.res(x)
if x.shape[1] == identity.shape[1]:
x += identity
elif x.shape[1] > identity.shape[1]:
if x.shape[1] % identity.shape[1] == 0:
x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1)
else:
raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!")
else:
if identity.shape[1] % x.shape[1] == 0:
identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1)
else:
raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!")
x = identity
x = self.bn(x)
x = self.relu(x)
return x
class CNNRes2D(torch.nn.Module):
def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15):
assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding)
super(CNNRes2D, self).__init__()
# Create initial convolutional block
prev_channel = 1 # Assuming input has 1 channel, modify if needed
self.conv_block = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]),
torch.nn.BatchNorm2d(channels[0][0]),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]),
)
# Create residual blocks
prev_channel = channels[0][0]
self.res_blocks = torch.nn.ModuleList()
for i in range(1, len(channels)):
block = []
for j, conv_channel in enumerate(channels[i]):
block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i]))
prev_channel = conv_channel
self.res_blocks.append(torch.nn.Sequential(*block))
# Create pooling blocks
self.pool_blocks = torch.nn.ModuleList()
for i in range(1, len(pool_padding)):
self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i]))
# Global pooling
self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(prev_channel, num_classes)
def forward(self, inwav):
inwav = self.conv_block(inwav)
for i in range(len(self.res_blocks)):
inwav = self.res_blocks[i](inwav)
if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav)
out = self.global_pool(inwav).squeeze()
out = self.linear(out)
return out
# # Example instantiation of the network
# cnn2d_res = CNNRes2D(
# channels=[[128], [128]*2],
# conv_kernels=[(3, 3), (3, 3)],
# conv_strides=[(1, 1), (1, 1)],
# conv_padding=[(1, 1), (1, 1)],
# pool_padding=[(0, 0), (0, 0)]
# )
|