--- tags: - audio-classification - speech-emotion-recognition - automatic-speech-recognition - emotion-recognition - wav2vec2 - toronto-emotional-speech-dataset datasets: - toronto-emotional-speech-dataset metrics: - accuracy: 0.85 base_model: facebook/wav2vec2-base model-index: - name: dynann/emotion-speech-recognition results: [] --- # Wav2Vec2 for Emotion Recognition This model is a fine-tuned version of [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the Toronto Emotional Speech Dataset (TESS). It achieves the following results on the evaluation set: - **Accuracy**: 85% - **Loss**: ~3.76 ## Model Description The model classifies audio input into 7 discrete emotions: - Angry - Disgust - Fear - Happy - Neutral - Pleasant Surprise (`ps`) - Sad It uses a custom classification head on top of the frozen Wav2Vec2 base model. ## Usage **Note**: You must define the custom `Wav2Vec2ForEmotionClassification` class to load this model. ```python import torch import torch.nn as nn from transformers import Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2Config # Define the Custom Model Class class Wav2Vec2ForEmotionClassification(nn.Module): def __init__(self, config): super().__init__() self.wav2vec2 = Wav2Vec2Model(config) self.classifier = nn.Sequential( nn.Linear(config.hidden_size, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, config.num_labels), ) def forward(self, input_values, attention_mask=None, labels=None, **kwargs): outputs = self.wav2vec2(input_values, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state pooled_output = torch.mean(hidden_states, dim=1) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels.view(-1)) return { "loss": loss, "logits": logits } # Load Model model_id = "dynann/emotion-speech-recognition" config = Wav2Vec2Config.from_pretrained(model_id) model = Wav2Vec2ForEmotionClassification(config) model.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin")) processor = Wav2Vec2Processor.from_pretrained(model_id) ``` ## Training Procedure - **Epochs**: 10 - **Batch Size**: 32 (optimized for P100) / 8 (local) - **Learning Rate**: 3e-4 - **Feature Encoder**: Frozen