Upload folder using huggingface_hub
Browse files- .gitattributes +3 -33
- README.md +209 -0
- audio_preprocessor.py +191 -0
- best_model.pth +3 -0
- config.json +39 -0
- modeling_vit_emotion.py +135 -0
- requirements.txt +5 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,5 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.
|
|
|
|
|
|
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: mit
|
| 4 |
+
tags:
|
| 5 |
+
- audio
|
| 6 |
+
- emotion-recognition
|
| 7 |
+
- valence-arousal
|
| 8 |
+
- vision-transformer
|
| 9 |
+
- pytorch
|
| 10 |
+
- music-emotion-recognition
|
| 11 |
+
datasets:
|
| 12 |
+
- custom
|
| 13 |
+
metrics:
|
| 14 |
+
- mse
|
| 15 |
+
- mae
|
| 16 |
+
pipeline_tag: audio-classification
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# ViT for Audio Emotion Recognition (Valence-Arousal)
|
| 20 |
+
|
| 21 |
+
This model is a fine-tuned Vision Transformer (ViT) for audio emotion recognition, predicting valence and arousal values in the continuous range of -1 to 1.
|
| 22 |
+
|
| 23 |
+
## Model Description
|
| 24 |
+
|
| 25 |
+
- **Base Model**: google/vit-base-patch16-224-in21k
|
| 26 |
+
- **Task**: Audio emotion recognition (regression)
|
| 27 |
+
- **Output**: Valence and Arousal predictions (2D continuous emotion space)
|
| 28 |
+
- **Range**: [-1, 1] for both dimensions
|
| 29 |
+
- **Input**: Mel spectrogram images (224x224 RGB)
|
| 30 |
+
|
| 31 |
+
## Architecture
|
| 32 |
+
|
| 33 |
+
```
|
| 34 |
+
ViT Base (86M parameters)
|
| 35 |
+
↓
|
| 36 |
+
CLS Token Output (768-dim)
|
| 37 |
+
↓
|
| 38 |
+
LayerNorm + Dropout
|
| 39 |
+
↓
|
| 40 |
+
Linear (768 → 512) + GELU + Dropout
|
| 41 |
+
↓
|
| 42 |
+
Linear (512 → 128) + GELU + Dropout
|
| 43 |
+
↓
|
| 44 |
+
Linear (128 → 2) + Tanh
|
| 45 |
+
↓
|
| 46 |
+
[Valence, Arousal] ∈ [-1, 1]²
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Usage
|
| 50 |
+
|
| 51 |
+
### Prerequisites
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
pip install torch transformers librosa numpy pillow
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Loading the Model
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
import torch
|
| 61 |
+
from transformers import ViTModel
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
|
| 64 |
+
class ViTForEmotionRegression(nn.Module):
|
| 65 |
+
def __init__(self, model_name='google/vit-base-patch16-224-in21k', num_emotions=2, dropout=0.1):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.vit = ViTModel.from_pretrained(model_name)
|
| 68 |
+
hidden_size = self.vit.config.hidden_size
|
| 69 |
+
|
| 70 |
+
self.head = nn.Sequential(
|
| 71 |
+
nn.LayerNorm(hidden_size),
|
| 72 |
+
nn.Dropout(dropout),
|
| 73 |
+
nn.Linear(hidden_size, 512),
|
| 74 |
+
nn.GELU(),
|
| 75 |
+
nn.Dropout(dropout),
|
| 76 |
+
nn.Linear(512, 128),
|
| 77 |
+
nn.GELU(),
|
| 78 |
+
nn.Dropout(dropout),
|
| 79 |
+
nn.Linear(128, num_emotions),
|
| 80 |
+
nn.Tanh()
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, pixel_values):
|
| 84 |
+
outputs = self.vit(pixel_values)
|
| 85 |
+
cls_output = outputs.last_hidden_state[:, 0]
|
| 86 |
+
return self.head(cls_output)
|
| 87 |
+
|
| 88 |
+
# Load the model
|
| 89 |
+
model = ViTForEmotionRegression()
|
| 90 |
+
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
|
| 91 |
+
model.eval()
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Audio Preprocessing
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
import librosa
|
| 98 |
+
import numpy as np
|
| 99 |
+
from PIL import Image
|
| 100 |
+
import torch
|
| 101 |
+
from torchvision import transforms
|
| 102 |
+
|
| 103 |
+
def preprocess_audio(audio_path):
|
| 104 |
+
# Load audio
|
| 105 |
+
y, sr = librosa.load(audio_path, sr=22050, duration=30)
|
| 106 |
+
|
| 107 |
+
# Generate mel spectrogram
|
| 108 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 109 |
+
y=y, sr=sr, n_mels=128, hop_length=512, n_fft=2048
|
| 110 |
+
)
|
| 111 |
+
mel_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 112 |
+
|
| 113 |
+
# Normalize to 0-255 for RGB conversion
|
| 114 |
+
mel_normalized = ((mel_db - mel_db.min()) / (mel_db.max() - mel_db.min()) * 255).astype(np.uint8)
|
| 115 |
+
|
| 116 |
+
# Convert to RGB image
|
| 117 |
+
image = Image.fromarray(mel_normalized).convert('RGB')
|
| 118 |
+
image = image.resize((224, 224))
|
| 119 |
+
|
| 120 |
+
# Apply ImageNet normalization
|
| 121 |
+
transform = transforms.Compose([
|
| 122 |
+
transforms.ToTensor(),
|
| 123 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
return transform(image).unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
# Process audio
|
| 129 |
+
audio_tensor = preprocess_audio('your_audio.mp3')
|
| 130 |
+
|
| 131 |
+
# Predict emotions
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
predictions = model(audio_tensor)
|
| 134 |
+
valence, arousal = predictions[0].tolist()
|
| 135 |
+
|
| 136 |
+
print(f"Valence: {valence:.3f}, Arousal: {arousal:.3f}")
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Emotion Quadrant Mapping
|
| 140 |
+
|
| 141 |
+
```python
|
| 142 |
+
def classify_emotion(valence, arousal):
|
| 143 |
+
if valence >= 0 and arousal >= 0:
|
| 144 |
+
return "HAPPY" if valence > arousal else "EXCITED"
|
| 145 |
+
elif valence >= 0 and arousal < 0:
|
| 146 |
+
return "CALM" if abs(arousal) > valence else "CONTENT"
|
| 147 |
+
elif valence < 0 and arousal < 0:
|
| 148 |
+
return "SAD" if abs(valence) > abs(arousal) else "BORED"
|
| 149 |
+
else: # valence < 0 and arousal >= 0
|
| 150 |
+
return "TENSE" if arousal > abs(valence) else "ANGRY"
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Model Details
|
| 154 |
+
|
| 155 |
+
- **Parameters**: ~86.8M
|
| 156 |
+
- **Model Size**: ~331 MB
|
| 157 |
+
- **Framework**: PyTorch
|
| 158 |
+
- **Base Architecture**: ViT-Base (12 layers, 768 hidden, 12 heads)
|
| 159 |
+
- **Custom Head**: 3-layer MLP with GELU activations
|
| 160 |
+
- **Training Data**: Custom audio emotion dataset
|
| 161 |
+
- **Training**: Fine-tuned with MSE loss on valence-arousal targets
|
| 162 |
+
|
| 163 |
+
## Emotion Space
|
| 164 |
+
|
| 165 |
+
The model predicts emotions in the 2D circumplex model:
|
| 166 |
+
|
| 167 |
+
```
|
| 168 |
+
High Arousal
|
| 169 |
+
|
|
| 170 |
+
Angry Tense Excited
|
| 171 |
+
|
|
| 172 |
+
Sad -------- + -------- Happy
|
| 173 |
+
|
|
| 174 |
+
Bored Calm Content
|
| 175 |
+
|
|
| 176 |
+
Low Arousal
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
- **Valence**: Negative (unpleasant) ↔ Positive (pleasant)
|
| 180 |
+
- **Arousal**: Low (calm) ↔ High (energetic)
|
| 181 |
+
|
| 182 |
+
## Performance
|
| 183 |
+
|
| 184 |
+
The model outputs continuous predictions that can be:
|
| 185 |
+
- Used directly for emotion intensity analysis
|
| 186 |
+
- Mapped to discrete emotion categories
|
| 187 |
+
- Visualized on emotion quadrant plots
|
| 188 |
+
|
| 189 |
+
## Limitations
|
| 190 |
+
|
| 191 |
+
- Trained on music/audio, performance may vary on speech
|
| 192 |
+
- Requires mel spectrogram preprocessing
|
| 193 |
+
- Fixed 30-second audio duration (or first 30s)
|
| 194 |
+
- Cultural bias depending on training data
|
| 195 |
+
|
| 196 |
+
## Citation
|
| 197 |
+
|
| 198 |
+
```bibtex
|
| 199 |
+
@misc{sentio-vit-emotion,
|
| 200 |
+
title={Vision Transformer for Audio Emotion Recognition},
|
| 201 |
+
author={SentioApp Team},
|
| 202 |
+
year={2025},
|
| 203 |
+
publisher={HuggingFace}
|
| 204 |
+
}
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
## License
|
| 208 |
+
|
| 209 |
+
MIT License
|
audio_preprocessor.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Preprocessing for Model Inference
|
| 3 |
+
|
| 4 |
+
This module handles loading audio files and converting them to spectrograms
|
| 5 |
+
for input to the emotion prediction models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import librosa
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AudioPreprocessor:
|
| 15 |
+
"""
|
| 16 |
+
Preprocessor for converting audio files to model-ready spectrograms.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
sample_rate=22050,
|
| 21 |
+
duration=30,
|
| 22 |
+
n_mels=128,
|
| 23 |
+
hop_length=512,
|
| 24 |
+
n_fft=2048,
|
| 25 |
+
fmin=20,
|
| 26 |
+
fmax=8000,
|
| 27 |
+
image_size=224):
|
| 28 |
+
"""
|
| 29 |
+
Initialize audio preprocessor.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
sample_rate: Audio sampling rate (Hz)
|
| 33 |
+
duration: Audio clip duration (seconds)
|
| 34 |
+
n_mels: Number of mel-frequency bins
|
| 35 |
+
hop_length: Hop length for STFT
|
| 36 |
+
n_fft: FFT window size
|
| 37 |
+
fmin: Minimum frequency
|
| 38 |
+
fmax: Maximum frequency
|
| 39 |
+
image_size: Target image size for model input (224 for ViT)
|
| 40 |
+
"""
|
| 41 |
+
self.sample_rate = sample_rate
|
| 42 |
+
self.duration = duration
|
| 43 |
+
self.n_mels = n_mels
|
| 44 |
+
self.hop_length = hop_length
|
| 45 |
+
self.n_fft = n_fft
|
| 46 |
+
self.fmin = fmin
|
| 47 |
+
self.fmax = fmax
|
| 48 |
+
self.image_size = image_size
|
| 49 |
+
|
| 50 |
+
# ImageNet normalization (used by ViT)
|
| 51 |
+
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 52 |
+
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 53 |
+
|
| 54 |
+
def load_audio(self, audio_path):
|
| 55 |
+
"""
|
| 56 |
+
Load audio file.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
audio_path: Path to audio file
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
audio: Audio waveform
|
| 63 |
+
sr: Sample rate
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
audio, sr = librosa.load(
|
| 67 |
+
audio_path,
|
| 68 |
+
sr=self.sample_rate,
|
| 69 |
+
duration=self.duration,
|
| 70 |
+
mono=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Pad or truncate to exact duration
|
| 74 |
+
target_length = self.sample_rate * self.duration
|
| 75 |
+
if len(audio) < target_length:
|
| 76 |
+
audio = np.pad(audio, (0, target_length - len(audio)))
|
| 77 |
+
else:
|
| 78 |
+
audio = audio[:target_length]
|
| 79 |
+
|
| 80 |
+
return audio, sr
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
raise RuntimeError(f"Failed to load audio from {audio_path}: {e}")
|
| 84 |
+
|
| 85 |
+
def audio_to_melspectrogram(self, audio):
|
| 86 |
+
"""
|
| 87 |
+
Convert audio waveform to mel spectrogram.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
audio: Audio waveform
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
mel_spec: Mel spectrogram in dB scale
|
| 94 |
+
"""
|
| 95 |
+
# Compute mel spectrogram
|
| 96 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 97 |
+
y=audio,
|
| 98 |
+
sr=self.sample_rate,
|
| 99 |
+
n_mels=self.n_mels,
|
| 100 |
+
n_fft=self.n_fft,
|
| 101 |
+
hop_length=self.hop_length,
|
| 102 |
+
fmin=self.fmin,
|
| 103 |
+
fmax=self.fmax
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Convert to dB scale
|
| 107 |
+
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 108 |
+
|
| 109 |
+
return mel_spec_db
|
| 110 |
+
|
| 111 |
+
def spectrogram_to_image(self, mel_spec):
|
| 112 |
+
"""
|
| 113 |
+
Convert mel spectrogram to RGB image tensor for ViT input.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
mel_spec: Mel spectrogram (n_mels, time_steps)
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
image_tensor: Tensor of shape (3, 224, 224) normalized for ViT
|
| 120 |
+
"""
|
| 121 |
+
# Normalize to [0, 1]
|
| 122 |
+
spec_min = mel_spec.min()
|
| 123 |
+
spec_max = mel_spec.max()
|
| 124 |
+
spec_norm = (mel_spec - spec_min) / (spec_max - spec_min + 1e-8)
|
| 125 |
+
|
| 126 |
+
# Resize to 224x224 using PIL
|
| 127 |
+
spec_pil = Image.fromarray((spec_norm * 255).astype(np.uint8))
|
| 128 |
+
spec_resized = spec_pil.resize(
|
| 129 |
+
(self.image_size, self.image_size),
|
| 130 |
+
Image.Resampling.BILINEAR
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Convert back to numpy and normalize
|
| 134 |
+
spec_array = np.array(spec_resized).astype(np.float32) / 255.0
|
| 135 |
+
|
| 136 |
+
# Convert grayscale to RGB by replicating channels
|
| 137 |
+
spec_rgb = np.stack([spec_array, spec_array, spec_array], axis=0)
|
| 138 |
+
|
| 139 |
+
# Convert to torch tensor
|
| 140 |
+
image_tensor = torch.from_numpy(spec_rgb).float()
|
| 141 |
+
|
| 142 |
+
# Apply ImageNet normalization
|
| 143 |
+
image_tensor = (image_tensor - self.imagenet_mean) / self.imagenet_std
|
| 144 |
+
|
| 145 |
+
return image_tensor
|
| 146 |
+
|
| 147 |
+
def preprocess(self, audio_path):
|
| 148 |
+
"""
|
| 149 |
+
Complete preprocessing pipeline: audio file -> model-ready tensor.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
audio_path: Path to audio file
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
image_tensor: Tensor of shape (3, 224, 224) ready for model input
|
| 156 |
+
mel_spec: Raw mel spectrogram (for visualization)
|
| 157 |
+
"""
|
| 158 |
+
# Load audio
|
| 159 |
+
audio, _ = self.load_audio(audio_path)
|
| 160 |
+
|
| 161 |
+
# Convert to mel spectrogram
|
| 162 |
+
mel_spec = self.audio_to_melspectrogram(audio)
|
| 163 |
+
|
| 164 |
+
# Convert to image tensor
|
| 165 |
+
image_tensor = self.spectrogram_to_image(mel_spec)
|
| 166 |
+
|
| 167 |
+
return image_tensor, mel_spec
|
| 168 |
+
|
| 169 |
+
def preprocess_batch(self, audio_paths):
|
| 170 |
+
"""
|
| 171 |
+
Preprocess multiple audio files.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
audio_paths: List of audio file paths
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
batch_tensor: Tensor of shape (batch_size, 3, 224, 224)
|
| 178 |
+
mel_specs: List of mel spectrograms
|
| 179 |
+
"""
|
| 180 |
+
tensors = []
|
| 181 |
+
mel_specs = []
|
| 182 |
+
|
| 183 |
+
for audio_path in audio_paths:
|
| 184 |
+
tensor, mel_spec = self.preprocess(audio_path)
|
| 185 |
+
tensors.append(tensor)
|
| 186 |
+
mel_specs.append(mel_spec)
|
| 187 |
+
|
| 188 |
+
# Stack into batch
|
| 189 |
+
batch_tensor = torch.stack(tensors, dim=0)
|
| 190 |
+
|
| 191 |
+
return batch_tensor, mel_specs
|
best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d63bd6d99e635cb349a509259e5d20c9645135eac728f685d7204caee8890c6f
|
| 3 |
+
size 347485262
|
config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ViTForEmotionRegression"
|
| 4 |
+
],
|
| 5 |
+
"model_type": "vit-emotion",
|
| 6 |
+
"task": "audio-emotion-recognition",
|
| 7 |
+
"base_model": "google/vit-base-patch16-224-in21k",
|
| 8 |
+
"num_emotions": 2,
|
| 9 |
+
"emotion_dimensions": ["valence", "arousal"],
|
| 10 |
+
"output_range": [-1, 1],
|
| 11 |
+
"input_size": [224, 224],
|
| 12 |
+
"num_channels": 3,
|
| 13 |
+
"patch_size": 16,
|
| 14 |
+
"hidden_size": 768,
|
| 15 |
+
"num_hidden_layers": 12,
|
| 16 |
+
"num_attention_heads": 12,
|
| 17 |
+
"intermediate_size": 3072,
|
| 18 |
+
"hidden_act": "gelu",
|
| 19 |
+
"hidden_dropout_prob": 0.1,
|
| 20 |
+
"attention_probs_dropout_prob": 0.0,
|
| 21 |
+
"initializer_range": 0.02,
|
| 22 |
+
"layer_norm_eps": 1e-12,
|
| 23 |
+
"image_size": 224,
|
| 24 |
+
"qkv_bias": true,
|
| 25 |
+
"audio_processing": {
|
| 26 |
+
"sample_rate": 22050,
|
| 27 |
+
"n_mels": 128,
|
| 28 |
+
"hop_length": 512,
|
| 29 |
+
"n_fft": 2048,
|
| 30 |
+
"mel_spectrogram_format": "RGB",
|
| 31 |
+
"normalization": "imagenet"
|
| 32 |
+
},
|
| 33 |
+
"regression_head": {
|
| 34 |
+
"architecture": "768 -> 512 -> 128 -> 2",
|
| 35 |
+
"activation": "gelu",
|
| 36 |
+
"dropout": 0.1,
|
| 37 |
+
"output_activation": "tanh"
|
| 38 |
+
}
|
| 39 |
+
}
|
modeling_vit_emotion.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vision Transformer (ViT) Model Definition for Emotion Regression
|
| 3 |
+
|
| 4 |
+
This file defines the ViT model architecture used for valence-arousal prediction.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from transformers import ViTModel, ViTConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ViTForEmotionRegression(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Vision Transformer for emotion regression (valence and arousal prediction).
|
| 15 |
+
|
| 16 |
+
Architecture:
|
| 17 |
+
- Pre-trained ViT backbone (google/vit-base-patch16-224-in21k)
|
| 18 |
+
- Custom regression head for 2D emotion prediction
|
| 19 |
+
- Dropout for regularization
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name='google/vit-base-patch16-224-in21k',
|
| 23 |
+
num_emotions=2, freeze_backbone=False, dropout=0.1):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# Load pre-trained ViT model
|
| 27 |
+
try:
|
| 28 |
+
self.vit = ViTModel.from_pretrained(model_name)
|
| 29 |
+
print(f"✅ Loaded pre-trained ViT from {model_name}")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"⚠️ Could not load pre-trained model: {e}")
|
| 32 |
+
print(" Initializing with random weights...")
|
| 33 |
+
config = ViTConfig()
|
| 34 |
+
self.vit = ViTModel(config)
|
| 35 |
+
|
| 36 |
+
# Freeze backbone if specified
|
| 37 |
+
if freeze_backbone:
|
| 38 |
+
for param in self.vit.parameters():
|
| 39 |
+
param.requires_grad = False
|
| 40 |
+
print(f"❄️ Frozen ViT backbone")
|
| 41 |
+
|
| 42 |
+
# Get hidden size from ViT config
|
| 43 |
+
hidden_size = self.vit.config.hidden_size
|
| 44 |
+
|
| 45 |
+
# Regression head for emotion prediction (named 'head' to match saved checkpoint)
|
| 46 |
+
# Architecture: 768 -> 512 -> 128 -> 2
|
| 47 |
+
self.head = nn.Sequential(
|
| 48 |
+
nn.LayerNorm(hidden_size), # [0] weight: [768], bias: [768]
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
nn.Linear(hidden_size, 512), # [2] weight: [512, 768], bias: [512]
|
| 51 |
+
nn.GELU(),
|
| 52 |
+
nn.Dropout(dropout),
|
| 53 |
+
nn.Linear(512, 128), # [5] weight: [128, 512], bias: [128]
|
| 54 |
+
nn.GELU(),
|
| 55 |
+
nn.Dropout(dropout),
|
| 56 |
+
nn.Linear(128, num_emotions), # [8] weight: [2, 128], bias: [2]
|
| 57 |
+
nn.Tanh() # Output in range [-1, 1]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Forward pass through the model.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
pixel_values: Input images tensor of shape (batch_size, 3, 224, 224)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Emotion predictions tensor of shape (batch_size, 2) [valence, arousal]
|
| 69 |
+
"""
|
| 70 |
+
# Get ViT outputs
|
| 71 |
+
outputs = self.vit(pixel_values)
|
| 72 |
+
cls_output = outputs.last_hidden_state[:, 0]
|
| 73 |
+
|
| 74 |
+
# Predict emotions
|
| 75 |
+
emotion_predictions = self.head(cls_output)
|
| 76 |
+
return emotion_predictions
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class MobileViTStudent(nn.Module):
|
| 80 |
+
"""
|
| 81 |
+
Lightweight MobileViT student model for emotion regression.
|
| 82 |
+
Used in distilled version for faster inference.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, num_emotions=2, dropout=0.1):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
# Lightweight CNN backbone
|
| 89 |
+
self.conv_stem = nn.Sequential(
|
| 90 |
+
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
|
| 91 |
+
nn.BatchNorm2d(32),
|
| 92 |
+
nn.ReLU(inplace=True),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Mobile inverted bottleneck blocks
|
| 96 |
+
self.blocks = nn.Sequential(
|
| 97 |
+
self._make_mb_block(32, 64, stride=2),
|
| 98 |
+
self._make_mb_block(64, 128, stride=2),
|
| 99 |
+
self._make_mb_block(128, 256, stride=2),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Global pooling
|
| 103 |
+
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
| 104 |
+
|
| 105 |
+
# Regression head (named 'head' to match saved checkpoint)
|
| 106 |
+
self.head = nn.Sequential(
|
| 107 |
+
nn.Flatten(),
|
| 108 |
+
nn.Linear(256, 128),
|
| 109 |
+
nn.ReLU(inplace=True),
|
| 110 |
+
nn.Dropout(dropout),
|
| 111 |
+
nn.Linear(128, num_emotions),
|
| 112 |
+
nn.Tanh()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def _make_mb_block(self, in_channels, out_channels, stride=1):
|
| 116 |
+
"""Create Mobile Inverted Bottleneck block"""
|
| 117 |
+
return nn.Sequential(
|
| 118 |
+
# Depthwise
|
| 119 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3,
|
| 120 |
+
stride=stride, padding=1, groups=in_channels),
|
| 121 |
+
nn.BatchNorm2d(in_channels),
|
| 122 |
+
nn.ReLU(inplace=True),
|
| 123 |
+
# Pointwise
|
| 124 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
| 125 |
+
nn.BatchNorm2d(out_channels),
|
| 126 |
+
nn.ReLU(inplace=True),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
"""Forward pass"""
|
| 131 |
+
x = self.conv_stem(x)
|
| 132 |
+
x = self.blocks(x)
|
| 133 |
+
x = self.global_pool(x)
|
| 134 |
+
emotions = self.head(x)
|
| 135 |
+
return emotions
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.30.0
|
| 3 |
+
librosa>=0.10.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
pillow>=10.0.0
|