ASR / classification_network.py
SIDD2201's picture
Upload 363 files
f2688f7 verified
import torch
class CNN2D(torch.nn.Module):
def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15):
assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding)
super(CNN2D, self).__init__()
# Create convolutional blocks
self.conv_blocks = torch.nn.ModuleList()
prev_channel = 1 # Assuming the input is a grayscale image, modify if using more channels
for i in range(len(channels)):
# Add stacked conv layers
block = []
for j, conv_channel in enumerate(channels[i]):
block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i]))
prev_channel = conv_channel
# Add batch normalization
block.append(torch.nn.BatchNorm2d(prev_channel))
# Add ReLU activation
block.append(torch.nn.ReLU())
self.conv_blocks.append(torch.nn.Sequential(*block))
# Create pooling blocks
self.pool_blocks = torch.nn.ModuleList()
for i in range(len(pool_padding)):
# Adding Max Pool (reduces dimensions)
self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i]))
# Global pooling
self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(prev_channel, num_classes)
def forward(self, inwav):
for i in range(len(self.conv_blocks)):
# Apply convolutional layer
inwav = self.conv_blocks[i](inwav)
# Apply max pooling
if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav)
# Apply global pooling
out = self.global_pool(inwav).squeeze() # [batch_size, 256, 1, 1] -> [batch_size, 256]
out = self.linear(out) # [batch_size, num_classes]
return out
class ResBlock2D(torch.nn.Module):
def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad):
super(ResBlock2D, self).__init__()
self.res = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad),
torch.nn.BatchNorm2d(channel),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad),
torch.nn.BatchNorm2d(channel),
)
self.bn = torch.nn.BatchNorm2d(channel)
self.relu = torch.nn.ReLU()
def forward(self, x):
identity = x
x = self.res(x)
if x.shape[1] == identity.shape[1]:
x += identity
elif x.shape[1] > identity.shape[1]:
if x.shape[1] % identity.shape[1] == 0:
x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1)
else:
raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!")
else:
if identity.shape[1] % x.shape[1] == 0:
identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1)
else:
raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!")
x = identity
x = self.bn(x)
x = self.relu(x)
return x
class CNNRes2D(torch.nn.Module):
def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15):
assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding)
super(CNNRes2D, self).__init__()
# Create initial convolutional block
prev_channel = 1 # Assuming input has 1 channel, modify if needed
self.conv_block = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]),
torch.nn.BatchNorm2d(channels[0][0]),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]),
)
# Create residual blocks
prev_channel = channels[0][0]
self.res_blocks = torch.nn.ModuleList()
for i in range(1, len(channels)):
block = []
for j, conv_channel in enumerate(channels[i]):
block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i]))
prev_channel = conv_channel
self.res_blocks.append(torch.nn.Sequential(*block))
# Create pooling blocks
self.pool_blocks = torch.nn.ModuleList()
for i in range(1, len(pool_padding)):
self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i]))
# Global pooling
self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(prev_channel, num_classes)
def forward(self, inwav):
inwav = self.conv_block(inwav)
for i in range(len(self.res_blocks)):
inwav = self.res_blocks[i](inwav)
if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav)
out = self.global_pool(inwav).squeeze()
out = self.linear(out)
return out
# # Example instantiation of the network
# cnn2d_res = CNNRes2D(
# channels=[[128], [128]*2],
# conv_kernels=[(3, 3), (3, 3)],
# conv_strides=[(1, 1), (1, 1)],
# conv_padding=[(1, 1), (1, 1)],
# pool_padding=[(0, 0), (0, 0)]
# )