Spaces:
Running
Running
File size: 5,164 Bytes
507ebe9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import json, os
import numpy as np
import torch
import torch.nn as nn
import librosa
import opensmile
import joblib
from transformers import WhisperModel, WhisperFeatureExtractor
# Loaded from metadata.json at startup — do not hardcode here
EMOTION_LABELS = None
NUM_EMOTIONS = None
GEMAPS_DIM = None
WHISPER_DIM = None
SAMPLE_RATE = None
MAX_DURATION = None
MAX_SAMPLES = None
_smile = None
_whisper_fe = None
_scalers = None
_fusion = None
_mlp = None
def get_smile():
global _smile
if _smile is None:
_smile = opensmile.Smile(
feature_set=opensmile.FeatureSet.eGeMAPSv02,
feature_level=opensmile.FeatureLevel.Functionals,
)
return _smile
def get_whisper_fe():
global _whisper_fe
if _whisper_fe is None:
_whisper_fe = WhisperFeatureExtractor.from_pretrained(
"openai/whisper-tiny", sampling_rate=SAMPLE_RATE
)
return _whisper_fe
class GeMAPS_MLP(nn.Module):
def __init__(self, in_dim, hidden=128, num_classes=5, dropout=0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.ReLU(),
nn.BatchNorm1d(hidden), nn.Dropout(dropout),
nn.Linear(hidden, hidden // 2), nn.ReLU(),
nn.BatchNorm1d(hidden // 2), nn.Dropout(dropout),
nn.Linear(hidden // 2, num_classes)
)
def forward(self, x):
return self.net(x)
class FusionSER(nn.Module):
def __init__(self, num_classes=5, dropout=0.3, gemaps_proj=64, whisper_proj=256):
super().__init__()
self.whisper_enc = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
self.w_proj = nn.Sequential(
nn.Linear(WHISPER_DIM, whisper_proj), nn.ReLU(), nn.Dropout(dropout)
)
self.g_proj = nn.Sequential(
nn.Linear(GEMAPS_DIM, gemaps_proj), nn.ReLU(), nn.Dropout(dropout)
)
self.classifier = nn.Sequential(
nn.Linear(whisper_proj + gemaps_proj, 128), nn.ReLU(),
nn.BatchNorm1d(128), nn.Dropout(dropout),
nn.Linear(128, num_classes)
)
def forward(self, whisper_inp, gemaps):
w = self.whisper_enc(whisper_inp).last_hidden_state.mean(dim=1)
w = self.w_proj(w)
g = self.g_proj(gemaps)
return self.classifier(torch.cat([w, g], dim=-1))
def load_models(model_dir="."):
global _fusion, _mlp, _scalers
global EMOTION_LABELS, NUM_EMOTIONS, GEMAPS_DIM, WHISPER_DIM
global SAMPLE_RATE, MAX_DURATION, MAX_SAMPLES
with open(os.path.join(model_dir, "metadata.json")) as f:
meta = json.load(f)
EMOTION_LABELS = meta["emotion_labels"]
NUM_EMOTIONS = meta["num_emotions"]
GEMAPS_DIM = meta["gemaps_dim"]
WHISPER_DIM = meta["whisper_dim"]
SAMPLE_RATE = meta["sample_rate"]
MAX_DURATION = meta["max_duration"]
MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION)
_fusion = FusionSER(num_classes=NUM_EMOTIONS)
_fusion.load_state_dict(
torch.load(os.path.join(model_dir, "fusion_ser.pt"), map_location="cpu")
)
_fusion.eval()
_mlp = GeMAPS_MLP(in_dim=GEMAPS_DIM, num_classes=NUM_EMOTIONS)
_mlp.load_state_dict(
torch.load(os.path.join(model_dir, "gemaps_mlp.pt"), map_location="cpu")
)
_mlp.eval()
_scalers = joblib.load(os.path.join(model_dir, "language_scalers.pkl"))
# Pre-warm feature extractors
get_smile()
get_whisper_fe()
print("All models loaded.")
def extract_gemaps(audio_path, language):
try:
feats = get_smile().process_file(audio_path)
arr = feats.values[0].astype(np.float32).reshape(1, -1)
except Exception:
arr = np.zeros((1, GEMAPS_DIM), dtype=np.float32)
# Apply the same per-language scaler fitted in notebook 1
scaler = _scalers.get(language) or _scalers.get("english")
arr = scaler.transform(arr)
return torch.from_numpy(arr.astype(np.float32)) # (1, 88)
def extract_whisper(audio_path):
try:
audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)
audio = audio[:MAX_SAMPLES]
inp = get_whisper_fe()(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
return inp.input_features # (1, 80, 3000)
except Exception:
return torch.zeros(1, 80, 3000)
@torch.no_grad()
def predict(audio_path, language="english", mode="fusion"):
if _fusion is None:
raise RuntimeError("Call load_models() first.")
gemaps = extract_gemaps(audio_path, language)
whisper = extract_whisper(audio_path) if mode in ("fusion", "ensemble") else None
probs_f = probs_m = None
if mode in ("fusion", "ensemble"):
probs_f = torch.softmax(_fusion(whisper, gemaps), -1).squeeze(0).numpy()
if mode in ("gemaps", "ensemble"):
probs_m = torch.softmax(_mlp(gemaps), -1).squeeze(0).numpy()
if mode == "fusion":
probs = probs_f
elif mode == "gemaps":
probs = probs_m
else:
probs = 0.6 * probs_f + 0.4 * probs_m
return {label: float(probs[i]) for i, label in enumerate(EMOTION_LABELS)}
|