BissakaAI commited on
Commit
c2ca4c0
·
verified ·
1 Parent(s): c4fa7f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -27
app.py CHANGED
@@ -1,46 +1,106 @@
 
 
1
  import gradio as gr
 
2
  import numpy as np
3
- from faster_whisper import WhisperModel
4
 
5
- # Load model (small = fast, medium = better accuracy)
6
- model = WhisperModel(
7
- "small",
8
- device="cpu",
9
- compute_type="float16" if torch.cuda.is_available() else "int8"
 
 
 
 
 
 
 
 
 
10
  )
11
 
12
- def transcribe_stream(audio):
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if audio is None:
14
- return ""
 
 
 
 
 
 
 
15
 
16
- sr, data = audio
17
 
18
- # Convert to mono
19
- if data.ndim > 1:
20
- data = np.mean(data, axis=1)
 
 
 
 
21
 
22
- segments, info = model.transcribe(
23
- data,
24
- language="yo", # Yoruba (use None for auto-detect)
25
- beam_size=5
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
 
28
- text = ""
29
- for seg in segments:
30
- text += seg.text + " "
 
 
 
 
 
 
 
 
 
31
 
32
- return text.strip()
33
 
 
 
 
34
  demo = gr.Interface(
35
- fn=transcribe_stream,
36
  inputs=gr.Audio(
37
- source="microphone",
38
  type="numpy",
39
- streaming=True
40
  ),
41
- outputs=gr.Textbox(),
42
- title="Real-Time Streaming ASR (Whisper)",
43
- description="Low-latency live speech recognition"
44
  )
45
 
46
- demo.launch()
 
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ import librosa
5
  import numpy as np
6
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
7
 
8
+ # ----------------------------
9
+ # Config
10
+ # ----------------------------
11
+ ASR_MODEL_ID = "openai/whisper-small"
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
15
+
16
+ # ----------------------------
17
+ # Load processor & model
18
+ # ----------------------------
19
+ processor = AutoProcessor.from_pretrained(
20
+ ASR_MODEL_ID,
21
+ token=HF_TOKEN
22
  )
23
 
24
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
25
+ ASR_MODEL_ID,
26
+ torch_dtype=DTYPE,
27
+ low_cpu_mem_usage=True,
28
+ use_safetensors=True,
29
+ token=HF_TOKEN
30
+ ).to(DEVICE)
31
+
32
+ model.eval()
33
+
34
+ # ----------------------------
35
+ # Audio preprocessing
36
+ # ----------------------------
37
+ def preprocess_audio(audio):
38
  if audio is None:
39
+ return None
40
+
41
+ # Gradio returns (sr, np.ndarray)
42
+ sr, speech = audio
43
+
44
+ # Stereo → mono
45
+ if speech.ndim > 1:
46
+ speech = np.mean(speech, axis=1)
47
 
48
+ speech = speech.astype(np.float32)
49
 
50
+ # Force 16kHz
51
+ if sr != 16000:
52
+ speech = librosa.resample(
53
+ speech,
54
+ orig_sr=sr,
55
+ target_sr=16000
56
+ )
57
 
58
+ return speech
59
+
60
+ # ----------------------------
61
+ # Transcription
62
+ # ----------------------------
63
+ def transcribe_audio(audio):
64
+ speech = preprocess_audio(audio)
65
+
66
+ if speech is None or len(speech) == 0:
67
+ return "No audio provided."
68
+
69
+ inputs = processor(
70
+ speech,
71
+ sampling_rate=16000,
72
+ return_tensors="pt"
73
  )
74
 
75
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
76
+
77
+ with torch.no_grad():
78
+ generated_ids = model.generate(
79
+ **inputs,
80
+ max_new_tokens=256
81
+ )
82
+
83
+ transcription = processor.batch_decode(
84
+ generated_ids,
85
+ skip_special_tokens=True
86
+ )[0]
87
 
88
+ return transcription.strip()
89
 
90
+ # ----------------------------
91
+ # Gradio UI (REAL-TIME MIC)
92
+ # ----------------------------
93
  demo = gr.Interface(
94
+ fn=transcribe_audio,
95
  inputs=gr.Audio(
96
+ sources=["microphone", "upload"],
97
  type="numpy",
98
+ label="Speak or Upload Audio"
99
  ),
100
+ outputs=gr.Textbox(label="Transcription"),
101
+ title="HealthAtlas ASR (Whisper)",
102
+ description="Real-time multilingual speech-to-text with automatic language detection"
103
  )
104
 
105
+ if __name__ == "__main__":
106
+ demo.launch()