BissakaAI commited on
Commit
42081cb
·
verified ·
1 Parent(s): 5d41ed7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -10,17 +10,20 @@ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
10
  # ----------------------------
11
  ASR_MODEL_ID = "openai/whisper-large-v3"
12
  HF_TOKEN = os.getenv("HF_TOKEN")
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
15
 
16
  # ----------------------------
17
  # Load processor & model
18
  # ----------------------------
 
19
  processor = AutoProcessor.from_pretrained(
20
  ASR_MODEL_ID,
21
  token=HF_TOKEN
22
  )
23
 
 
24
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
25
  ASR_MODEL_ID,
26
  torch_dtype=DTYPE,
@@ -30,6 +33,7 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(
30
  ).to(DEVICE)
31
 
32
  model.eval()
 
33
 
34
  # ----------------------------
35
  # Audio preprocessing
@@ -38,7 +42,7 @@ def preprocess_audio(audio):
38
  if audio is None:
39
  return None
40
 
41
- # Gradio returns (sr, np.ndarray)
42
  sr, speech = audio
43
 
44
  # Stereo → mono
@@ -76,15 +80,10 @@ def transcribe_audio(audio):
76
 
77
  with torch.no_grad():
78
  generated_ids = model.generate(
79
- **inputs,
80
- max_new_tokens=256,
81
- language=None,
82
- task="transcribe",
83
- prompt_ids=processor.get_prompt_ids(
84
- text="This audio may be in Yoruba, Hausa, Igbo, Nigerian Pidgin or English."
85
- )
86
- )
87
-
88
 
89
  transcription = processor.batch_decode(
90
  generated_ids,
@@ -94,7 +93,7 @@ def transcribe_audio(audio):
94
  return transcription.strip()
95
 
96
  # ----------------------------
97
- # Gradio UI (REAL-TIME MIC)
98
  # ----------------------------
99
  demo = gr.Interface(
100
  fn=transcribe_audio,
@@ -104,9 +103,12 @@ demo = gr.Interface(
104
  label="Speak or Upload Audio"
105
  ),
106
  outputs=gr.Textbox(label="Transcription"),
107
- title="HealthAtlas ASR (Whisper)",
108
- description="Real-time multilingual speech-to-text with automatic language detection"
109
  )
110
 
 
 
 
111
  if __name__ == "__main__":
112
  demo.launch()
 
10
  # ----------------------------
11
  ASR_MODEL_ID = "openai/whisper-large-v3"
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
+
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
16
 
17
  # ----------------------------
18
  # Load processor & model
19
  # ----------------------------
20
+ print("Loading Whisper processor...")
21
  processor = AutoProcessor.from_pretrained(
22
  ASR_MODEL_ID,
23
  token=HF_TOKEN
24
  )
25
 
26
+ print("Loading Whisper model...")
27
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
  ASR_MODEL_ID,
29
  torch_dtype=DTYPE,
 
33
  ).to(DEVICE)
34
 
35
  model.eval()
36
+ print("✅ Whisper Large v3 loaded")
37
 
38
  # ----------------------------
39
  # Audio preprocessing
 
42
  if audio is None:
43
  return None
44
 
45
+ # Gradio returns (sample_rate, waveform)
46
  sr, speech = audio
47
 
48
  # Stereo → mono
 
80
 
81
  with torch.no_grad():
82
  generated_ids = model.generate(
83
+ **inputs,
84
+ max_new_tokens=256,
85
+ task="transcribe" # 🔑 THIS IS ALL YOU NEED
86
+ )
 
 
 
 
 
87
 
88
  transcription = processor.batch_decode(
89
  generated_ids,
 
93
  return transcription.strip()
94
 
95
  # ----------------------------
96
+ # Gradio UI (Mic + Upload)
97
  # ----------------------------
98
  demo = gr.Interface(
99
  fn=transcribe_audio,
 
103
  label="Speak or Upload Audio"
104
  ),
105
  outputs=gr.Textbox(label="Transcription"),
106
+ title="HealthAtlas ASR (Whisper Large v3)",
107
+ description="Real-time multilingual speech-to-text with automatic language detection",
108
  )
109
 
110
+ # ----------------------------
111
+ # Launch
112
+ # ----------------------------
113
  if __name__ == "__main__":
114
  demo.launch()