Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| import torch | |
| import librosa | |
| import numpy as np | |
| from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor | |
| import os | |
| app = Flask(__name__) | |
| # Model setup | |
| model_name = 'amiriparian/ExHuBERT' | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960") | |
| model = AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| # Labels for emotion mapping | |
| labels = ['disgust', 'neutral', 'kind', 'anger', 'surprise', 'joy'] | |
| def detect_scream(): | |
| try: | |
| # Check if audio file is provided | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No audio file provided'}), 400 | |
| audio_file = request.files['file'] | |
| # Validate file type | |
| if not audio_file.filename.endswith(('.wav', '.mp3')): | |
| return jsonify({'error': 'Unsupported file format. Use WAV or MP3'}), 400 | |
| # Save audio file temporarily | |
| temp_path = f"/tmp/{audio_file.filename}" | |
| audio_file.save(temp_path) | |
| # Load and preprocess audio | |
| waveform, sr = librosa.load(temp_path, sr=16000) | |
| inputs = feature_extractor( | |
| waveform, | |
| sampling_rate=16000, | |
| padding="max_length", | |
| max_length=48000, | |
| return_tensors="pt" | |
| ) | |
| inputs = inputs['input_values'].to(device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| outputs = model(inputs).logits | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| # Get result | |
| result = { | |
| 'label': labels[predicted.item()], | |
| 'confidence': float(confidence.item()), | |
| 'alert_level': 'High-Risk' if confidence.item() > 0.8 else ('Medium-Risk' if confidence.item() > 0.5 else 'None') | |
| } | |
| # Clean up temporary file | |
| os.remove(temp_path) | |
| return jsonify(result), 200 | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def health_check(): | |
| return jsonify({'status': 'healthy'}), 200 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |