Spaces:
Sleeping
Sleeping
| import torch | |
| import librosa | |
| import numpy as np | |
| from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification | |
| import argparse | |
| import os | |
| class EmotionPredictor: | |
| def __init__(self, model_path="models/wav2vec2-finetuned"): | |
| """ | |
| Initializes the predictor with the fine-tuned model. | |
| """ | |
| print(f"Loading model from: {model_path}...") | |
| # Load config and model | |
| try: | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path) | |
| self.model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path) | |
| self.model.eval() # Set to inference mode | |
| # Extract labels from config | |
| self.id2label = self.model.config.id2label | |
| print(f"Model loaded. Labels: {list(self.id2label.values())}") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| raise | |
| def predict(self, audio_path): | |
| """ | |
| Predicts emotion from a single audio file. | |
| """ | |
| # 1. Load Audio | |
| try: | |
| speech, sr = librosa.load(audio_path, sr=16000) | |
| except Exception as e: | |
| return {"error": f"Could not load audio: {e}"} | |
| # 2. Preprocess | |
| inputs = self.feature_extractor( | |
| speech, | |
| sampling_rate=16000, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| # 3. Inference | |
| with torch.no_grad(): | |
| logits = self.model(inputs.input_values).logits | |
| # 4. Post-processing | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_id = torch.argmax(logits, dim=-1).item() | |
| confidence = probabilities[0][predicted_id].item() | |
| predicted_label = self.id2label[predicted_id] | |
| return { | |
| "emotion": predicted_label, | |
| "confidence": f"{confidence:.2%}", | |
| "probabilities": { | |
| self.id2label[i]: f"{probabilities[0][i].item():.2%}" | |
| for i in range(len(self.id2label)) | |
| } | |
| } | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Predict emotion from audio file") | |
| parser.add_argument("--file", type=str, help="Path to .wav file") | |
| args = parser.parse_args() | |
| if args.file: | |
| predictor = EmotionPredictor() | |
| result = predictor.predict(args.file) | |
| print("\nPrediction Result:") | |
| print(f"Emotion: {result.get('emotion', 'Error')}") | |
| print(f"Confidence: {result.get('confidence', 'N/A')}") | |
| print(result) | |
| else: | |
| print("Please provide a file path using --file") | |