EYEDOL commited on
Commit
975eb7a
·
verified ·
1 Parent(s): 53c76c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -116
app.py CHANGED
@@ -5,215 +5,197 @@ import torch
5
  import soundfile as sf
6
  from transformers import pipeline
7
  import gradio as gr
8
-
9
- # Optional: pydub helps with splitting arbitrary audio formats (mp3, m4a, etc.)
10
  from pydub import AudioSegment
11
 
12
- MODEL_ID = "EYEDOL/Yoruba-ASRNEW"
 
 
 
 
 
 
 
13
 
14
- # device for transformers pipeline
15
- device = 0 if torch.cuda.is_available() else -1
16
 
17
- # Create pipeline (automatic-speech-recognition)
18
- asr = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device)
 
 
 
 
 
 
19
 
20
- # Utility: write numpy (rate, data) to wav
21
  def save_numpy_to_wav(np_tuple):
22
  samplerate, data = np_tuple
23
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
24
  sf.write(tmp.name, data, samplerate)
25
  return tmp.name
26
 
27
- # Utility: return audio duration in seconds (works for file paths)
28
  def get_duration_seconds(path):
29
  try:
30
  info = sf.info(path)
31
  return info.duration
32
  except Exception:
33
- # fallback to pydub
34
  seg = AudioSegment.from_file(path)
35
  return len(seg) / 1000.0
36
 
37
- # Split an audio file into chunks (ms). Returns list of (chunk_path, start_ms, end_ms)
38
  def split_audio_file(path, chunk_length_ms=25000, overlap_ms=500):
39
  audio = AudioSegment.from_file(path)
40
  duration_ms = len(audio)
41
  chunks = []
42
  start = 0
43
  while start < duration_ms:
44
- end = start + chunk_length_ms
45
- if end > duration_ms:
46
- end = duration_ms
47
  chunk = audio[start:end]
48
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
49
  chunk.export(tmp.name, format="wav")
50
  chunks.append((tmp.name, start, end))
51
- # advance start by chunk_length - overlap
52
- start += chunk_length_ms - overlap_ms
53
  return chunks
54
 
55
- # Transcribe a single file path (wraps pipeline call). Supports passing return_timestamps param optionally.
56
- def transcribe_file(path, return_timestamps=False):
57
  if return_timestamps:
58
- # some pipelines accept return_timestamps=True and return timestamps tokens;
59
- # exact format can vary by library version. We'll pass the kwarg and try to handle the output.
60
- out = asr(path, return_timestamps=True)
61
  else:
62
- out = asr(path)
63
- return out
64
 
65
- # Main: handle any input (numpy tuple or path)
66
- def transcribe(audio_input, allow_longform_with_timestamps=False, chunk_length_seconds=25, overlap_seconds=0.5):
67
  """
68
- audio_input: either a tuple (sr, numpy array) from gradio mic, or a filepath string from upload
69
- returns: dict with 'full_text' and 'segments' list of {start_s, end_s, text}
 
70
  """
71
- # Normalize input to a filepath
72
  if audio_input is None:
73
- return "No audio provided."
74
 
 
 
75
  if isinstance(audio_input, tuple):
76
- # Gradio microphone when type="numpy" sends (sample_rate, numpy_array)
77
- audio_path = save_numpy_to_wav(audio_input)
78
  else:
79
- audio_path = audio_input # uploaded filepath
80
 
81
- # determine duration
82
  duration_s = get_duration_seconds(audio_path)
 
83
 
84
- # If short enough, just transcribe directly
85
  if duration_s <= 30:
86
- out = transcribe_file(audio_path, return_timestamps=False)
87
- text = out.get("text", out)
88
  segments = [{"start_s": 0.0, "end_s": duration_s, "text": text}]
89
  full_text = text
90
- # cleanup if we created a temp file
91
- if isinstance(audio_input, tuple):
92
- try:
93
- os.unlink(audio_path)
94
- except Exception:
95
- pass
96
  return {"full_text": full_text, "segments": segments}
97
 
98
- # duration > 30s -> handle long audio
99
  if allow_longform_with_timestamps:
100
- # try calling the pipeline with return_timestamps=True
101
  try:
102
- out = transcribe_file(audio_path, return_timestamps=True)
103
- # expected: out may contain 'text' and 'chunks' or 'segments' with timestamps depending on HF version.
104
- # We'll try to be flexible.
105
- full_text = out.get("text", None)
106
  segments = []
107
 
108
- # If the pipeline returned timestamps in 'chunks' or 'segments':
109
  if isinstance(out, dict):
110
  if "chunks" in out and isinstance(out["chunks"], list):
111
  for c in out["chunks"]:
112
- # chunk may contain 'text', 'timestamp' or 'start'/'end'
113
- start = c.get("timestamp", [None, None])
114
- if isinstance(start, list) and len(start) == 2:
115
- start_s, end_s = start[0], start[1]
116
  else:
117
  start_s = c.get("start", None)
118
  end_s = c.get("end", None)
119
- segments.append({
120
- "start_s": start_s,
121
- "end_s": end_s,
122
- "text": c.get("text", "")
123
- })
124
  elif "words" in out and isinstance(out["words"], list):
125
- # group words into coarse segments (simple approach: group by contiguous words)
126
- # For simplicity, transform words items into tiny segments
127
  for w in out["words"]:
128
- segments.append({
129
- "start_s": w.get("start", None),
130
- "end_s": w.get("end", None),
131
- "text": w.get("word", "")
132
- })
133
  else:
134
- # fallback: no structured chunks return whole text as single segment
135
  if full_text is None:
136
  full_text = str(out)
137
  segments = [{"start_s": 0.0, "end_s": duration_s, "text": full_text}]
138
  else:
139
- # pipeline returned a string or something else
140
  full_text = str(out)
141
  segments = [{"start_s": 0.0, "end_s": duration_s, "text": full_text}]
142
 
143
- if isinstance(audio_input, tuple):
144
- try:
145
- os.unlink(audio_path)
146
- except Exception:
147
- pass
148
  return {"full_text": full_text, "segments": segments}
149
  except Exception as e:
150
- # Fall back to chunking if long-form timestamps fail
151
- print("Long-form timestamps failed, falling back to chunking:", e)
152
 
153
- # Default: chunking approach
154
  chunk_length_ms = int(chunk_length_seconds * 1000)
155
  overlap_ms = int(overlap_seconds * 1000)
156
-
157
  chunks = split_audio_file(audio_path, chunk_length_ms=chunk_length_ms, overlap_ms=overlap_ms)
158
- stitched_texts = []
159
  segments = []
160
  for chunk_path, start_ms, end_ms in chunks:
161
  try:
162
- out = transcribe_file(chunk_path, return_timestamps=False)
163
- text = out.get("text", out)
164
  except Exception as e:
165
- text = f"[ERROR transcribing chunk: {e}]"
166
-
167
  start_s = start_ms / 1000.0
168
  end_s = end_ms / 1000.0
169
  segments.append({"start_s": start_s, "end_s": end_s, "text": text})
170
- stitched_texts.append(text)
 
 
171
 
172
- # cleanup chunk file
173
- try:
174
- os.unlink(chunk_path)
175
- except Exception:
176
- pass
177
 
178
- # cleanup original temp if microphone
179
- if isinstance(audio_input, tuple):
180
- try:
181
- os.unlink(audio_path)
182
- except Exception:
183
- pass
184
-
185
- full_text = " ".join([s for s in stitched_texts if s])
186
  return {"full_text": full_text, "segments": segments}
187
 
188
- # Gradio UI
189
- with gr.Blocks(title="Yoruba ASR — long audio ready") as demo:
190
- gr.Markdown("## Yoruba ASR Upload or use microphone. Supports long audio via chunking or long-form timestamps 🎧")
191
 
192
  with gr.Row():
193
- with gr.Column():
194
- mic = gr.Audio(label="Record from mic (use 'Record' then 'Stop')", type="numpy")
195
- upload = gr.Audio(label="Or upload audio file", type="filepath")
196
- mode = gr.Radio(choices=["Use microphone input", "Use uploaded file"], value="Use microphone input", label="Input source")
197
- longform_checkbox = gr.Checkbox(label="Try model's long-form timestamps (may be supported by some Whisper forks)", value=False)
198
- chunk_len = gr.Slider(minimum=10, maximum=60, value=25, step=5, label="Chunk length (seconds) used when chunking")
199
- overlap = gr.Slider(minimum=0, maximum=5, value=0.5, step=0.5, label="Chunk overlap (seconds)")
 
200
  transcribe_btn = gr.Button("Transcribe")
201
- with gr.Column():
 
202
  full_text_out = gr.Textbox(label="Full transcription", lines=8)
203
  segments_out = gr.JSON(label="Segments (start_s, end_s, text)")
204
 
205
- def handle_transcription(mic_input, upload_input, mode_choice, use_longform, chunk_len_s, overlap_s):
206
- audio_src = mic_input if mode_choice == "Use microphone input" else upload_input
207
- res = transcribe(audio_src, allow_longform_with_timestamps=use_longform, chunk_length_seconds=chunk_len_s, overlap_seconds=overlap_s)
208
- if isinstance(res, str):
209
- return res, []
 
210
  return res["full_text"], res["segments"]
211
 
212
- transcribe_btn.click(fn=handle_transcription, inputs=[mic, upload, mode, longform_checkbox, chunk_len, overlap], outputs=[full_text_out, segments_out])
213
-
214
- gr.Markdown("**Notes:**\n\n- Chunking is robust and recommended if you experience errors. Default chunk length is 25s with 0.5s overlap. "
215
- "- If you enable long-form timestamps, the pipeline will attempt `return_timestamps=True` and return timestamps if the model supports it. "
216
- "- Ensure your Space has enough compute (GPU recommended) for faster transcription.")
217
 
218
  if __name__ == "__main__":
219
  demo.launch()
 
5
  import soundfile as sf
6
  from transformers import pipeline
7
  import gradio as gr
 
 
8
  from pydub import AudioSegment
9
 
10
+ # ---- Models available ----
11
+ MODEL_CHOICES = {
12
+ "Yoruba (EYEDOL/Yoruba-ASRNEW)": "EYEDOL/Yoruba-ASRNEW",
13
+ "Naija English (EYEDOL/NAIJA_ENG-ASRNEW)": "EYEDOL/NAIJA_ENG-ASRNEW",
14
+ }
15
+
16
+ # Device selection for pipeline creation
17
+ DEVICE = 0 if torch.cuda.is_available() else -1
18
 
19
+ # Cache created pipelines to avoid reloading
20
+ PIPELINE_CACHE = {}
21
 
22
+ def get_asr_pipeline(model_id: str):
23
+ """Return a cached pipeline for model_id or create a new one."""
24
+ if model_id in PIPELINE_CACHE:
25
+ return PIPELINE_CACHE[model_id]
26
+ # Create and cache
27
+ asr = pipeline("automatic-speech-recognition", model=model_id, device=DEVICE)
28
+ PIPELINE_CACHE[model_id] = asr
29
+ return asr
30
 
31
+ # Utilities
32
  def save_numpy_to_wav(np_tuple):
33
  samplerate, data = np_tuple
34
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
35
  sf.write(tmp.name, data, samplerate)
36
  return tmp.name
37
 
 
38
  def get_duration_seconds(path):
39
  try:
40
  info = sf.info(path)
41
  return info.duration
42
  except Exception:
 
43
  seg = AudioSegment.from_file(path)
44
  return len(seg) / 1000.0
45
 
 
46
  def split_audio_file(path, chunk_length_ms=25000, overlap_ms=500):
47
  audio = AudioSegment.from_file(path)
48
  duration_ms = len(audio)
49
  chunks = []
50
  start = 0
51
  while start < duration_ms:
52
+ end = min(start + chunk_length_ms, duration_ms)
 
 
53
  chunk = audio[start:end]
54
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
55
  chunk.export(tmp.name, format="wav")
56
  chunks.append((tmp.name, start, end))
57
+ start += max(1, chunk_length_ms - overlap_ms)
 
58
  return chunks
59
 
60
+ def transcribe_file_with_pipeline(asr_pipeline, path, return_timestamps=False):
61
+ # wrapper that calls pipeline and returns its output
62
  if return_timestamps:
63
+ return asr_pipeline(path, return_timestamps=True)
 
 
64
  else:
65
+ return asr_pipeline(path)
 
66
 
67
+ def transcribe(audio_input, model_id, allow_longform_with_timestamps=False, chunk_length_seconds=25, overlap_seconds=0.5):
 
68
  """
69
+ audio_input: either (sr, numpy_array) from mic (type="numpy") or filepath from upload (type="filepath")
70
+ model_id: Hugging Face model id string
71
+ Returns dict: {"full_text": str, "segments": [{start_s,end_s,text}, ...]}
72
  """
 
73
  if audio_input is None:
74
+ return {"error": "No audio provided."}
75
 
76
+ # Normalize to a filepath
77
+ created_tmp_input = False
78
  if isinstance(audio_input, tuple):
79
+ audio_path = save_numpy_to_wav(audio_input) # we created this tmp file
80
+ created_tmp_input = True
81
  else:
82
+ audio_path = audio_input
83
 
 
84
  duration_s = get_duration_seconds(audio_path)
85
+ asr = get_asr_pipeline(model_id)
86
 
87
+ # Short audio: direct call
88
  if duration_s <= 30:
89
+ out = transcribe_file_with_pipeline(asr, audio_path, return_timestamps=False)
90
+ text = out.get("text", out) if isinstance(out, dict) else str(out)
91
  segments = [{"start_s": 0.0, "end_s": duration_s, "text": text}]
92
  full_text = text
93
+ if created_tmp_input:
94
+ try: os.unlink(audio_path)
95
+ except: pass
 
 
 
96
  return {"full_text": full_text, "segments": segments}
97
 
98
+ # Long audio (>30s)
99
  if allow_longform_with_timestamps:
 
100
  try:
101
+ out = transcribe_file_with_pipeline(asr, audio_path, return_timestamps=True)
102
+ # Attempt to parse common structures
103
+ full_text = out.get("text", None) if isinstance(out, dict) else str(out)
 
104
  segments = []
105
 
 
106
  if isinstance(out, dict):
107
  if "chunks" in out and isinstance(out["chunks"], list):
108
  for c in out["chunks"]:
109
+ # chunk may contain 'timestamp' e.g. [start, end] or 'start'/'end'
110
+ ts = c.get("timestamp", None)
111
+ if isinstance(ts, list) and len(ts) == 2:
112
+ start_s, end_s = ts[0], ts[1]
113
  else:
114
  start_s = c.get("start", None)
115
  end_s = c.get("end", None)
116
+ segments.append({"start_s": start_s, "end_s": end_s, "text": c.get("text", "")})
117
+ elif "segments" in out and isinstance(out["segments"], list):
118
+ for s in out["segments"]:
119
+ segments.append({"start_s": s.get("start", None), "end_s": s.get("end", None), "text": s.get("text", "")})
 
120
  elif "words" in out and isinstance(out["words"], list):
 
 
121
  for w in out["words"]:
122
+ segments.append({"start_s": w.get("start", None), "end_s": w.get("end", None), "text": w.get("word", "")})
 
 
 
 
123
  else:
124
+ # no detailed structure -> fall back to full text
125
  if full_text is None:
126
  full_text = str(out)
127
  segments = [{"start_s": 0.0, "end_s": duration_s, "text": full_text}]
128
  else:
129
+ # pipeline returned just a string
130
  full_text = str(out)
131
  segments = [{"start_s": 0.0, "end_s": duration_s, "text": full_text}]
132
 
133
+ if created_tmp_input:
134
+ try: os.unlink(audio_path)
135
+ except: pass
 
 
136
  return {"full_text": full_text, "segments": segments}
137
  except Exception as e:
138
+ # fallback to chunking
139
+ print("Long-form timestamps failed; falling back to chunking:", e)
140
 
141
+ # Chunking fallback
142
  chunk_length_ms = int(chunk_length_seconds * 1000)
143
  overlap_ms = int(overlap_seconds * 1000)
 
144
  chunks = split_audio_file(audio_path, chunk_length_ms=chunk_length_ms, overlap_ms=overlap_ms)
145
+ stitched = []
146
  segments = []
147
  for chunk_path, start_ms, end_ms in chunks:
148
  try:
149
+ out = transcribe_file_with_pipeline(asr, chunk_path, return_timestamps=False)
150
+ text = out.get("text", out) if isinstance(out, dict) else str(out)
151
  except Exception as e:
152
+ text = f"[ERROR on chunk: {e}]"
 
153
  start_s = start_ms / 1000.0
154
  end_s = end_ms / 1000.0
155
  segments.append({"start_s": start_s, "end_s": end_s, "text": text})
156
+ stitched.append(text)
157
+ try: os.unlink(chunk_path)
158
+ except: pass
159
 
160
+ if created_tmp_input:
161
+ try: os.unlink(audio_path)
162
+ except: pass
 
 
163
 
164
+ full_text = " ".join([s for s in stitched if s])
 
 
 
 
 
 
 
165
  return {"full_text": full_text, "segments": segments}
166
 
167
+ # ---- Gradio UI ----
168
+ with gr.Blocks(title="EYEDOL ASR — Multi-model (Yoruba + Naija English)") as demo:
169
+ gr.Markdown("## EYEDOL ASR Demo\nSelect model, upload audio or use the microphone. Supports long audio via chunking or model long-form timestamps.")
170
 
171
  with gr.Row():
172
+ with gr.Column(scale=2):
173
+ model_choice = gr.Dropdown(list(MODEL_CHOICES.keys()), value=list(MODEL_CHOICES.keys())[0], label="Choose model")
174
+ mic_input = gr.Audio(label="Record (click Record → Stop)", type="numpy")
175
+ file_input = gr.Audio(label="Or upload audio file", type="filepath")
176
+ source = gr.Radio(["Use microphone input", "Use uploaded file"], value="Use microphone input", label="Input source")
177
+ longform = gr.Checkbox(label="Try model's built-in long-form timestamps (if supported)", value=False)
178
+ chunk_len = gr.Slider(minimum=10, maximum=120, value=25, step=5, label="Chunk length (seconds)")
179
+ overlap = gr.Slider(minimum=0.0, maximum=5.0, value=0.5, step=0.5, label="Chunk overlap (seconds)")
180
  transcribe_btn = gr.Button("Transcribe")
181
+ gr.Markdown("**Note:** If a model is private add `HF_TOKEN` as a secret in Space settings. GPU recommended for best performance.")
182
+ with gr.Column(scale=3):
183
  full_text_out = gr.Textbox(label="Full transcription", lines=8)
184
  segments_out = gr.JSON(label="Segments (start_s, end_s, text)")
185
 
186
+ def handle_transcription(mic_input, file_input, source_choice, model_label, use_longform, chunk_len_s, overlap_s):
187
+ model_id = MODEL_CHOICES.get(model_label)
188
+ audio_src = mic_input if source_choice == "Use microphone input" else file_input
189
+ res = transcribe(audio_src, model_id=model_id, allow_longform_with_timestamps=use_longform, chunk_length_seconds=chunk_len_s, overlap_seconds=overlap_s)
190
+ if "error" in res:
191
+ return res["error"], []
192
  return res["full_text"], res["segments"]
193
 
194
+ transcribe_btn.click(
195
+ fn=handle_transcription,
196
+ inputs=[mic_input, file_input, source, model_choice, longform, chunk_len, overlap],
197
+ outputs=[full_text_out, segments_out],
198
+ )
199
 
200
  if __name__ == "__main__":
201
  demo.launch()