woranit's picture
Update app.py
c31871d verified
# app.py β€” Thai ASR on faster-whisper using Thaweewat/whisper-th-medium-ct2
import os
from pathlib import Path
from typing import List, Tuple
import torch
import gradio as gr
from faster_whisper import WhisperModel
MODEL_ID = "Thaweewat/whisper-th-medium-ct2"
# Pick device/compute type
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = "int8_float16" if DEVICE == "cuda" else "int8"
# Load once at startup (first cold start will download the model)
model = WhisperModel(MODEL_ID, device=DEVICE, compute_type=COMPUTE_TYPE)
def _fmt_srt_time(t: float) -> str:
"""Format seconds -> SRT timestamp."""
if t is None:
t = 0.0
ms = int(round(t * 1000))
h, ms = divmod(ms, 3600000)
m, ms = divmod(ms, 60000)
s, ms = divmod(ms, 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
def _segments_to_srt(segments: List[Tuple[int, float, float, str]]) -> str:
"""[(idx, start, end, text)] -> SRT string."""
lines = []
for i, start, end, text in segments:
lines.append(str(i))
lines.append(f"{_fmt_srt_time(start)} --> {_fmt_srt_time(end)}")
lines.append((text or "").strip())
lines.append("") # blank line between cues
return "\n".join(lines).strip() + "\n"
def transcribe(audio_path: str):
"""
audio_path: Gradio supplies a file path.
Returns: transcript text, SRT file path, and list of segment dicts
"""
# Thai-only decoding, with VAD to skip silence
decode_opts = dict(language="th", task="transcribe", beam_size=5, best_of=5, temperature=[0.0, 0.2, 0.4])
vad_opts = dict(min_silence_duration_ms=500)
segments_iter, info = model.transcribe(
audio_path,
vad_filter=True,
vad_parameters=vad_opts,
**decode_opts,
)
segs = []
texts = []
for idx, seg in enumerate(segments_iter, start=1):
start = float(seg.start) if seg.start is not None else 0.0
end = float(seg.end) if seg.end is not None else start
text = (seg.text or "").strip()
segs.append((idx, start, end, text))
texts.append(text)
# Build outputs
transcript = "\n".join(texts).strip()
# Write SRT to a temp file (Gradio will serve it)
srt_str = _segments_to_srt(segs)
srt_path = "/tmp/output.srt"
with open(srt_path, "w", encoding="utf-8") as f:
f.write(srt_str)
# JSON-friendly segments
seg_dicts = [
{"index": i, "start": start, "end": end, "text": text}
for (i, start, end, text) in segs
]
return transcript, srt_path, seg_dicts
with gr.Blocks() as demo:
gr.Markdown("## πŸ‡ΉπŸ‡­ Thai ASR β€” faster-whisper (`Thaweewat/whisper-th-medium-ct2`)")
with gr.Row():
audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
with gr.Row():
btn = gr.Button("Transcribe", variant="primary")
with gr.Row():
out_text = gr.Textbox(label="Transcript", lines=12)
with gr.Row():
out_srt = gr.File(label="Download SRT")
with gr.Row():
out_json = gr.JSON(label="Segments (start/end/text)")
btn.click(fn=transcribe, inputs=audio, outputs=[out_text, out_srt, out_json])
if __name__ == "__main__":
demo.launch()