ashutoshroy02's picture
Upload folder using huggingface_hub
18b3a2d verified
raw
history blame contribute delete
891 Bytes
import torch
import librosa
from transformers import Wav2Vec2Processor
from model import Wav2Vec2_LSTM_MultiTask
checkpoint = torch.load("model.pt", map_location="cpu")
emotion2id = checkpoint["emotion2id"]
id2emotion = {v: k for k, v in emotion2id.items()}
NUM_EMOTIONS = checkpoint["num_emotions"]
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2_LSTM_MultiTask(NUM_EMOTIONS)
model.load_state_dict(checkpoint["model_state"])
model.eval()
def predict(audio_path):
audio, _ = librosa.load(audio_path, sr=16000)
inputs = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).input_values
with torch.no_grad():
emotion_logits, stress_pred = model(inputs)
return {
"emotion": id2emotion[emotion_logits.argmax(dim=1).item()],
"stress": round(stress_pred.item(), 3)
}