speech-emotion-recognition / emotion_cnn_v2.py
saadmannan's picture
Upload emotion_cnn_v2.py with huggingface_hub
093d5c6 verified
"""
Enhanced CNN Architecture with Residual Connections and Attention
for Speech Emotion Recognition
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
"""Residual block with batch normalization"""
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Shortcut connection
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += self.shortcut(residual)
out = F.relu(out)
return out
class ChannelAttention(nn.Module):
"""Channel attention mechanism"""
def __init__(self, channels, reduction=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(),
nn.Linear(channels // reduction, channels, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
# Average pooling
avg_out = self.fc(self.avg_pool(x).view(b, c))
# Max pooling
max_out = self.fc(self.max_pool(x).view(b, c))
out = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
return x * out.expand_as(x)
class ImprovedEmotionCNN(nn.Module):
"""
Enhanced CNN with:
- Residual connections
- Channel attention
- Deeper architecture
- Better regularization
"""
def __init__(self, num_classes=8, dropout_rate=0.4):
super(ImprovedEmotionCNN, self).__init__()
# Initial convolution
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Residual blocks
self.layer1 = self._make_layer(64, 64, 2, stride=1)
self.layer2 = self._make_layer(64, 128, 2, stride=2)
self.layer3 = self._make_layer(128, 256, 2, stride=2)
self.layer4 = self._make_layer(256, 512, 2, stride=2)
# Channel attention
self.attention1 = ChannelAttention(64)
self.attention2 = ChannelAttention(128)
self.attention3 = ChannelAttention(256)
self.attention4 = ChannelAttention(512)
# Global pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.maxpool_global = nn.AdaptiveMaxPool2d((1, 1))
# Fully connected layers
self.dropout = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(512 * 2, 512) # *2 because we concat avg and max pool
self.bn_fc1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, 256)
self.bn_fc2 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
layers = []
layers.append(ResidualBlock(in_channels, out_channels, stride))
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_channels, out_channels, 1))
return nn.Sequential(*layers)
def forward(self, x):
# Initial convolution
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# Residual blocks with attention
x = self.layer1(x)
x = self.attention1(x)
x = self.layer2(x)
x = self.attention2(x)
x = self.layer3(x)
x = self.attention3(x)
x = self.layer4(x)
x = self.attention4(x)
# Global pooling (both avg and max)
avg_pool = self.avgpool(x)
max_pool = self.maxpool_global(x)
x = torch.cat([avg_pool, max_pool], dim=1)
x = x.view(x.size(0), -1)
# Fully connected layers
x = self.dropout(x)
x = self.fc1(x)
x = self.bn_fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.bn_fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
def get_num_params(self):
"""Calculate total number of parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def test_model():
"""Test model architecture"""
model = ImprovedEmotionCNN(num_classes=8)
# Test input (updated for 196 features)
batch_size = 4
x = torch.randn(batch_size, 1, 196, 128)
# Forward pass
output = model(x)
print("=" * 60)
print("Enhanced Model Architecture Test")
print("=" * 60)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Total parameters: {model.get_num_params():,}")
print("=" * 60)
# Print model summary
print("\nModel Summary:")
print(model)
if __name__ == "__main__":
test_model()