Komal133 commited on
Commit
e176a37
·
verified ·
1 Parent(s): ae87240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -114
app.py CHANGED
@@ -1,117 +1,53 @@
1
- import gradio as gr
2
- import sounddevice as sd
3
- import numpy as np
4
- import librosa
5
  import torch
6
- from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
7
- import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- # Load Hugging Face model
14
- MODEL_NAME = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
15
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
16
- model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- model.to(device)
19
- model.eval()
20
- logger.info(f"Loaded model {MODEL_NAME} on {device}")
21
-
22
- # Audio settings
23
- SAMPLE_RATE = 16000 # Model expects 16kHz
24
- DURATION = 5 # Seconds for real-time audio chunks
25
- recording = None
26
- is_recording = False
27
-
28
- # Function to process audio and detect screams
29
- def process_audio(audio_data, sample_rate=SAMPLE_RATE):
30
- try:
31
- inputs = feature_extractor(audio_data, sampling_rate=sample_rate, return_tensors="pt", padding=True)
32
- inputs = {key: val.to(device) for key, val in inputs.items()}
33
-
34
- with torch.no_grad():
35
- outputs = model(**inputs)
36
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
- confidence, predicted_label = torch.max(probabilities, dim=-1)
38
- confidence = confidence.item() * 100
39
- label = model.config.id2label[predicted_label.item()]
40
-
41
- # Check for scream-like emotions (e.g., fear, surprise)
42
- scream_detected = label in ["fear", "surprise"]
43
- risk_level = None
44
- if scream_detected:
45
- if confidence > 80:
46
- risk_level = "High-Risk"
47
- elif 50 <= confidence <= 80:
48
- risk_level = "Medium-Risk"
49
-
50
- return scream_detected, confidence, label, risk_level
51
- except Exception as e:
52
- logger.error(f"Error processing audio: {e}")
53
- return False, 0, "error", None
54
-
55
- # Real-time audio capture
56
- def start_recording():
57
- global recording, is_recording
58
- is_recording = True
59
- recording = []
60
-
61
- def callback(indata, frames, time, status):
62
- if status:
63
- logger.error(f"Recording error: {status}")
64
- recording.append(indata.copy())
65
-
66
- logger.info("Starting real-time audio capture")
67
- with sd.InputStream(samplerate=SAMPLE_RATE, channels=1, callback=callback, blocksize=int(SAMPLE_RATE * DURATION)):
68
- while is_recording:
69
- sd.sleep(1000)
70
-
71
- return "Recording started"
72
-
73
- def stop_recording():
74
- global is_recording, recording
75
- is_recording = False
76
- if recording:
77
- audio_data = np.concatenate(recording, axis=0).flatten()
78
- scream_detected, confidence, label, risk_level = process_audio(audio_data)
79
- return f"Detection: {scream_detected}, Confidence: {confidence:.2f}%, Label: {label}, Risk: {risk_level}"
80
- return "No audio recorded"
81
-
82
- # Process uploaded audio file
83
- def process_uploaded_audio(audio_file):
84
- try:
85
- audio_data, sr = librosa.load(audio_file, sr=SAMPLE_RATE)
86
- scream_detected, confidence, label, risk_level = process_audio(audio_data, sr)
87
- return f"Detection: {scream_detected}, Confidence: {confidence:.2f}%, Label: {label}, Risk: {risk_level}"
88
- except Exception as e:
89
- logger.error(f"Error processing uploaded audio: {e}")
90
- return f"Error: {e}"
91
-
92
- # Gradio interface
93
- def create_interface():
94
- with gr.Blocks() as demo:
95
- gr.Markdown("# Scream Detection System")
96
-
97
- with gr.Row():
98
- start_btn = gr.Button("Start Recording")
99
- stop_btn = gr.Button("Stop Recording")
100
-
101
- # Audio upload component (no 'source' parameter)
102
- upload = gr.Audio(type="filepath", label="Upload Audio File")
103
- output = gr.Textbox(label="Detection Results")
104
-
105
- with gr.Accordion("Settings"):
106
- confidence_threshold = gr.Slider(50, 100, value=80, label="High-Risk Confidence Threshold")
107
-
108
- start_btn.click(start_recording, outputs=output)
109
- stop_btn.click(stop_recording, outputs=output)
110
- upload.change(process_uploaded_audio, inputs=upload, outputs=output)
111
-
112
- return demo
113
-
114
- # Launch the interface
115
  if __name__ == "__main__":
116
- demo = create_interface()
117
- demo.launch()
 
1
+ from transformers import pipeline
 
 
 
2
  import torch
3
+ import soundfile as sf
4
+ from datetime import datetime
5
+ import requests
6
+
7
+ # Initialize the classifier pipeline
8
+ classifier = pipeline(
9
+ "audio-classification",
10
+ model="padmalcom/wav2vec2-large-nonverbalvocalization-classification",
11
+ )
12
+
13
+ def detect_scream(audio_path: str):
14
+ audio, sr = sf.read(audio_path)
15
+ # Resample to expected sampling rate if needed
16
+ if sr != classifier.feature_extractor.sampling_rate:
17
+ import librosa
18
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=classifier.feature_extractor.sampling_rate)
19
+ results = classifier(audio)
20
+ top = results[0]
21
+ label = top["label"]
22
+ score = float(top["score"]) * 100 # as percentage
23
+ # Map to your Detection_Result and Alert_Level
24
+ if label.lower() == "scream" and score > 80:
25
+ alert = "High-Risk"
26
+ elif label.lower() == "scream" and score > 50:
27
+ alert = "Medium-Risk"
28
+ else:
29
+ alert = "None"
30
+ return label, score, alert
31
+
32
+ def log_to_salesforce(sf_instance, audio_url, label, score, alert):
33
+ # Create Scream_Detection__c record
34
+ sf_instance.Scream_Detection__c.create({
35
+ "Audio_File_URL__c": audio_url,
36
+ "Detection_Result__c": label,
37
+ "Confidence_Score__c": score,
38
+ "Alert_Level__c": alert,
39
+ "Timestamp__c": datetime.utcnow().isoformat(),
40
+ # add User__c if available
41
+ })
42
+ # trigger Salesforce alert automation (email/SMS/in‑app)
43
+
44
+ def main():
45
+ audio_path = "input.wav"
46
+ audio_url = "https://my.blob/storage/input.wav"
47
+ label, score, alert = detect_scream(audio_path)
48
+ print(f"Detected: {label}, {score:.1f}%, Level: {alert}")
49
+ # Optional: Push to Salesforce using simple-salesforce, requests, etc.
50
+ # log_to_salesforce(sf, audio_url, label, score, alert)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  if __name__ == "__main__":
53
+ main()