woranit commited on
Commit
c31871d
·
verified ·
1 Parent(s): 647f8e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -436
app.py CHANGED
@@ -1,485 +1,97 @@
1
- # app.py — Thai ASR on faster-whisper (quiet-speech safe + gap-fill + rescue windows + optional WhisperX alignment)
2
- # Works on HF Spaces (CPU) and will auto-use GPU if available.
3
-
4
  import os
5
- import tempfile
6
- import subprocess
7
- from typing import List, Tuple, Optional
8
 
 
9
  import gradio as gr
10
  from faster_whisper import WhisperModel
11
 
12
- # =========================
13
- # Config / environment
14
- # =========================
15
- MODEL_ID = os.getenv("MODEL_ID", "Thaweewat/whisper-th-medium-ct2")
16
-
17
- # Try GPU if torch is present; else CPU
18
- try:
19
- import torch # optional; only used to detect GPU
20
- HAS_CUDA = torch.cuda.is_available()
21
- except Exception:
22
- HAS_CUDA = False
23
-
24
- DEVICE = "cuda" if HAS_CUDA else "cpu"
25
- COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "int8_float16" if DEVICE == "cuda" else "int8")
26
- CPU_THREADS = int(os.getenv("CPU_THREADS", os.cpu_count() or 4))
27
- NUM_WORKERS = int(os.getenv("NUM_WORKERS", 1))
28
-
29
- # Optional domain bias (proper nouns help): set in Space → Variables
30
- # e.g. "อนุทิน ชาญวีรกูล พรรคภูมิใจไทย พรรคประชาชน นายกรัฐมนตรี สภาผู้แทนราษฎร ลงมติ"
31
- BIAS_PROMPT = (os.getenv("INITIAL_PROMPT_TH") or "").strip()
32
 
33
- # Loudness normalization toggle (set LOUDNORM=1 in Space Variables to enable)
34
- LOUDNORM = os.getenv("LOUDNORM", "0") == "1"
 
35
 
36
- # ---- Gap-fill knobs ----
37
- GAP_MIN_SECONDS = float(os.getenv("GAP_MIN_SECONDS", "0.8")) # consider holes >= 0.8s
38
- GAP_MAX_SECONDS = float(os.getenv("GAP_MAX_SECONDS", "40.0")) # retry longer holes
39
- GAP_MAX_COUNT = int(os.getenv("GAP_MAX_COUNT", "20")) # allow many gap retries
40
- GAP_PAD = float(os.getenv("GAP_PAD", "2.0")) # more context around gaps
41
 
42
- # Join/merge behavior after recovery
43
- JOIN_GAP = float(os.getenv("JOIN_GAP", "0.35")) # join segments separated by <= this
44
-
45
- # =========================
46
- # Optional WhisperX import (alignment)
47
- # =========================
48
- HAS_WHISPERX = False
49
- try:
50
- import whisperx # type: ignore
51
- HAS_WHISPERX = True
52
- except Exception as _e:
53
- HAS_WHISPERX = False
54
-
55
- # =========================
56
- # Load model (one-time)
57
- # =========================
58
- model = WhisperModel(
59
- MODEL_ID,
60
- device=DEVICE,
61
- compute_type=COMPUTE_TYPE,
62
- cpu_threads=CPU_THREADS,
63
- num_workers=NUM_WORKERS,
64
- )
65
-
66
- # =========================
67
- # Helpers
68
- # =========================
69
- def _fmt_srt_time(t: Optional[float]) -> str:
70
  if t is None:
71
  t = 0.0
72
- ms = int(round(float(t) * 1000))
73
  h, ms = divmod(ms, 3600000)
74
  m, ms = divmod(ms, 60000)
75
  s, ms = divmod(ms, 1000)
76
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
77
 
78
  def _segments_to_srt(segments: List[Tuple[int, float, float, str]]) -> str:
 
79
  lines = []
80
  for i, start, end, text in segments:
81
  lines.append(str(i))
82
  lines.append(f"{_fmt_srt_time(start)} --> {_fmt_srt_time(end)}")
83
  lines.append((text or "").strip())
84
- lines.append("")
85
  return "\n".join(lines).strip() + "\n"
86
 
87
- def _ensure_mono16k(src_path: str) -> str:
88
  """
89
- Convert any audio to mono/16k WAV for stable timestamps.
90
- If LOUDNORM=1, apply EBU R128 loudness normalization to reduce 'no-speech' drops on quiet spans.
91
  """
92
- out = tempfile.NamedTemporaryFile(prefix="norm_", suffix=".wav", delete=False)
93
- out_path = out.name
94
- out.close()
95
- cmd = [
96
- "ffmpeg", "-nostdin", "-loglevel", "error", "-y",
97
- "-i", src_path,
98
- ]
99
- if LOUDNORM:
100
- cmd += ["-af", "loudnorm=I=-16:LRA=11:TP=-1.5"]
101
- cmd += ["-ac", "1", "-ar", "16000", out_path]
102
- subprocess.run(cmd, check=True)
103
- return out_path
104
-
105
- def _ffmpeg_trim(src_path: str, start: float, end: float) -> str:
106
- """Create a temp WAV of [start, end]."""
107
- start = max(0.0, float(start))
108
- end = max(start, float(end))
109
- out = tempfile.NamedTemporaryFile(prefix="clip_", suffix=".wav", delete=False)
110
- out_path = out.name
111
- out.close()
112
- cmd = [
113
- "ffmpeg", "-nostdin", "-loglevel", "error", "-y",
114
- "-ss", f"{start:.3f}", "-to", f"{end:.3f}",
115
- "-i", src_path, "-ac", "1", "-ar", "16000", out_path,
116
- ]
117
- subprocess.run(cmd, check=True)
118
- return out_path
119
-
120
- def _run_asr(audio_path: str, use_vad: bool, vad_opts: dict, decode_opts: dict):
121
- # Build kwargs so we can omit None-only fields safely
122
- kwargs = dict(
123
- vad_filter=use_vad,
124
- vad_parameters=vad_opts if use_vad else None,
125
- initial_prompt=BIAS_PROMPT if BIAS_PROMPT else None,
126
  **decode_opts,
127
  )
128
- # Remove keys with value None (compat for older faster-whisper)
129
- for k in ["log_prob_threshold", "compression_ratio_threshold", "patience"]:
130
- if k in kwargs and kwargs[k] is None:
131
- kwargs.pop(k)
132
-
133
- segments_iter, info = model.transcribe(audio_path, **kwargs)
134
 
135
- segs: List[Tuple[int, float, float, str]] = []
136
- texts: List[str] = []
137
- last_end = 0.0
138
  for idx, seg in enumerate(segments_iter, start=1):
139
- start = float(seg.start) if seg.start is not None else last_end
140
  end = float(seg.end) if seg.end is not None else start
141
  text = (seg.text or "").strip()
142
  segs.append((idx, start, end, text))
143
  texts.append(text)
144
- last_end = max(last_end, end)
145
- transcript = "\n".join(texts).strip()
146
- return segs, transcript, info, last_end
147
-
148
- def _find_gaps(segs: List[Tuple[int, float, float, str]], total_dur: float):
149
- """Return list of (gap_start, gap_end, left_idx, right_idx)."""
150
- gaps = []
151
- if not segs:
152
- return gaps
153
- # Gap before first
154
- if segs[0][1] >= GAP_MIN_SECONDS:
155
- gaps.append((0.0, min(segs[0][1], GAP_MAX_SECONDS), None, 0))
156
- # Gaps between
157
- for i in range(len(segs) - 1):
158
- cur_end = segs[i][2]
159
- nxt_start = segs[i + 1][1]
160
- gap = nxt_start - cur_end
161
- if gap >= GAP_MIN_SECONDS:
162
- gap_end = min(nxt_start, cur_end + GAP_MAX_SECONDS)
163
- gaps.append((cur_end, gap_end, i, i + 1))
164
- # Gap after last
165
- tail_gap = total_dur - segs[-1][2]
166
- if tail_gap >= GAP_MIN_SECONDS:
167
- gap_end = min(total_dur, segs[-1][2] + GAP_MAX_SECONDS)
168
- gaps.append((segs[-1][2], gap_end, len(segs) - 1, None))
169
- return gaps[:GAP_MAX_COUNT]
170
-
171
- def _gap_fill(audio_path: str, segs: List[Tuple[int, float, float, str]], total_dur: float, decode_opts_base: dict):
172
- """Re-decode suspicious gaps without VAD. Returns a new merged list."""
173
- if total_dur <= 0 or not segs:
174
- return segs
175
-
176
- gaps = _find_gaps(segs, total_dur)
177
- if not gaps:
178
- return segs
179
-
180
- print(f"[ASR] gap-fill: found {len(gaps)} gap(s) to retry")
181
-
182
- # Slightly stronger decode for recovery
183
- decode_opts_fb = dict(decode_opts_base)
184
- decode_opts_fb.update({
185
- "beam_size": max(2, decode_opts_base.get("beam_size", 1)),
186
- "best_of": 1,
187
- "temperature": 0.0,
188
- "no_speech_threshold": 0.02,
189
- "condition_on_previous_text": False,
190
- # keep log_prob_threshold/compression disabled
191
- "log_prob_threshold": None,
192
- "compression_ratio_threshold": None,
193
- # (no patience key)
194
- })
195
-
196
- recovered: List[Tuple[int, float, float, str]] = []
197
- for (gs, ge, _left_idx, _right_idx) in gaps:
198
- clip_start = max(0.0, gs - GAP_PAD)
199
- clip_end = min(total_dur, ge + GAP_PAD)
200
- if clip_end - clip_start <= 0.12:
201
- continue
202
-
203
- try:
204
- clip_path = _ffmpeg_trim(audio_path, clip_start, clip_end)
205
- segs_c, _, _, _ = _run_asr(clip_path, False, {}, decode_opts_fb)
206
- os.unlink(clip_path)
207
- except Exception as e:
208
- print(f"[ASR] gap-fill error on {clip_start:.2f}-{clip_end:.2f}: {e}")
209
- continue
210
-
211
- # Re-map clip-local times back to absolute, keep only inside the gap (+/- 0.2s tolerance)
212
- for _, s, e, t in segs_c:
213
- text = (t or "").strip()
214
- if not text:
215
- continue
216
- abs_s = clip_start + max(0.0, s)
217
- abs_e = clip_start + max(0.0, e)
218
- if abs_e <= gs - 0.20 or abs_s >= ge + 0.20:
219
- continue
220
- recovered.append((0, abs_s, abs_e, text))
221
-
222
- if not recovered:
223
- return segs
224
-
225
- # Merge + sort + reindex; also join tiny holes between neighbors
226
- merged = segs + recovered
227
- merged.sort(key=lambda x: x[1]) # by start
228
-
229
- deduped: List[Tuple[int, float, float, str]] = []
230
- for tup in merged:
231
- if deduped:
232
- prev = deduped[-1]
233
- gap = tup[1] - prev[2]
234
- if 0.0 <= gap <= JOIN_GAP:
235
- deduped[-1] = (prev[0], prev[1], tup[2], (prev[3] + " " + tup[3]).strip())
236
- continue
237
- if gap < 0.15:
238
- new_text = prev[3] if len(prev[3]) >= len(tup[3]) else tup[3]
239
- deduped[-1] = (prev[0], min(prev[1], tup[1]), max(prev[2], tup[2]), new_text)
240
- continue
241
- deduped.append(tup)
242
-
243
- reindexed = [(i + 1, s, e, t) for i, (_, s, e, t) in enumerate(deduped)]
244
- print(f"[ASR] gap-fill: inserted {len(recovered)} piece(s); total segs={len(reindexed)}")
245
- return reindexed
246
-
247
- # ---------- surgical rescue for specified windows ----------
248
- def _parse_windows(text: str):
249
- """
250
- Parse "20-38,60-75" -> [(20.0, 38.0), (60.0, 75.0)]
251
- """
252
- windows = []
253
- if not text:
254
- return windows
255
- for chunk in text.split(","):
256
- chunk = chunk.strip()
257
- if "-" in chunk:
258
- a, b = chunk.split("-", 1)
259
- try:
260
- a = float(a.strip()); b = float(b.strip())
261
- if b > a:
262
- windows.append((a, b))
263
- except:
264
- continue
265
- return windows
266
-
267
- def _rescue_windows(audio_path: str, windows: List[Tuple[float,float]], base_opts: dict):
268
- rescued = []
269
- if not windows:
270
- return rescued
271
- for (a, b) in windows:
272
- try:
273
- # small context around window
274
- clip = _ffmpeg_trim(audio_path, max(0.0, a - 1.0), b + 1.0)
275
- opts = dict(base_opts)
276
- opts.update({
277
- "beam_size": 2,
278
- "best_of": 1,
279
- "temperature": 0.0,
280
- "no_speech_threshold": 0.02,
281
- "condition_on_previous_text": False,
282
- "log_prob_threshold": None,
283
- "compression_ratio_threshold": None,
284
- })
285
- segs_c, _, _, _ = _run_asr(clip, False, {}, opts)
286
- os.unlink(clip)
287
- except Exception as e:
288
- print("rescue error", a, b, e);
289
- continue
290
-
291
- for _, s, e, t in segs_c:
292
- t = (t or "").strip()
293
- if not t:
294
- continue
295
- abs_s = max(0.0, (a - 1.0) + max(0.0, s or 0.0))
296
- abs_e = max(abs_s, (a - 1.0) + max(0.0, e or 0.0))
297
- # keep only inside the requested window (+/- 0.2s tolerance)
298
- if abs_e < a - 0.20 or abs_s > b + 0.20:
299
- continue
300
- rescued.append((0, abs_s, abs_e, t))
301
- return rescued
302
-
303
- def _merge_with_join(segs: List[Tuple[int,float,float,str]]):
304
- if not segs:
305
- return segs
306
- segs.sort(key=lambda x: x[1])
307
- out: List[Tuple[int,float,float,str]] = []
308
- for tup in segs:
309
- if out:
310
- prev = out[-1]
311
- gap = tup[1] - prev[2]
312
- if 0.0 <= gap <= JOIN_GAP:
313
- out[-1] = (prev[0], prev[1], tup[2], (prev[3] + " " + tup[3]).strip())
314
- continue
315
- if gap < 0.15:
316
- new_text = prev[3] if len(prev[3]) >= len(tup[3]) else tup[3]
317
- out[-1] = (prev[0], min(prev[1], tup[1]), max(prev[2], tup[2]), new_text)
318
- continue
319
- out.append(tup)
320
- return [(i+1, s, e, t) for i, (_, s, e, t) in enumerate(out)]
321
-
322
- def _squash_tail_repeats(text: str) -> str:
323
- # Common outro repeats in Thai; keep a single one
324
- import re
325
- text = text.strip()
326
- text = re.sub(r"(สวัสดีครับ|สวัสดีค่ะ)(\s*\1){1,}$", r"\1", text)
327
- return text
328
-
329
- # ---------- WhisperX alignment ----------
330
- def _align_with_whisperx(audio_path: str, segments: List[Tuple[int,float,float,str]], lang_code: str = "th"):
331
- """
332
- segments: [(idx, start, end, text), ...]
333
- returns: [(idx, start, end, text)] with refined start/end from word-level alignment.
334
- """
335
- if not segments or not HAS_WHISPERX:
336
- return segments
337
-
338
- try:
339
- device = "cuda" if HAS_CUDA else "cpu"
340
- align_model, metadata = whisperx.load_align_model(language_code=lang_code, device=device)
341
- # Convert to list[dict] for whisperx
342
- seg_dicts = [{"start": s, "end": e, "text": t} for (_i, s, e, t) in segments]
343
- aligned = whisperx.align(
344
- seg_dicts, align_model, metadata, audio_path, device,
345
- return_char_alignments=False
346
- )
347
- out = []
348
- for i, seg in enumerate(aligned.get("segments", []), start=1):
349
- s = float(seg.get("start", seg_dicts[i-1]["start"]))
350
- e = float(seg.get("end", seg_dicts[i-1]["end"]))
351
- t = seg.get("text", seg_dicts[i-1]["text"])
352
- out.append((i, s, e, t))
353
- return out if out else segments
354
- except Exception as e:
355
- print("[Align] WhisperX alignment failed:", e)
356
- return segments
357
-
358
- # =========================
359
- # Transcribe main
360
- # =========================
361
- def transcribe(audio_path: Optional[str], vad_mode: str, enable_gapfill: bool, rescue_text: str, use_alignment: bool):
362
- if not audio_path:
363
- return "", None, []
364
-
365
- # Normalize audio to mono/16k for consistent timestamps
366
- try:
367
- wav_path = _ensure_mono16k(audio_path)
368
- except Exception as e:
369
- return f"แปลงไฟล์เสียงด้วย ffmpeg ไม่สำเร็จ: {e}", None, []
370
-
371
- # ---- Quiet-speech safe decode options ----
372
- decode_opts = dict(
373
- language="th",
374
- task="transcribe",
375
- beam_size=2, # small recall bump
376
- best_of=1,
377
- temperature=0.0,
378
- # patience omitted for compatibility
379
- condition_on_previous_text=False,
380
-
381
- # ↓↓↓ Make Whisper reluctant to drop quiet spans ↓↓↓
382
- no_speech_threshold=0.05,
383
- log_prob_threshold=None, # disable hard drop by avg logprob
384
- compression_ratio_threshold=None, # disable CR gate (music/noise)
385
-
386
- chunk_length=20, # shorter chunks reduce all-or-nothing drops
387
- )
388
-
389
- # Gentler VAD (only used if we choose VAD path)
390
- vad_opts = dict(
391
- threshold=0.08,
392
- min_silence_duration_ms=420,
393
- min_speech_duration_ms=80,
394
- speech_pad_ms=1200,
395
- )
396
-
397
- # Choose whether to start with VAD or NO-VAD
398
- use_vad_first = (vad_mode == "AUTO (VAD on)")
399
-
400
- # Pass 1
401
- segs1, _text1, info1, last_end1 = _run_asr(
402
- wav_path, use_vad_first, vad_opts if use_vad_first else {}, decode_opts
403
- )
404
- dur = float(getattr(info1, "duration", 0.0) or 0.0)
405
- cov1 = (last_end1 / dur * 100.0) if dur > 0 else 0.0
406
- print(f"[ASR] pass1 ({'VAD' if use_vad_first else 'NO-VAD'}) coverage: "
407
- f"{last_end1:.2f}/{dur:.2f}s ({cov1:.1f}%) | segs={len(segs1)}")
408
-
409
- chosen_segs = segs1
410
-
411
- # Fallback: try the other mode if we obviously ended early
412
- if dur > 0 and (last_end1 < 0.98 * dur or len(segs1) == 0):
413
- decode_opts_fb = dict(decode_opts, no_speech_threshold=0.02)
414
- segs2, _text2, info2, last_end2 = _run_asr(
415
- wav_path, not use_vad_first, {} if use_vad_first else vad_opts, decode_opts_fb
416
- )
417
- cov2 = (last_end2 / dur * 100.0) if dur > 0 else 0.0
418
- print(f"[ASR] pass2 ({'NO-VAD' if use_vad_first else 'VAD'}) coverage: "
419
- f"{last_end2:.2f}/{dur:.2f}s ({cov2:.1f}%) | segs={len(segs2)}")
420
- if last_end2 > last_end1 + 0.5:
421
- chosen_segs = segs2
422
-
423
- # Gap-fill to rescue missed mid-sentences (optional)
424
- if enable_gapfill:
425
- chosen_segs = _gap_fill(wav_path, chosen_segs, dur, decode_opts)
426
-
427
- # Surgical rescue for user-provided windows (e.g., "20-38,60-75")
428
- windows = _parse_windows(rescue_text)
429
- if windows:
430
- rescued = _rescue_windows(wav_path, windows, decode_opts)
431
- if rescued:
432
- chosen_segs = _merge_with_join(chosen_segs + rescued)
433
-
434
- # Optional WhisperX alignment (refine timings & recover boundary words)
435
- if use_alignment and HAS_WHISPERX:
436
- chosen_segs = _align_with_whisperx(wav_path, chosen_segs, lang_code="th")
437
 
438
  # Build outputs
439
- transcript = "\n".join([t for (_, _, _, t) in chosen_segs]).strip()
440
- transcript = _squash_tail_repeats(transcript)
441
 
442
- # SRT file
443
- srt_str = _segments_to_srt(chosen_segs)
444
  srt_path = "/tmp/output.srt"
445
  with open(srt_path, "w", encoding="utf-8") as f:
446
  f.write(srt_str)
447
 
 
448
  seg_dicts = [
449
- {"index": i, "start": s, "end": e, "text": t}
450
- for (i, s, e, t) in chosen_segs
451
  ]
452
 
453
- # Clean temp wav
454
- try:
455
- os.unlink(wav_path)
456
- except Exception:
457
- pass
458
-
459
  return transcript, srt_path, seg_dicts
460
 
461
- # =========================
462
- # UI
463
- # =========================
464
- with gr.Blocks(title="Thai ASR — faster-whisper (quiet-speech safe)") as demo:
465
- gr.Markdown("## 🇹🇭 Thai ASR — faster-whisper (`Thaweewat/whisper-th-medium-ct2`)\n"
466
- "หลีกเลี่ยงคำหายช่วงเสียงเบา: ปิด gate เคร่ง, chunk 20s, Gap-fill + Rescue Windows\n"
467
- "มีตัวเลือกปรับปรุงเวลา/ขอบคำด้วย WhisperX (แนะนำ GPU)")
468
-
469
- audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="อัปโหลดไฟล์เสียงหรืออัดเสียง")
470
  with gr.Row():
471
- vad_mode = gr.Radio(choices=["AUTO (VAD on)", "NO-VAD"], value="NO-VAD", label="VAD โหมด")
472
- gapfill = gr.Checkbox(value=True, label="กู้หลุมกลางประโยค (Gap-fill)")
473
- align_cb = gr.Checkbox(value=False, label="WhisperX word alignment (GPU แนะนำ)")
474
- rescue = gr.Textbox(value="", lines=1, label="Rescue windows (วินาที) เช่น 20-38,60-75 — ว่างไว้ถ้าไม่ใช้")
475
-
476
- btn = gr.Button("ถอดเสียง", variant="primary")
477
-
478
- out_text = gr.Textbox(label="ผลถอดเสียง", lines=14)
479
- out_srt = gr.File(label="ดาวน์โหลดไฟล์ .srt")
480
- out_json = gr.JSON(label="Segments (index/start/end/text)")
481
 
482
- btn.click(fn=transcribe, inputs=[audio, vad_mode, gapfill, rescue, align_cb], outputs=[out_text, out_srt, out_json])
483
 
484
  if __name__ == "__main__":
485
- demo.queue(max_size=8).launch()
 
1
+ # app.py — Thai ASR on faster-whisper using Thaweewat/whisper-th-medium-ct2
 
 
2
  import os
3
+ from pathlib import Path
4
+ from typing import List, Tuple
 
5
 
6
+ import torch
7
  import gradio as gr
8
  from faster_whisper import WhisperModel
9
 
10
+ MODEL_ID = "Thaweewat/whisper-th-medium-ct2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Pick device/compute type
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ COMPUTE_TYPE = "int8_float16" if DEVICE == "cuda" else "int8"
15
 
16
+ # Load once at startup (first cold start will download the model)
17
+ model = WhisperModel(MODEL_ID, device=DEVICE, compute_type=COMPUTE_TYPE)
 
 
 
18
 
19
+ def _fmt_srt_time(t: float) -> str:
20
+ """Format seconds -> SRT timestamp."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if t is None:
22
  t = 0.0
23
+ ms = int(round(t * 1000))
24
  h, ms = divmod(ms, 3600000)
25
  m, ms = divmod(ms, 60000)
26
  s, ms = divmod(ms, 1000)
27
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
28
 
29
  def _segments_to_srt(segments: List[Tuple[int, float, float, str]]) -> str:
30
+ """[(idx, start, end, text)] -> SRT string."""
31
  lines = []
32
  for i, start, end, text in segments:
33
  lines.append(str(i))
34
  lines.append(f"{_fmt_srt_time(start)} --> {_fmt_srt_time(end)}")
35
  lines.append((text or "").strip())
36
+ lines.append("") # blank line between cues
37
  return "\n".join(lines).strip() + "\n"
38
 
39
+ def transcribe(audio_path: str):
40
  """
41
+ audio_path: Gradio supplies a file path.
42
+ Returns: transcript text, SRT file path, and list of segment dicts
43
  """
44
+ # Thai-only decoding, with VAD to skip silence
45
+ decode_opts = dict(language="th", task="transcribe", beam_size=5, best_of=5, temperature=[0.0, 0.2, 0.4])
46
+ vad_opts = dict(min_silence_duration_ms=500)
47
+
48
+ segments_iter, info = model.transcribe(
49
+ audio_path,
50
+ vad_filter=True,
51
+ vad_parameters=vad_opts,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  **decode_opts,
53
  )
 
 
 
 
 
 
54
 
55
+ segs = []
56
+ texts = []
 
57
  for idx, seg in enumerate(segments_iter, start=1):
58
+ start = float(seg.start) if seg.start is not None else 0.0
59
  end = float(seg.end) if seg.end is not None else start
60
  text = (seg.text or "").strip()
61
  segs.append((idx, start, end, text))
62
  texts.append(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # Build outputs
65
+ transcript = "\n".join(texts).strip()
 
66
 
67
+ # Write SRT to a temp file (Gradio will serve it)
68
+ srt_str = _segments_to_srt(segs)
69
  srt_path = "/tmp/output.srt"
70
  with open(srt_path, "w", encoding="utf-8") as f:
71
  f.write(srt_str)
72
 
73
+ # JSON-friendly segments
74
  seg_dicts = [
75
+ {"index": i, "start": start, "end": end, "text": text}
76
+ for (i, start, end, text) in segs
77
  ]
78
 
 
 
 
 
 
 
79
  return transcript, srt_path, seg_dicts
80
 
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown("## 🇹🇭 Thai ASR — faster-whisper (`Thaweewat/whisper-th-medium-ct2`)")
 
 
 
 
 
 
 
83
  with gr.Row():
84
+ audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
85
+ with gr.Row():
86
+ btn = gr.Button("Transcribe", variant="primary")
87
+ with gr.Row():
88
+ out_text = gr.Textbox(label="Transcript", lines=12)
89
+ with gr.Row():
90
+ out_srt = gr.File(label="Download SRT")
91
+ with gr.Row():
92
+ out_json = gr.JSON(label="Segments (start/end/text)")
 
93
 
94
+ btn.click(fn=transcribe, inputs=audio, outputs=[out_text, out_srt, out_json])
95
 
96
  if __name__ == "__main__":
97
+ demo.launch()