BissakaAI commited on
Commit
5db1504
·
verified ·
1 Parent(s): 42081cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -28
app.py CHANGED
@@ -15,25 +15,22 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
16
 
17
  # ----------------------------
18
- # Load processor & model
19
  # ----------------------------
20
- print("Loading Whisper processor...")
21
  processor = AutoProcessor.from_pretrained(
22
  ASR_MODEL_ID,
23
- token=HF_TOKEN
24
  )
25
 
26
- print("Loading Whisper model...")
27
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
  ASR_MODEL_ID,
29
  torch_dtype=DTYPE,
30
  low_cpu_mem_usage=True,
31
  use_safetensors=True,
32
- token=HF_TOKEN
33
  ).to(DEVICE)
34
 
35
  model.eval()
36
- print("✅ Whisper Large v3 loaded")
37
 
38
  # ----------------------------
39
  # Audio preprocessing
@@ -42,22 +39,26 @@ def preprocess_audio(audio):
42
  if audio is None:
43
  return None
44
 
45
- # Gradio returns (sample_rate, waveform)
46
  sr, speech = audio
47
 
48
  # Stereo → mono
49
  if speech.ndim > 1:
50
  speech = np.mean(speech, axis=1)
51
 
 
52
  speech = speech.astype(np.float32)
53
 
 
 
 
 
 
 
 
 
54
  # Force 16kHz
55
  if sr != 16000:
56
- speech = librosa.resample(
57
- speech,
58
- orig_sr=sr,
59
- target_sr=16000
60
- )
61
 
62
  return speech
63
 
@@ -67,48 +68,51 @@ def preprocess_audio(audio):
67
  def transcribe_audio(audio):
68
  speech = preprocess_audio(audio)
69
 
70
- if speech is None or len(speech) == 0:
71
- return "No audio provided."
72
 
73
  inputs = processor(
74
  speech,
75
  sampling_rate=16000,
76
  return_tensors="pt"
77
  )
78
-
79
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
80
 
81
  with torch.no_grad():
 
82
  generated_ids = model.generate(
83
  **inputs,
84
- max_new_tokens=256,
85
- task="transcribe" # 🔑 THIS IS ALL YOU NEED
 
 
 
86
  )
87
 
88
- transcription = processor.batch_decode(
89
  generated_ids,
90
  skip_special_tokens=True
91
- )[0]
92
 
93
- return transcription.strip()
 
 
 
94
 
95
  # ----------------------------
96
- # Gradio UI (Mic + Upload)
97
  # ----------------------------
98
  demo = gr.Interface(
99
  fn=transcribe_audio,
100
  inputs=gr.Audio(
101
  sources=["microphone", "upload"],
102
  type="numpy",
103
- label="Speak or Upload Audio"
104
  ),
105
  outputs=gr.Textbox(label="Transcription"),
106
- title="HealthAtlas ASR (Whisper Large v3)",
107
- description="Real-time multilingual speech-to-text with automatic language detection",
108
  )
109
 
110
- # ----------------------------
111
- # Launch
112
- # ----------------------------
113
  if __name__ == "__main__":
114
- demo.launch()
 
15
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
16
 
17
  # ----------------------------
18
+ # Load Whisper
19
  # ----------------------------
 
20
  processor = AutoProcessor.from_pretrained(
21
  ASR_MODEL_ID,
22
+ use_auth_token=HF_TOKEN
23
  )
24
 
 
25
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
26
  ASR_MODEL_ID,
27
  torch_dtype=DTYPE,
28
  low_cpu_mem_usage=True,
29
  use_safetensors=True,
30
+ use_auth_token=HF_TOKEN
31
  ).to(DEVICE)
32
 
33
  model.eval()
 
34
 
35
  # ----------------------------
36
  # Audio preprocessing
 
39
  if audio is None:
40
  return None
41
 
 
42
  sr, speech = audio
43
 
44
  # Stereo → mono
45
  if speech.ndim > 1:
46
  speech = np.mean(speech, axis=1)
47
 
48
+ # Convert to float32
49
  speech = speech.astype(np.float32)
50
 
51
+ # Normalize volume
52
+ rms = np.sqrt(np.mean(speech ** 2))
53
+ if rms > 0:
54
+ speech = speech / rms
55
+
56
+ # Trim silence
57
+ speech, _ = librosa.effects.trim(speech, top_db=25)
58
+
59
  # Force 16kHz
60
  if sr != 16000:
61
+ speech = librosa.resample(speech, orig_sr=sr, target_sr=16000).astype(np.float32)
 
 
 
 
62
 
63
  return speech
64
 
 
68
  def transcribe_audio(audio):
69
  speech = preprocess_audio(audio)
70
 
71
+ if speech is None or len(speech) < 16000:
72
+ return "Audio too short or unclear. Please speak clearly and try again."
73
 
74
  inputs = processor(
75
  speech,
76
  sampling_rate=16000,
77
  return_tensors="pt"
78
  )
 
79
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
80
 
81
  with torch.no_grad():
82
+ # Force Yoruba transcription
83
  generated_ids = model.generate(
84
  **inputs,
85
+ task="transcribe",
86
+ language="yo", # Yoruba ISO-639-1 code
87
+ max_new_tokens=512,
88
+ temperature=0.0,
89
+ no_repeat_ngram_size=3
90
  )
91
 
92
+ text = processor.batch_decode(
93
  generated_ids,
94
  skip_special_tokens=True
95
+ )[0].strip()
96
 
97
+ if len(text.split()) < 2:
98
+ return "Speech unclear. Please repeat slowly in Yoruba."
99
+
100
+ return text
101
 
102
  # ----------------------------
103
+ # Gradio UI
104
  # ----------------------------
105
  demo = gr.Interface(
106
  fn=transcribe_audio,
107
  inputs=gr.Audio(
108
  sources=["microphone", "upload"],
109
  type="numpy",
110
+ label="Speak clearly or upload audio in Yoruba"
111
  ),
112
  outputs=gr.Textbox(label="Transcription"),
113
+ title="Yoruba ASR (Whisper)",
114
+ description="Speech-to-text system that transcribes only Yoruba"
115
  )
116
 
 
 
 
117
  if __name__ == "__main__":
118
+ demo.launch(share=True)