Raemih commited on
Commit
507ebe9
·
verified ·
1 Parent(s): 728a4a9

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +162 -0
model_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import librosa
6
+ import opensmile
7
+ import joblib
8
+ from transformers import WhisperModel, WhisperFeatureExtractor
9
+
10
+ # Loaded from metadata.json at startup — do not hardcode here
11
+ EMOTION_LABELS = None
12
+ NUM_EMOTIONS = None
13
+ GEMAPS_DIM = None
14
+ WHISPER_DIM = None
15
+ SAMPLE_RATE = None
16
+ MAX_DURATION = None
17
+ MAX_SAMPLES = None
18
+
19
+ _smile = None
20
+ _whisper_fe = None
21
+ _scalers = None
22
+ _fusion = None
23
+ _mlp = None
24
+
25
+
26
+ def get_smile():
27
+ global _smile
28
+ if _smile is None:
29
+ _smile = opensmile.Smile(
30
+ feature_set=opensmile.FeatureSet.eGeMAPSv02,
31
+ feature_level=opensmile.FeatureLevel.Functionals,
32
+ )
33
+ return _smile
34
+
35
+
36
+ def get_whisper_fe():
37
+ global _whisper_fe
38
+ if _whisper_fe is None:
39
+ _whisper_fe = WhisperFeatureExtractor.from_pretrained(
40
+ "openai/whisper-tiny", sampling_rate=SAMPLE_RATE
41
+ )
42
+ return _whisper_fe
43
+
44
+
45
+ class GeMAPS_MLP(nn.Module):
46
+ def __init__(self, in_dim, hidden=128, num_classes=5, dropout=0.3):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Linear(in_dim, hidden), nn.ReLU(),
50
+ nn.BatchNorm1d(hidden), nn.Dropout(dropout),
51
+ nn.Linear(hidden, hidden // 2), nn.ReLU(),
52
+ nn.BatchNorm1d(hidden // 2), nn.Dropout(dropout),
53
+ nn.Linear(hidden // 2, num_classes)
54
+ )
55
+ def forward(self, x):
56
+ return self.net(x)
57
+
58
+
59
+ class FusionSER(nn.Module):
60
+ def __init__(self, num_classes=5, dropout=0.3, gemaps_proj=64, whisper_proj=256):
61
+ super().__init__()
62
+ self.whisper_enc = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
63
+ self.w_proj = nn.Sequential(
64
+ nn.Linear(WHISPER_DIM, whisper_proj), nn.ReLU(), nn.Dropout(dropout)
65
+ )
66
+ self.g_proj = nn.Sequential(
67
+ nn.Linear(GEMAPS_DIM, gemaps_proj), nn.ReLU(), nn.Dropout(dropout)
68
+ )
69
+ self.classifier = nn.Sequential(
70
+ nn.Linear(whisper_proj + gemaps_proj, 128), nn.ReLU(),
71
+ nn.BatchNorm1d(128), nn.Dropout(dropout),
72
+ nn.Linear(128, num_classes)
73
+ )
74
+
75
+ def forward(self, whisper_inp, gemaps):
76
+ w = self.whisper_enc(whisper_inp).last_hidden_state.mean(dim=1)
77
+ w = self.w_proj(w)
78
+ g = self.g_proj(gemaps)
79
+ return self.classifier(torch.cat([w, g], dim=-1))
80
+
81
+
82
+ def load_models(model_dir="."):
83
+ global _fusion, _mlp, _scalers
84
+ global EMOTION_LABELS, NUM_EMOTIONS, GEMAPS_DIM, WHISPER_DIM
85
+ global SAMPLE_RATE, MAX_DURATION, MAX_SAMPLES
86
+
87
+ with open(os.path.join(model_dir, "metadata.json")) as f:
88
+ meta = json.load(f)
89
+
90
+ EMOTION_LABELS = meta["emotion_labels"]
91
+ NUM_EMOTIONS = meta["num_emotions"]
92
+ GEMAPS_DIM = meta["gemaps_dim"]
93
+ WHISPER_DIM = meta["whisper_dim"]
94
+ SAMPLE_RATE = meta["sample_rate"]
95
+ MAX_DURATION = meta["max_duration"]
96
+ MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION)
97
+
98
+ _fusion = FusionSER(num_classes=NUM_EMOTIONS)
99
+ _fusion.load_state_dict(
100
+ torch.load(os.path.join(model_dir, "fusion_ser.pt"), map_location="cpu")
101
+ )
102
+ _fusion.eval()
103
+
104
+ _mlp = GeMAPS_MLP(in_dim=GEMAPS_DIM, num_classes=NUM_EMOTIONS)
105
+ _mlp.load_state_dict(
106
+ torch.load(os.path.join(model_dir, "gemaps_mlp.pt"), map_location="cpu")
107
+ )
108
+ _mlp.eval()
109
+
110
+ _scalers = joblib.load(os.path.join(model_dir, "language_scalers.pkl"))
111
+
112
+ # Pre-warm feature extractors
113
+ get_smile()
114
+ get_whisper_fe()
115
+ print("All models loaded.")
116
+
117
+
118
+ def extract_gemaps(audio_path, language):
119
+ try:
120
+ feats = get_smile().process_file(audio_path)
121
+ arr = feats.values[0].astype(np.float32).reshape(1, -1)
122
+ except Exception:
123
+ arr = np.zeros((1, GEMAPS_DIM), dtype=np.float32)
124
+ # Apply the same per-language scaler fitted in notebook 1
125
+ scaler = _scalers.get(language) or _scalers.get("english")
126
+ arr = scaler.transform(arr)
127
+ return torch.from_numpy(arr.astype(np.float32)) # (1, 88)
128
+
129
+
130
+ def extract_whisper(audio_path):
131
+ try:
132
+ audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)
133
+ audio = audio[:MAX_SAMPLES]
134
+ inp = get_whisper_fe()(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
135
+ return inp.input_features # (1, 80, 3000)
136
+ except Exception:
137
+ return torch.zeros(1, 80, 3000)
138
+
139
+
140
+ @torch.no_grad()
141
+ def predict(audio_path, language="english", mode="fusion"):
142
+ if _fusion is None:
143
+ raise RuntimeError("Call load_models() first.")
144
+
145
+ gemaps = extract_gemaps(audio_path, language)
146
+ whisper = extract_whisper(audio_path) if mode in ("fusion", "ensemble") else None
147
+
148
+ probs_f = probs_m = None
149
+
150
+ if mode in ("fusion", "ensemble"):
151
+ probs_f = torch.softmax(_fusion(whisper, gemaps), -1).squeeze(0).numpy()
152
+ if mode in ("gemaps", "ensemble"):
153
+ probs_m = torch.softmax(_mlp(gemaps), -1).squeeze(0).numpy()
154
+
155
+ if mode == "fusion":
156
+ probs = probs_f
157
+ elif mode == "gemaps":
158
+ probs = probs_m
159
+ else:
160
+ probs = 0.6 * probs_f + 0.4 * probs_m
161
+
162
+ return {label: float(probs[i]) for i, label in enumerate(EMOTION_LABELS)}