| import torch | |
| import librosa | |
| from transformers import Wav2Vec2Processor | |
| from model import Wav2Vec2_LSTM_MultiTask | |
| checkpoint = torch.load("model.pt", map_location="cpu") | |
| emotion2id = checkpoint["emotion2id"] | |
| id2emotion = {v: k for k, v in emotion2id.items()} | |
| NUM_EMOTIONS = checkpoint["num_emotions"] | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
| model = Wav2Vec2_LSTM_MultiTask(NUM_EMOTIONS) | |
| model.load_state_dict(checkpoint["model_state"]) | |
| model.eval() | |
| def predict(audio_path): | |
| audio, _ = librosa.load(audio_path, sr=16000) | |
| inputs = processor( | |
| audio, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_values | |
| with torch.no_grad(): | |
| emotion_logits, stress_pred = model(inputs) | |
| return { | |
| "emotion": id2emotion[emotion_logits.argmax(dim=1).item()], | |
| "stress": round(stress_pred.item(), 3) | |
| } | |