File size: 5,164 Bytes
507ebe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json, os
import numpy as np
import torch
import torch.nn as nn
import librosa
import opensmile
import joblib
from transformers import WhisperModel, WhisperFeatureExtractor

# Loaded from metadata.json at startup — do not hardcode here
EMOTION_LABELS = None
NUM_EMOTIONS   = None
GEMAPS_DIM     = None
WHISPER_DIM    = None
SAMPLE_RATE    = None
MAX_DURATION   = None
MAX_SAMPLES    = None

_smile      = None
_whisper_fe = None
_scalers    = None
_fusion     = None
_mlp        = None


def get_smile():
    global _smile
    if _smile is None:
        _smile = opensmile.Smile(
            feature_set=opensmile.FeatureSet.eGeMAPSv02,
            feature_level=opensmile.FeatureLevel.Functionals,
        )
    return _smile


def get_whisper_fe():
    global _whisper_fe
    if _whisper_fe is None:
        _whisper_fe = WhisperFeatureExtractor.from_pretrained(
            "openai/whisper-tiny", sampling_rate=SAMPLE_RATE
        )
    return _whisper_fe


class GeMAPS_MLP(nn.Module):
    def __init__(self, in_dim, hidden=128, num_classes=5, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.BatchNorm1d(hidden), nn.Dropout(dropout),
            nn.Linear(hidden, hidden // 2), nn.ReLU(),
            nn.BatchNorm1d(hidden // 2), nn.Dropout(dropout),
            nn.Linear(hidden // 2, num_classes)
        )
    def forward(self, x):
        return self.net(x)


class FusionSER(nn.Module):
    def __init__(self, num_classes=5, dropout=0.3, gemaps_proj=64, whisper_proj=256):
        super().__init__()
        self.whisper_enc = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
        self.w_proj = nn.Sequential(
            nn.Linear(WHISPER_DIM, whisper_proj), nn.ReLU(), nn.Dropout(dropout)
        )
        self.g_proj = nn.Sequential(
            nn.Linear(GEMAPS_DIM, gemaps_proj), nn.ReLU(), nn.Dropout(dropout)
        )
        self.classifier = nn.Sequential(
            nn.Linear(whisper_proj + gemaps_proj, 128), nn.ReLU(),
            nn.BatchNorm1d(128), nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

    def forward(self, whisper_inp, gemaps):
        w = self.whisper_enc(whisper_inp).last_hidden_state.mean(dim=1)
        w = self.w_proj(w)
        g = self.g_proj(gemaps)
        return self.classifier(torch.cat([w, g], dim=-1))


def load_models(model_dir="."):
    global _fusion, _mlp, _scalers
    global EMOTION_LABELS, NUM_EMOTIONS, GEMAPS_DIM, WHISPER_DIM
    global SAMPLE_RATE, MAX_DURATION, MAX_SAMPLES

    with open(os.path.join(model_dir, "metadata.json")) as f:
        meta = json.load(f)

    EMOTION_LABELS = meta["emotion_labels"]
    NUM_EMOTIONS   = meta["num_emotions"]
    GEMAPS_DIM     = meta["gemaps_dim"]
    WHISPER_DIM    = meta["whisper_dim"]
    SAMPLE_RATE    = meta["sample_rate"]
    MAX_DURATION   = meta["max_duration"]
    MAX_SAMPLES    = int(SAMPLE_RATE * MAX_DURATION)

    _fusion = FusionSER(num_classes=NUM_EMOTIONS)
    _fusion.load_state_dict(
        torch.load(os.path.join(model_dir, "fusion_ser.pt"), map_location="cpu")
    )
    _fusion.eval()

    _mlp = GeMAPS_MLP(in_dim=GEMAPS_DIM, num_classes=NUM_EMOTIONS)
    _mlp.load_state_dict(
        torch.load(os.path.join(model_dir, "gemaps_mlp.pt"), map_location="cpu")
    )
    _mlp.eval()

    _scalers = joblib.load(os.path.join(model_dir, "language_scalers.pkl"))

    # Pre-warm feature extractors
    get_smile()
    get_whisper_fe()
    print("All models loaded.")


def extract_gemaps(audio_path, language):
    try:
        feats = get_smile().process_file(audio_path)
        arr   = feats.values[0].astype(np.float32).reshape(1, -1)
    except Exception:
        arr = np.zeros((1, GEMAPS_DIM), dtype=np.float32)
    # Apply the same per-language scaler fitted in notebook 1
    scaler = _scalers.get(language) or _scalers.get("english")
    arr    = scaler.transform(arr)
    return torch.from_numpy(arr.astype(np.float32))  # (1, 88)


def extract_whisper(audio_path):
    try:
        audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)
        audio    = audio[:MAX_SAMPLES]
        inp      = get_whisper_fe()(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
        return inp.input_features  # (1, 80, 3000)
    except Exception:
        return torch.zeros(1, 80, 3000)


@torch.no_grad()
def predict(audio_path, language="english", mode="fusion"):
    if _fusion is None:
        raise RuntimeError("Call load_models() first.")

    gemaps  = extract_gemaps(audio_path, language)
    whisper = extract_whisper(audio_path) if mode in ("fusion", "ensemble") else None

    probs_f = probs_m = None

    if mode in ("fusion", "ensemble"):
        probs_f = torch.softmax(_fusion(whisper, gemaps), -1).squeeze(0).numpy()
    if mode in ("gemaps", "ensemble"):
        probs_m = torch.softmax(_mlp(gemaps), -1).squeeze(0).numpy()

    if mode == "fusion":
        probs = probs_f
    elif mode == "gemaps":
        probs = probs_m
    else:
        probs = 0.6 * probs_f + 0.4 * probs_m

    return {label: float(probs[i]) for i, label in enumerate(EMOTION_LABELS)}