winzerprince commited on
Commit
3b5844e
·
verified ·
1 Parent(s): c6f4158

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. .gitattributes +3 -33
  2. README.md +209 -0
  3. audio_preprocessor.py +191 -0
  4. best_model.pth +3 -0
  5. config.json +39 -0
  6. modeling_vit_emotion.py +135 -0
  7. requirements.txt +5 -0
.gitattributes CHANGED
@@ -1,35 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.onnx filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - audio
6
+ - emotion-recognition
7
+ - valence-arousal
8
+ - vision-transformer
9
+ - pytorch
10
+ - music-emotion-recognition
11
+ datasets:
12
+ - custom
13
+ metrics:
14
+ - mse
15
+ - mae
16
+ pipeline_tag: audio-classification
17
+ ---
18
+
19
+ # ViT for Audio Emotion Recognition (Valence-Arousal)
20
+
21
+ This model is a fine-tuned Vision Transformer (ViT) for audio emotion recognition, predicting valence and arousal values in the continuous range of -1 to 1.
22
+
23
+ ## Model Description
24
+
25
+ - **Base Model**: google/vit-base-patch16-224-in21k
26
+ - **Task**: Audio emotion recognition (regression)
27
+ - **Output**: Valence and Arousal predictions (2D continuous emotion space)
28
+ - **Range**: [-1, 1] for both dimensions
29
+ - **Input**: Mel spectrogram images (224x224 RGB)
30
+
31
+ ## Architecture
32
+
33
+ ```
34
+ ViT Base (86M parameters)
35
+
36
+ CLS Token Output (768-dim)
37
+
38
+ LayerNorm + Dropout
39
+
40
+ Linear (768 → 512) + GELU + Dropout
41
+
42
+ Linear (512 → 128) + GELU + Dropout
43
+
44
+ Linear (128 → 2) + Tanh
45
+
46
+ [Valence, Arousal] ∈ [-1, 1]²
47
+ ```
48
+
49
+ ## Usage
50
+
51
+ ### Prerequisites
52
+
53
+ ```bash
54
+ pip install torch transformers librosa numpy pillow
55
+ ```
56
+
57
+ ### Loading the Model
58
+
59
+ ```python
60
+ import torch
61
+ from transformers import ViTModel
62
+ import torch.nn as nn
63
+
64
+ class ViTForEmotionRegression(nn.Module):
65
+ def __init__(self, model_name='google/vit-base-patch16-224-in21k', num_emotions=2, dropout=0.1):
66
+ super().__init__()
67
+ self.vit = ViTModel.from_pretrained(model_name)
68
+ hidden_size = self.vit.config.hidden_size
69
+
70
+ self.head = nn.Sequential(
71
+ nn.LayerNorm(hidden_size),
72
+ nn.Dropout(dropout),
73
+ nn.Linear(hidden_size, 512),
74
+ nn.GELU(),
75
+ nn.Dropout(dropout),
76
+ nn.Linear(512, 128),
77
+ nn.GELU(),
78
+ nn.Dropout(dropout),
79
+ nn.Linear(128, num_emotions),
80
+ nn.Tanh()
81
+ )
82
+
83
+ def forward(self, pixel_values):
84
+ outputs = self.vit(pixel_values)
85
+ cls_output = outputs.last_hidden_state[:, 0]
86
+ return self.head(cls_output)
87
+
88
+ # Load the model
89
+ model = ViTForEmotionRegression()
90
+ model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
91
+ model.eval()
92
+ ```
93
+
94
+ ### Audio Preprocessing
95
+
96
+ ```python
97
+ import librosa
98
+ import numpy as np
99
+ from PIL import Image
100
+ import torch
101
+ from torchvision import transforms
102
+
103
+ def preprocess_audio(audio_path):
104
+ # Load audio
105
+ y, sr = librosa.load(audio_path, sr=22050, duration=30)
106
+
107
+ # Generate mel spectrogram
108
+ mel_spec = librosa.feature.melspectrogram(
109
+ y=y, sr=sr, n_mels=128, hop_length=512, n_fft=2048
110
+ )
111
+ mel_db = librosa.power_to_db(mel_spec, ref=np.max)
112
+
113
+ # Normalize to 0-255 for RGB conversion
114
+ mel_normalized = ((mel_db - mel_db.min()) / (mel_db.max() - mel_db.min()) * 255).astype(np.uint8)
115
+
116
+ # Convert to RGB image
117
+ image = Image.fromarray(mel_normalized).convert('RGB')
118
+ image = image.resize((224, 224))
119
+
120
+ # Apply ImageNet normalization
121
+ transform = transforms.Compose([
122
+ transforms.ToTensor(),
123
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
124
+ ])
125
+
126
+ return transform(image).unsqueeze(0)
127
+
128
+ # Process audio
129
+ audio_tensor = preprocess_audio('your_audio.mp3')
130
+
131
+ # Predict emotions
132
+ with torch.no_grad():
133
+ predictions = model(audio_tensor)
134
+ valence, arousal = predictions[0].tolist()
135
+
136
+ print(f"Valence: {valence:.3f}, Arousal: {arousal:.3f}")
137
+ ```
138
+
139
+ ### Emotion Quadrant Mapping
140
+
141
+ ```python
142
+ def classify_emotion(valence, arousal):
143
+ if valence >= 0 and arousal >= 0:
144
+ return "HAPPY" if valence > arousal else "EXCITED"
145
+ elif valence >= 0 and arousal < 0:
146
+ return "CALM" if abs(arousal) > valence else "CONTENT"
147
+ elif valence < 0 and arousal < 0:
148
+ return "SAD" if abs(valence) > abs(arousal) else "BORED"
149
+ else: # valence < 0 and arousal >= 0
150
+ return "TENSE" if arousal > abs(valence) else "ANGRY"
151
+ ```
152
+
153
+ ## Model Details
154
+
155
+ - **Parameters**: ~86.8M
156
+ - **Model Size**: ~331 MB
157
+ - **Framework**: PyTorch
158
+ - **Base Architecture**: ViT-Base (12 layers, 768 hidden, 12 heads)
159
+ - **Custom Head**: 3-layer MLP with GELU activations
160
+ - **Training Data**: Custom audio emotion dataset
161
+ - **Training**: Fine-tuned with MSE loss on valence-arousal targets
162
+
163
+ ## Emotion Space
164
+
165
+ The model predicts emotions in the 2D circumplex model:
166
+
167
+ ```
168
+ High Arousal
169
+ |
170
+ Angry Tense Excited
171
+ |
172
+ Sad -------- + -------- Happy
173
+ |
174
+ Bored Calm Content
175
+ |
176
+ Low Arousal
177
+ ```
178
+
179
+ - **Valence**: Negative (unpleasant) ↔ Positive (pleasant)
180
+ - **Arousal**: Low (calm) ↔ High (energetic)
181
+
182
+ ## Performance
183
+
184
+ The model outputs continuous predictions that can be:
185
+ - Used directly for emotion intensity analysis
186
+ - Mapped to discrete emotion categories
187
+ - Visualized on emotion quadrant plots
188
+
189
+ ## Limitations
190
+
191
+ - Trained on music/audio, performance may vary on speech
192
+ - Requires mel spectrogram preprocessing
193
+ - Fixed 30-second audio duration (or first 30s)
194
+ - Cultural bias depending on training data
195
+
196
+ ## Citation
197
+
198
+ ```bibtex
199
+ @misc{sentio-vit-emotion,
200
+ title={Vision Transformer for Audio Emotion Recognition},
201
+ author={SentioApp Team},
202
+ year={2025},
203
+ publisher={HuggingFace}
204
+ }
205
+ ```
206
+
207
+ ## License
208
+
209
+ MIT License
audio_preprocessor.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Preprocessing for Model Inference
3
+
4
+ This module handles loading audio files and converting them to spectrograms
5
+ for input to the emotion prediction models.
6
+ """
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+
14
+ class AudioPreprocessor:
15
+ """
16
+ Preprocessor for converting audio files to model-ready spectrograms.
17
+ """
18
+
19
+ def __init__(self,
20
+ sample_rate=22050,
21
+ duration=30,
22
+ n_mels=128,
23
+ hop_length=512,
24
+ n_fft=2048,
25
+ fmin=20,
26
+ fmax=8000,
27
+ image_size=224):
28
+ """
29
+ Initialize audio preprocessor.
30
+
31
+ Args:
32
+ sample_rate: Audio sampling rate (Hz)
33
+ duration: Audio clip duration (seconds)
34
+ n_mels: Number of mel-frequency bins
35
+ hop_length: Hop length for STFT
36
+ n_fft: FFT window size
37
+ fmin: Minimum frequency
38
+ fmax: Maximum frequency
39
+ image_size: Target image size for model input (224 for ViT)
40
+ """
41
+ self.sample_rate = sample_rate
42
+ self.duration = duration
43
+ self.n_mels = n_mels
44
+ self.hop_length = hop_length
45
+ self.n_fft = n_fft
46
+ self.fmin = fmin
47
+ self.fmax = fmax
48
+ self.image_size = image_size
49
+
50
+ # ImageNet normalization (used by ViT)
51
+ self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
52
+ self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
53
+
54
+ def load_audio(self, audio_path):
55
+ """
56
+ Load audio file.
57
+
58
+ Args:
59
+ audio_path: Path to audio file
60
+
61
+ Returns:
62
+ audio: Audio waveform
63
+ sr: Sample rate
64
+ """
65
+ try:
66
+ audio, sr = librosa.load(
67
+ audio_path,
68
+ sr=self.sample_rate,
69
+ duration=self.duration,
70
+ mono=True
71
+ )
72
+
73
+ # Pad or truncate to exact duration
74
+ target_length = self.sample_rate * self.duration
75
+ if len(audio) < target_length:
76
+ audio = np.pad(audio, (0, target_length - len(audio)))
77
+ else:
78
+ audio = audio[:target_length]
79
+
80
+ return audio, sr
81
+
82
+ except Exception as e:
83
+ raise RuntimeError(f"Failed to load audio from {audio_path}: {e}")
84
+
85
+ def audio_to_melspectrogram(self, audio):
86
+ """
87
+ Convert audio waveform to mel spectrogram.
88
+
89
+ Args:
90
+ audio: Audio waveform
91
+
92
+ Returns:
93
+ mel_spec: Mel spectrogram in dB scale
94
+ """
95
+ # Compute mel spectrogram
96
+ mel_spec = librosa.feature.melspectrogram(
97
+ y=audio,
98
+ sr=self.sample_rate,
99
+ n_mels=self.n_mels,
100
+ n_fft=self.n_fft,
101
+ hop_length=self.hop_length,
102
+ fmin=self.fmin,
103
+ fmax=self.fmax
104
+ )
105
+
106
+ # Convert to dB scale
107
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
108
+
109
+ return mel_spec_db
110
+
111
+ def spectrogram_to_image(self, mel_spec):
112
+ """
113
+ Convert mel spectrogram to RGB image tensor for ViT input.
114
+
115
+ Args:
116
+ mel_spec: Mel spectrogram (n_mels, time_steps)
117
+
118
+ Returns:
119
+ image_tensor: Tensor of shape (3, 224, 224) normalized for ViT
120
+ """
121
+ # Normalize to [0, 1]
122
+ spec_min = mel_spec.min()
123
+ spec_max = mel_spec.max()
124
+ spec_norm = (mel_spec - spec_min) / (spec_max - spec_min + 1e-8)
125
+
126
+ # Resize to 224x224 using PIL
127
+ spec_pil = Image.fromarray((spec_norm * 255).astype(np.uint8))
128
+ spec_resized = spec_pil.resize(
129
+ (self.image_size, self.image_size),
130
+ Image.Resampling.BILINEAR
131
+ )
132
+
133
+ # Convert back to numpy and normalize
134
+ spec_array = np.array(spec_resized).astype(np.float32) / 255.0
135
+
136
+ # Convert grayscale to RGB by replicating channels
137
+ spec_rgb = np.stack([spec_array, spec_array, spec_array], axis=0)
138
+
139
+ # Convert to torch tensor
140
+ image_tensor = torch.from_numpy(spec_rgb).float()
141
+
142
+ # Apply ImageNet normalization
143
+ image_tensor = (image_tensor - self.imagenet_mean) / self.imagenet_std
144
+
145
+ return image_tensor
146
+
147
+ def preprocess(self, audio_path):
148
+ """
149
+ Complete preprocessing pipeline: audio file -> model-ready tensor.
150
+
151
+ Args:
152
+ audio_path: Path to audio file
153
+
154
+ Returns:
155
+ image_tensor: Tensor of shape (3, 224, 224) ready for model input
156
+ mel_spec: Raw mel spectrogram (for visualization)
157
+ """
158
+ # Load audio
159
+ audio, _ = self.load_audio(audio_path)
160
+
161
+ # Convert to mel spectrogram
162
+ mel_spec = self.audio_to_melspectrogram(audio)
163
+
164
+ # Convert to image tensor
165
+ image_tensor = self.spectrogram_to_image(mel_spec)
166
+
167
+ return image_tensor, mel_spec
168
+
169
+ def preprocess_batch(self, audio_paths):
170
+ """
171
+ Preprocess multiple audio files.
172
+
173
+ Args:
174
+ audio_paths: List of audio file paths
175
+
176
+ Returns:
177
+ batch_tensor: Tensor of shape (batch_size, 3, 224, 224)
178
+ mel_specs: List of mel spectrograms
179
+ """
180
+ tensors = []
181
+ mel_specs = []
182
+
183
+ for audio_path in audio_paths:
184
+ tensor, mel_spec = self.preprocess(audio_path)
185
+ tensors.append(tensor)
186
+ mel_specs.append(mel_spec)
187
+
188
+ # Stack into batch
189
+ batch_tensor = torch.stack(tensors, dim=0)
190
+
191
+ return batch_tensor, mel_specs
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63bd6d99e635cb349a509259e5d20c9645135eac728f685d7204caee8890c6f
3
+ size 347485262
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForEmotionRegression"
4
+ ],
5
+ "model_type": "vit-emotion",
6
+ "task": "audio-emotion-recognition",
7
+ "base_model": "google/vit-base-patch16-224-in21k",
8
+ "num_emotions": 2,
9
+ "emotion_dimensions": ["valence", "arousal"],
10
+ "output_range": [-1, 1],
11
+ "input_size": [224, 224],
12
+ "num_channels": 3,
13
+ "patch_size": 16,
14
+ "hidden_size": 768,
15
+ "num_hidden_layers": 12,
16
+ "num_attention_heads": 12,
17
+ "intermediate_size": 3072,
18
+ "hidden_act": "gelu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "attention_probs_dropout_prob": 0.0,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_eps": 1e-12,
23
+ "image_size": 224,
24
+ "qkv_bias": true,
25
+ "audio_processing": {
26
+ "sample_rate": 22050,
27
+ "n_mels": 128,
28
+ "hop_length": 512,
29
+ "n_fft": 2048,
30
+ "mel_spectrogram_format": "RGB",
31
+ "normalization": "imagenet"
32
+ },
33
+ "regression_head": {
34
+ "architecture": "768 -> 512 -> 128 -> 2",
35
+ "activation": "gelu",
36
+ "dropout": 0.1,
37
+ "output_activation": "tanh"
38
+ }
39
+ }
modeling_vit_emotion.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision Transformer (ViT) Model Definition for Emotion Regression
3
+
4
+ This file defines the ViT model architecture used for valence-arousal prediction.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import ViTModel, ViTConfig
10
+
11
+
12
+ class ViTForEmotionRegression(nn.Module):
13
+ """
14
+ Vision Transformer for emotion regression (valence and arousal prediction).
15
+
16
+ Architecture:
17
+ - Pre-trained ViT backbone (google/vit-base-patch16-224-in21k)
18
+ - Custom regression head for 2D emotion prediction
19
+ - Dropout for regularization
20
+ """
21
+
22
+ def __init__(self, model_name='google/vit-base-patch16-224-in21k',
23
+ num_emotions=2, freeze_backbone=False, dropout=0.1):
24
+ super().__init__()
25
+
26
+ # Load pre-trained ViT model
27
+ try:
28
+ self.vit = ViTModel.from_pretrained(model_name)
29
+ print(f"✅ Loaded pre-trained ViT from {model_name}")
30
+ except Exception as e:
31
+ print(f"⚠️ Could not load pre-trained model: {e}")
32
+ print(" Initializing with random weights...")
33
+ config = ViTConfig()
34
+ self.vit = ViTModel(config)
35
+
36
+ # Freeze backbone if specified
37
+ if freeze_backbone:
38
+ for param in self.vit.parameters():
39
+ param.requires_grad = False
40
+ print(f"❄️ Frozen ViT backbone")
41
+
42
+ # Get hidden size from ViT config
43
+ hidden_size = self.vit.config.hidden_size
44
+
45
+ # Regression head for emotion prediction (named 'head' to match saved checkpoint)
46
+ # Architecture: 768 -> 512 -> 128 -> 2
47
+ self.head = nn.Sequential(
48
+ nn.LayerNorm(hidden_size), # [0] weight: [768], bias: [768]
49
+ nn.Dropout(dropout),
50
+ nn.Linear(hidden_size, 512), # [2] weight: [512, 768], bias: [512]
51
+ nn.GELU(),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(512, 128), # [5] weight: [128, 512], bias: [128]
54
+ nn.GELU(),
55
+ nn.Dropout(dropout),
56
+ nn.Linear(128, num_emotions), # [8] weight: [2, 128], bias: [2]
57
+ nn.Tanh() # Output in range [-1, 1]
58
+ )
59
+
60
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Forward pass through the model.
63
+
64
+ Args:
65
+ pixel_values: Input images tensor of shape (batch_size, 3, 224, 224)
66
+
67
+ Returns:
68
+ Emotion predictions tensor of shape (batch_size, 2) [valence, arousal]
69
+ """
70
+ # Get ViT outputs
71
+ outputs = self.vit(pixel_values)
72
+ cls_output = outputs.last_hidden_state[:, 0]
73
+
74
+ # Predict emotions
75
+ emotion_predictions = self.head(cls_output)
76
+ return emotion_predictions
77
+
78
+
79
+ class MobileViTStudent(nn.Module):
80
+ """
81
+ Lightweight MobileViT student model for emotion regression.
82
+ Used in distilled version for faster inference.
83
+ """
84
+
85
+ def __init__(self, num_emotions=2, dropout=0.1):
86
+ super().__init__()
87
+
88
+ # Lightweight CNN backbone
89
+ self.conv_stem = nn.Sequential(
90
+ nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
91
+ nn.BatchNorm2d(32),
92
+ nn.ReLU(inplace=True),
93
+ )
94
+
95
+ # Mobile inverted bottleneck blocks
96
+ self.blocks = nn.Sequential(
97
+ self._make_mb_block(32, 64, stride=2),
98
+ self._make_mb_block(64, 128, stride=2),
99
+ self._make_mb_block(128, 256, stride=2),
100
+ )
101
+
102
+ # Global pooling
103
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
104
+
105
+ # Regression head (named 'head' to match saved checkpoint)
106
+ self.head = nn.Sequential(
107
+ nn.Flatten(),
108
+ nn.Linear(256, 128),
109
+ nn.ReLU(inplace=True),
110
+ nn.Dropout(dropout),
111
+ nn.Linear(128, num_emotions),
112
+ nn.Tanh()
113
+ )
114
+
115
+ def _make_mb_block(self, in_channels, out_channels, stride=1):
116
+ """Create Mobile Inverted Bottleneck block"""
117
+ return nn.Sequential(
118
+ # Depthwise
119
+ nn.Conv2d(in_channels, in_channels, kernel_size=3,
120
+ stride=stride, padding=1, groups=in_channels),
121
+ nn.BatchNorm2d(in_channels),
122
+ nn.ReLU(inplace=True),
123
+ # Pointwise
124
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
125
+ nn.BatchNorm2d(out_channels),
126
+ nn.ReLU(inplace=True),
127
+ )
128
+
129
+ def forward(self, x):
130
+ """Forward pass"""
131
+ x = self.conv_stem(x)
132
+ x = self.blocks(x)
133
+ x = self.global_pool(x)
134
+ emotions = self.head(x)
135
+ return emotions
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ librosa>=0.10.0
4
+ numpy>=1.24.0
5
+ pillow>=10.0.0