EMOTIA / models /audio.py
Manav2op's picture
Upload folder using huggingface_hub
25d0747 verified
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model, Wav2Vec2Config
import librosa
import numpy as np
class AudioEmotionModel(nn.Module):
"""
CNN + Transformer for audio emotion recognition.
Uses Wav2Vec2 backbone for feature extraction.
"""
def __init__(self, num_emotions=7, pretrained=True):
super().__init__()
self.num_emotions = num_emotions
# Load pre-trained Wav2Vec2
if pretrained:
self.wav2vec = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
else:
config = Wav2Vec2Config()
self.wav2vec = Wav2Vec2Model(config)
# Freeze base layers
for param in self.wav2vec.parameters():
param.requires_grad = False
hidden_size = self.wav2vec.config.hidden_size
# CNN for local feature extraction
self.cnn = nn.Sequential(
nn.Conv1d(hidden_size, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(256, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1)
)
# Transformer for sequence modeling
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=128, nhead=8, dim_feedforward=512),
num_layers=4
)
# Emotion classification
self.emotion_classifier = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, num_emotions)
)
# Stress/confidence estimation
self.stress_head = nn.Sequential(
nn.Linear(128, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()
)
def forward(self, input_values):
"""
input_values: batch of audio waveforms (B, T)
Returns: emotion_logits, stress_score
"""
# Extract features with Wav2Vec2
outputs = self.wav2vec(input_values)
hidden_states = outputs.last_hidden_state # (B, T, hidden_size)
# Transpose for CNN (B, hidden_size, T)
hidden_states = hidden_states.transpose(1, 2)
# CNN feature extraction
cnn_features = self.cnn(hidden_states).squeeze(-1) # (B, 128)
# Add sequence dimension for transformer
cnn_features = cnn_features.unsqueeze(1) # (B, 1, 128)
# Transformer
transformer_out = self.transformer(cnn_features) # (B, 1, 128)
pooled_features = transformer_out.mean(dim=1) # (B, 128)
emotion_logits = self.emotion_classifier(pooled_features)
stress_score = self.stress_head(pooled_features)
return emotion_logits, stress_score.squeeze()
def preprocess_audio(self, audio_path, sample_rate=16000, duration=3.0):
"""
Load and preprocess audio file.
"""
# Load audio
audio, sr = librosa.load(audio_path, sr=sample_rate, duration=duration)
# Pad/truncate to fixed length
target_length = int(sample_rate * duration)
if len(audio) < target_length:
audio = np.pad(audio, (0, target_length - len(audio)))
else:
audio = audio[:target_length]
return torch.tensor(audio, dtype=torch.float32)
def extract_prosody_features(self, audio):
"""
Extract additional prosody features (pitch, rhythm, etc.)
"""
# Pitch
pitches, magnitudes = librosa.piptrack(y=audio.numpy(), sr=16000)
pitch = np.mean(pitches[pitches > 0])
# RMS energy
rms = librosa.feature.rms(y=audio.numpy())[0].mean()
# Zero-crossing rate
zcr = librosa.feature.zero_crossing_rate(y=audio.numpy())[0].mean()
return torch.tensor([pitch, rms, zcr], dtype=torch.float32)