MetiMiester commited on
Commit
5dc2f59
·
verified ·
1 Parent(s): ebc93d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -14
app.py CHANGED
@@ -2,34 +2,53 @@ import joblib
2
  import torch
3
  import numpy as np
4
  import soundfile as sf
 
5
  from transformers import pipeline
6
  import gradio as gr
7
 
8
  # Load your text classifier
9
  text_clf = joblib.load("text_pipeline_balanced.joblib")
10
 
11
- # Initialize Whisper ASR
12
  device = 0 if torch.cuda.is_available() else -1
 
 
13
  asr = pipeline(
14
  "automatic-speech-recognition",
15
- model="openai/whisper-base",
16
- chunk_length_s=30,
17
  device=device,
18
- generate_kwargs={"language": "en", "task": "transcribe"}
 
 
 
 
 
 
19
  )
20
 
21
  def classify(audio_path):
22
  """
23
- audio_path: str path to uploaded/recorded file
24
  returns: transcript (str), safety probabilities (dict), unsafe probability (str)
25
  """
26
- # Read file from disk
27
  audio, sr = sf.read(audio_path, dtype="float32")
28
- # If stereo, convert to mono
 
 
 
 
 
 
29
  if audio.ndim > 1:
30
  audio = audio.mean(axis=1)
31
 
32
- # ASR
 
 
 
 
33
  result = asr({"array": audio, "sampling_rate": sr})
34
  txt = result["text"].strip()
35
 
@@ -37,18 +56,19 @@ def classify(audio_path):
37
  proba = float(text_clf.predict_proba([txt])[0][1])
38
  label_probs = {"safe": 1 - proba, "unsafe": proba}
39
  unsafe_str = f"{proba:.2f}"
 
40
  return txt, label_probs, unsafe_str
41
 
42
- # Use filepath-based Audio component
43
- audio_input = gr.components.Audio(label="Upload or record audio", type="filepath")
44
- transcript_out = gr.components.Textbox(label="Transcript")
45
- probs_out = gr.components.Label(num_top_classes=2, label="Safety Probabilities")
46
- unsafe_out = gr.components.Textbox(label="Unsafe Probability")
47
 
48
  iface = gr.Interface(
49
  fn=classify,
50
  inputs=audio_input,
51
- outputs=[transcript_out, probs_out, unsafe_out],
52
  title="BubbleGuard Audio Safety Checker",
53
  description="Upload or record audio; get ASR transcript plus safe/unsafe probabilities."
54
  )
 
2
  import torch
3
  import numpy as np
4
  import soundfile as sf
5
+ import torchaudio
6
  from transformers import pipeline
7
  import gradio as gr
8
 
9
  # Load your text classifier
10
  text_clf = joblib.load("text_pipeline_balanced.joblib")
11
 
12
+ # Choose GPU if available
13
  device = 0 if torch.cuda.is_available() else -1
14
+
15
+ # Initialize Whisper-Large ASR with beam search
16
  asr = pipeline(
17
  "automatic-speech-recognition",
18
+ model="openai/whisper-large-v2",
19
+ chunk_length_s=10,
20
  device=device,
21
+ generate_kwargs={
22
+ "language": "en",
23
+ "task": "transcribe",
24
+ "num_beams": 5,
25
+ "best_of": 5,
26
+ },
27
+ ignore_warning=True
28
  )
29
 
30
  def classify(audio_path):
31
  """
32
+ audio_path: str path to the uploaded/recorded file
33
  returns: transcript (str), safety probabilities (dict), unsafe probability (str)
34
  """
35
+ # Read & (re)sample
36
  audio, sr = sf.read(audio_path, dtype="float32")
37
+ if sr != 16000:
38
+ audio = torchaudio.functional.resample(
39
+ torch.from_numpy(audio), sr, 16000
40
+ ).numpy()
41
+ sr = 16000
42
+
43
+ # Stereo → mono
44
  if audio.ndim > 1:
45
  audio = audio.mean(axis=1)
46
 
47
+ # Normalize peak amplitude
48
+ peak = np.abs(audio).max() or 1.0
49
+ audio = audio / peak
50
+
51
+ # ASR transcription
52
  result = asr({"array": audio, "sampling_rate": sr})
53
  txt = result["text"].strip()
54
 
 
56
  proba = float(text_clf.predict_proba([txt])[0][1])
57
  label_probs = {"safe": 1 - proba, "unsafe": proba}
58
  unsafe_str = f"{proba:.2f}"
59
+
60
  return txt, label_probs, unsafe_str
61
 
62
+ # Gradio components
63
+ audio_input = gr.components.Audio(label="Upload or record audio", type="filepath")
64
+ transcript_out = gr.components.Textbox(label="Transcript")
65
+ probs_out = gr.components.Label(num_top_classes=2, label="Safety Probabilities")
66
+ unsafe_prob_out = gr.components.Textbox(label="Unsafe Probability")
67
 
68
  iface = gr.Interface(
69
  fn=classify,
70
  inputs=audio_input,
71
+ outputs=[transcript_out, probs_out, unsafe_prob_out],
72
  title="BubbleGuard Audio Safety Checker",
73
  description="Upload or record audio; get ASR transcript plus safe/unsafe probabilities."
74
  )