Spaces:
Running
Running
| import json, os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import librosa | |
| import opensmile | |
| import joblib | |
| from transformers import WhisperModel, WhisperFeatureExtractor | |
| # Loaded from metadata.json at startup — do not hardcode here | |
| EMOTION_LABELS = None | |
| NUM_EMOTIONS = None | |
| GEMAPS_DIM = None | |
| WHISPER_DIM = None | |
| SAMPLE_RATE = None | |
| MAX_DURATION = None | |
| MAX_SAMPLES = None | |
| _smile = None | |
| _whisper_fe = None | |
| _scalers = None | |
| _fusion = None | |
| _mlp = None | |
| def get_smile(): | |
| global _smile | |
| if _smile is None: | |
| _smile = opensmile.Smile( | |
| feature_set=opensmile.FeatureSet.eGeMAPSv02, | |
| feature_level=opensmile.FeatureLevel.Functionals, | |
| ) | |
| return _smile | |
| def get_whisper_fe(): | |
| global _whisper_fe | |
| if _whisper_fe is None: | |
| _whisper_fe = WhisperFeatureExtractor.from_pretrained( | |
| "openai/whisper-tiny", sampling_rate=SAMPLE_RATE | |
| ) | |
| return _whisper_fe | |
| class GeMAPS_MLP(nn.Module): | |
| def __init__(self, in_dim, hidden=128, num_classes=5, dropout=0.3): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden), nn.ReLU(), | |
| nn.BatchNorm1d(hidden), nn.Dropout(dropout), | |
| nn.Linear(hidden, hidden // 2), nn.ReLU(), | |
| nn.BatchNorm1d(hidden // 2), nn.Dropout(dropout), | |
| nn.Linear(hidden // 2, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class FusionSER(nn.Module): | |
| def __init__(self, num_classes=5, dropout=0.3, gemaps_proj=64, whisper_proj=256): | |
| super().__init__() | |
| self.whisper_enc = WhisperModel.from_pretrained("openai/whisper-tiny").encoder | |
| self.w_proj = nn.Sequential( | |
| nn.Linear(WHISPER_DIM, whisper_proj), nn.ReLU(), nn.Dropout(dropout) | |
| ) | |
| self.g_proj = nn.Sequential( | |
| nn.Linear(GEMAPS_DIM, gemaps_proj), nn.ReLU(), nn.Dropout(dropout) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(whisper_proj + gemaps_proj, 128), nn.ReLU(), | |
| nn.BatchNorm1d(128), nn.Dropout(dropout), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, whisper_inp, gemaps): | |
| w = self.whisper_enc(whisper_inp).last_hidden_state.mean(dim=1) | |
| w = self.w_proj(w) | |
| g = self.g_proj(gemaps) | |
| return self.classifier(torch.cat([w, g], dim=-1)) | |
| def load_models(model_dir="."): | |
| global _fusion, _mlp, _scalers | |
| global EMOTION_LABELS, NUM_EMOTIONS, GEMAPS_DIM, WHISPER_DIM | |
| global SAMPLE_RATE, MAX_DURATION, MAX_SAMPLES | |
| with open(os.path.join(model_dir, "metadata.json")) as f: | |
| meta = json.load(f) | |
| EMOTION_LABELS = meta["emotion_labels"] | |
| NUM_EMOTIONS = meta["num_emotions"] | |
| GEMAPS_DIM = meta["gemaps_dim"] | |
| WHISPER_DIM = meta["whisper_dim"] | |
| SAMPLE_RATE = meta["sample_rate"] | |
| MAX_DURATION = meta["max_duration"] | |
| MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION) | |
| _fusion = FusionSER(num_classes=NUM_EMOTIONS) | |
| _fusion.load_state_dict( | |
| torch.load(os.path.join(model_dir, "fusion_ser.pt"), map_location="cpu") | |
| ) | |
| _fusion.eval() | |
| _mlp = GeMAPS_MLP(in_dim=GEMAPS_DIM, num_classes=NUM_EMOTIONS) | |
| _mlp.load_state_dict( | |
| torch.load(os.path.join(model_dir, "gemaps_mlp.pt"), map_location="cpu") | |
| ) | |
| _mlp.eval() | |
| _scalers = joblib.load(os.path.join(model_dir, "language_scalers.pkl")) | |
| # Pre-warm feature extractors | |
| get_smile() | |
| get_whisper_fe() | |
| print("All models loaded.") | |
| def extract_gemaps(audio_path, language): | |
| try: | |
| feats = get_smile().process_file(audio_path) | |
| arr = feats.values[0].astype(np.float32).reshape(1, -1) | |
| except Exception: | |
| arr = np.zeros((1, GEMAPS_DIM), dtype=np.float32) | |
| # Apply the same per-language scaler fitted in notebook 1 | |
| scaler = _scalers.get(language) or _scalers.get("english") | |
| arr = scaler.transform(arr) | |
| return torch.from_numpy(arr.astype(np.float32)) # (1, 88) | |
| def extract_whisper(audio_path): | |
| try: | |
| audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True) | |
| audio = audio[:MAX_SAMPLES] | |
| inp = get_whisper_fe()(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt") | |
| return inp.input_features # (1, 80, 3000) | |
| except Exception: | |
| return torch.zeros(1, 80, 3000) | |
| def predict(audio_path, language="english", mode="fusion"): | |
| if _fusion is None: | |
| raise RuntimeError("Call load_models() first.") | |
| gemaps = extract_gemaps(audio_path, language) | |
| whisper = extract_whisper(audio_path) if mode in ("fusion", "ensemble") else None | |
| probs_f = probs_m = None | |
| if mode in ("fusion", "ensemble"): | |
| probs_f = torch.softmax(_fusion(whisper, gemaps), -1).squeeze(0).numpy() | |
| if mode in ("gemaps", "ensemble"): | |
| probs_m = torch.softmax(_mlp(gemaps), -1).squeeze(0).numpy() | |
| if mode == "fusion": | |
| probs = probs_f | |
| elif mode == "gemaps": | |
| probs = probs_m | |
| else: | |
| probs = 0.6 * probs_f + 0.4 * probs_m | |
| return {label: float(probs[i]) for i, label in enumerate(EMOTION_LABELS)} | |