Komal133 commited on
Commit
b453cec
·
verified ·
1 Parent(s): f76d519

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -91
app.py CHANGED
@@ -1,94 +1,116 @@
1
- import torch
2
- import librosa
3
  import numpy as np
4
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
5
- from fastapi import FastAPI, UploadFile, File
6
- from pydantic import BaseModel
7
- from datetime import datetime
8
- import requests
9
-
10
- # Initialize FastAPI application
11
- app = FastAPI()
12
-
13
- # Hugging Face Model
14
- model_name = "facebook/wav2vec2-large-960h"
15
- processor = Wav2Vec2Processor.from_pretrained(model_name)
16
-
17
- # Initialize model for sequence classification (you will need a custom fine-tuned model for panic/scream detection)
18
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name, num_labels=2) # Assuming 2 classes: 'not panic' and 'panic'
19
-
20
- # Salesforce API Configuration (You should configure Salesforce API integration here)
21
- SF_URL = "https://your-salesforce-instance.com"
22
- SF_API_KEY = "your-api-key" # Placeholder, use secure methods for handling API keys
23
-
24
- # Helper functions
25
- def process_audio(file_path):
26
- """
27
- Function to process and predict the class of the audio file.
28
- Fine-tuning may be required for accurate panic/scream detection.
29
- """
30
- # Load audio file using librosa (resample to 16 kHz)
31
- audio_input, _ = librosa.load(file_path, sr=16000)
32
-
33
- # Pre-process audio using Hugging Face's Wav2Vec2Processor
34
- inputs = processor(audio_input, return_tensors="pt", padding=True)
35
-
36
- # Predict emotions (this requires fine-tuning for scream/panic detection)
37
- with torch.no_grad():
38
- logits = model(**inputs).logits
39
-
40
- # Get predicted label (0: no panic, 1: panic/scream)
41
- predicted_class = torch.argmax(logits, dim=-1).item()
42
-
43
- return predicted_class
44
-
45
- def send_salesforce_alert(alert_data):
46
- """
47
- Sends alert data to the Salesforce system (via REST API)
48
- """
49
- headers = {'Authorization': f'Bearer {SF_API_KEY}'}
50
- response = requests.post(f"{SF_URL}/alerts", json=alert_data, headers=headers)
51
- return response.json()
52
-
53
- def log_alert_to_salesforce(audio_metadata, alert_type):
54
- """
55
- Logs the detected alert in Salesforce with the metadata and alert type.
56
- """
57
- alert_data = {
58
- "timestamp": datetime.now().isoformat(),
59
- "audio_metadata": audio_metadata,
60
- "alert_type": alert_type
61
- }
62
- return send_salesforce_alert(alert_data)
63
-
64
- # API Endpoint for uploading audio file and processing it
65
- @app.post("/upload-audio/")
66
- async def upload_audio(file: UploadFile = File(...)):
67
- """
68
- Handles audio file upload, processes the audio for panic detection,
69
- and triggers an alert if necessary.
70
- """
71
- # Save the uploaded file temporarily
72
- file_location = f"./temp_audio/{file.filename}"
73
- with open(file_location, "wb") as audio_file:
74
- audio_file.write(file.file.read())
75
-
76
- # Process audio to detect panic/scream
77
- detection_result = process_audio(file_location)
78
-
79
- # Set alert type based on confidence
80
- alert_type = "High-Risk" if detection_result == 1 else "Medium-Risk"
81
 
82
- # Log detection and send alert to Salesforce
83
- audio_metadata = {"filename": file.filename, "file_size": len(file.file.read())}
84
- log_alert_to_salesforce(audio_metadata, alert_type)
85
-
86
- return {"message": f"Alert triggered: {alert_type}", "alert_type": alert_type}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # API Endpoint to start/stop detection system (you can implement start/stop functionality if needed)
89
- @app.get("/toggle-detection/{status}")
90
- async def toggle_detection(status: str):
91
- """
92
- Start or stop the detection system.
93
- """
94
- return {"message": f"Detection system {status}"}
 
1
+ import gradio as gr
2
+ import sounddevice as sd
3
  import numpy as np
4
+ import librosa
5
+ import torch
6
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Load Hugging Face model
14
+ MODEL_NAME = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
15
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
16
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device)
19
+ model.eval()
20
+ logger.info(f"Loaded model {MODEL_NAME} on {device}")
21
+
22
+ # Audio settings
23
+ SAMPLE_RATE = 16000 # Model expects 16kHz
24
+ DURATION = 5 # Seconds for real-time audio chunks
25
+ recording = None
26
+ is_recording = False
27
+
28
+ # Function to process audio and detect screams
29
+ def process_audio(audio_data, sample_rate=SAMPLE_RATE):
30
+ try:
31
+ inputs = feature_extractor(audio_data, sampling_rate=sample_rate, return_tensors="pt", padding=True)
32
+ inputs = {key: val.to(device) for key, val in inputs.items()}
33
+
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
+ confidence, predicted_label = torch.max(probabilities, dim=-1)
38
+ confidence = confidence.item() * 100
39
+ label = model.config.id2label[predicted_label.item()]
40
+
41
+ # Check for scream-like emotions (e.g., fear, surprise)
42
+ scream_detected = label in ["fear", "surprise"]
43
+ risk_level = None
44
+ if scream_detected:
45
+ if confidence > 80:
46
+ risk_level = "High-Risk"
47
+ elif 50 <= confidence <= 80:
48
+ risk_level = "Medium-Risk"
49
+
50
+ return scream_detected, confidence, label, risk_level
51
+ except Exception as e:
52
+ logger.error(f"Error processing audio: {e}")
53
+ return False, 0, "error", None
54
+
55
+ # Real-time audio capture
56
+ def start_recording():
57
+ global recording, is_recording
58
+ is_recording = True
59
+ recording = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def callback(indata, frames, time, status):
62
+ if status:
63
+ logger.error(f"Recording error: {status}")
64
+ recording.append(indata.copy())
65
+
66
+ logger.info("Starting real-time audio capture")
67
+ with sd.InputStream(samplerate=SAMPLE_RATE, channels=1, callback=callback, blocksize=int(SAMPLE_RATE * DURATION)):
68
+ while is_recording:
69
+ sd.sleep(1000)
70
+
71
+ return "Recording started"
72
+
73
+ def stop_recording():
74
+ global is_recording, recording
75
+ is_recording = False
76
+ if recording:
77
+ audio_data = np.concatenate(recording, axis=0).flatten()
78
+ scream_detected, confidence, label, risk_level = process_audio(audio_data)
79
+ return f"Detection: {scream_detected}, Confidence: {confidence:.2f}%, Label: {label}, Risk: {risk_level}"
80
+ return "No audio recorded"
81
+
82
+ # Process uploaded audio file
83
+ def process_uploaded_audio(audio_file):
84
+ try:
85
+ audio_data, sr = librosa.load(audio_file, sr=SAMPLE_RATE)
86
+ scream_detected, confidence, label, risk_level = process_audio(audio_data, sr)
87
+ return f"Detection: {scream_detected}, Confidence: {confidence:.2f}%, Label: {label}, Risk: {risk_level}"
88
+ except Exception as e:
89
+ logger.error(f"Error processing uploaded audio: {e}")
90
+ return f"Error: {e}"
91
+
92
+ # Gradio interface
93
+ def create_interface():
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("# Scream Detection System")
96
+
97
+ with gr.Row():
98
+ start_btn = gr.Button("Start Recording")
99
+ stop_btn = gr.Button("Stop Recording")
100
+
101
+ upload = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
102
+ output = gr.Textbox(label="Detection Results")
103
+
104
+ with gr.Accordion("Settings"):
105
+ confidence_threshold = gr.Slider(50, 100, value=80, label="High-Risk Confidence Threshold")
106
+
107
+ start_btn.click(start_recording, outputs=output)
108
+ stop_btn.click(stop_recording, outputs=output)
109
+ upload.change(process_uploaded_audio, inputs=upload, outputs=output)
110
+
111
+ return demo
112
 
113
+ # Launch the interface
114
+ if __name__ == "__main__":
115
+ demo = create_interface()
116
+ demo.launch()