import json, os import numpy as np import torch import torch.nn as nn import librosa import opensmile import joblib from transformers import WhisperModel, WhisperFeatureExtractor # Loaded from metadata.json at startup — do not hardcode here EMOTION_LABELS = None NUM_EMOTIONS = None GEMAPS_DIM = None WHISPER_DIM = None SAMPLE_RATE = None MAX_DURATION = None MAX_SAMPLES = None _smile = None _whisper_fe = None _scalers = None _fusion = None _mlp = None def get_smile(): global _smile if _smile is None: _smile = opensmile.Smile( feature_set=opensmile.FeatureSet.eGeMAPSv02, feature_level=opensmile.FeatureLevel.Functionals, ) return _smile def get_whisper_fe(): global _whisper_fe if _whisper_fe is None: _whisper_fe = WhisperFeatureExtractor.from_pretrained( "openai/whisper-tiny", sampling_rate=SAMPLE_RATE ) return _whisper_fe class GeMAPS_MLP(nn.Module): def __init__(self, in_dim, hidden=128, num_classes=5, dropout=0.3): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.ReLU(), nn.BatchNorm1d(hidden), nn.Dropout(dropout), nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.BatchNorm1d(hidden // 2), nn.Dropout(dropout), nn.Linear(hidden // 2, num_classes) ) def forward(self, x): return self.net(x) class FusionSER(nn.Module): def __init__(self, num_classes=5, dropout=0.3, gemaps_proj=64, whisper_proj=256): super().__init__() self.whisper_enc = WhisperModel.from_pretrained("openai/whisper-tiny").encoder self.w_proj = nn.Sequential( nn.Linear(WHISPER_DIM, whisper_proj), nn.ReLU(), nn.Dropout(dropout) ) self.g_proj = nn.Sequential( nn.Linear(GEMAPS_DIM, gemaps_proj), nn.ReLU(), nn.Dropout(dropout) ) self.classifier = nn.Sequential( nn.Linear(whisper_proj + gemaps_proj, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(dropout), nn.Linear(128, num_classes) ) def forward(self, whisper_inp, gemaps): w = self.whisper_enc(whisper_inp).last_hidden_state.mean(dim=1) w = self.w_proj(w) g = self.g_proj(gemaps) return self.classifier(torch.cat([w, g], dim=-1)) def load_models(model_dir="."): global _fusion, _mlp, _scalers global EMOTION_LABELS, NUM_EMOTIONS, GEMAPS_DIM, WHISPER_DIM global SAMPLE_RATE, MAX_DURATION, MAX_SAMPLES with open(os.path.join(model_dir, "metadata.json")) as f: meta = json.load(f) EMOTION_LABELS = meta["emotion_labels"] NUM_EMOTIONS = meta["num_emotions"] GEMAPS_DIM = meta["gemaps_dim"] WHISPER_DIM = meta["whisper_dim"] SAMPLE_RATE = meta["sample_rate"] MAX_DURATION = meta["max_duration"] MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION) _fusion = FusionSER(num_classes=NUM_EMOTIONS) _fusion.load_state_dict( torch.load(os.path.join(model_dir, "fusion_ser.pt"), map_location="cpu") ) _fusion.eval() _mlp = GeMAPS_MLP(in_dim=GEMAPS_DIM, num_classes=NUM_EMOTIONS) _mlp.load_state_dict( torch.load(os.path.join(model_dir, "gemaps_mlp.pt"), map_location="cpu") ) _mlp.eval() _scalers = joblib.load(os.path.join(model_dir, "language_scalers.pkl")) # Pre-warm feature extractors get_smile() get_whisper_fe() print("All models loaded.") def extract_gemaps(audio_path, language): try: feats = get_smile().process_file(audio_path) arr = feats.values[0].astype(np.float32).reshape(1, -1) except Exception: arr = np.zeros((1, GEMAPS_DIM), dtype=np.float32) # Apply the same per-language scaler fitted in notebook 1 scaler = _scalers.get(language) or _scalers.get("english") arr = scaler.transform(arr) return torch.from_numpy(arr.astype(np.float32)) # (1, 88) def extract_whisper(audio_path): try: audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True) audio = audio[:MAX_SAMPLES] inp = get_whisper_fe()(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt") return inp.input_features # (1, 80, 3000) except Exception: return torch.zeros(1, 80, 3000) @torch.no_grad() def predict(audio_path, language="english", mode="fusion"): if _fusion is None: raise RuntimeError("Call load_models() first.") gemaps = extract_gemaps(audio_path, language) whisper = extract_whisper(audio_path) if mode in ("fusion", "ensemble") else None probs_f = probs_m = None if mode in ("fusion", "ensemble"): probs_f = torch.softmax(_fusion(whisper, gemaps), -1).squeeze(0).numpy() if mode in ("gemaps", "ensemble"): probs_m = torch.softmax(_mlp(gemaps), -1).squeeze(0).numpy() if mode == "fusion": probs = probs_f elif mode == "gemaps": probs = probs_m else: probs = 0.6 * probs_f + 0.4 * probs_m return {label: float(probs[i]) for i, label in enumerate(EMOTION_LABELS)}