import torch from model import CLSTMModel from config import CONFIG, EMOTION_CONFIG from audio_utils import preprocess_audio device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(CONFIG["model_path"], map_location=device) if "label_map" in checkpoint: inv = {v: k for k, v in checkpoint["label_map"].items()} emotions = [inv[i] for i in range(len(inv))] else: emotions = list(EMOTION_CONFIG.keys()) model = CLSTMModel( n_mels=CONFIG["n_mels"], n_classes=len(emotions) ).to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() def predict(path): x = preprocess_audio(path, device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=1) idx = torch.argmax(probs, dim=1).item() return { "emotion": emotions[idx], "confidence": float(probs[0][idx]) }