MetiMiester commited on
Commit
84897bd
·
verified ·
1 Parent(s): 18d8681

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -2,38 +2,43 @@ import gradio as gr
2
  import joblib
3
  import torch
4
  import numpy as np
5
- from transformers import pipeline
 
6
 
7
  # 1) Load your trained text classifier
8
  text_clf = joblib.load("text_pipeline_balanced.joblib")
9
 
10
- # 2) Initialize Whisper ASR
11
- device = 0 if torch.cuda.is_available() else -1
12
- asr = pipeline(
13
- "automatic-speech-recognition",
14
- model="openai/whisper-base",
15
- chunk_length_s=30,
16
- device=device,
17
- ignore_warning=True,
18
- generate_kwargs={"language": "en", "task": "transcribe"}
19
- )
 
 
 
 
20
 
21
  def classify_audio(audio_np, sr):
22
- """
23
- audio_np: np.ndarray
24
- sr: sampling rate
25
- """
26
  if audio_np is None or sr is None:
27
  return "", "❌ Unsafe", 0.0
28
 
 
29
  audio = audio_np.astype("float32")
30
  if audio.ndim > 1:
31
  audio = audio.mean(axis=1)
32
 
33
- # Transcribe
34
- transcript = asr({"array": audio, "sampling_rate": sr})["text"].strip()
 
 
 
35
 
36
- # Classify
37
  proba = text_clf.predict_proba([transcript])[0][1]
38
  label = "❌ Unsafe" if proba > 0.5 else "✅ Safe"
39
 
 
2
  import joblib
3
  import torch
4
  import numpy as np
5
+ import whisper # openai-whisper
6
+ import soundfile as sf
7
 
8
  # 1) Load your trained text classifier
9
  text_clf = joblib.load("text_pipeline_balanced.joblib")
10
 
11
+ # 2) Load Whisper model
12
+ # (Requires `pip install openai-whisper`)
13
+ asr_model = whisper.load_model("base")
14
+
15
+ def transcribe_with_whisper(audio_np, sr):
16
+ # Whisper expects a file or numpy; we can pass the array directly
17
+ # but it wants 16 kHz, so resample if needed:
18
+ if sr != 16000:
19
+ import resampy
20
+ audio_np = resampy.resample(audio_np, sr, 16000)
21
+ sr = 16000
22
+ # whisper.load_model returns a model with .transcribe()
23
+ result = asr_model.transcribe(audio_np, fp16=False)
24
+ return result["text"].strip()
25
 
26
  def classify_audio(audio_np, sr):
 
 
 
 
27
  if audio_np is None or sr is None:
28
  return "", "❌ Unsafe", 0.0
29
 
30
+ # ensure float32 & mono
31
  audio = audio_np.astype("float32")
32
  if audio.ndim > 1:
33
  audio = audio.mean(axis=1)
34
 
35
+ # 3) Transcribe via openai-whisper
36
+ try:
37
+ transcript = transcribe_with_whisper(audio, sr)
38
+ except Exception as e:
39
+ transcript = f"[Transcription error: {e}]"
40
 
41
+ # 4) Classify
42
  proba = text_clf.predict_proba([transcript])[0][1]
43
  label = "❌ Unsafe" if proba > 0.5 else "✅ Safe"
44