File size: 3,243 Bytes
c31871d
67ff989
c31871d
 
67ff989
c31871d
67ff989
 
 
c31871d
fae294a
c31871d
 
 
730b279
c31871d
 
fae294a
c31871d
 
67ff989
 
c31871d
67ff989
 
 
 
 
 
c31871d
67ff989
 
 
 
 
c31871d
67ff989
 
c31871d
bf03a49
c31871d
 
bf03a49
c31871d
 
 
 
 
 
 
 
67ff989
 
97cfb42
c31871d
 
67ff989
c31871d
67ff989
 
 
 
730b279
fae294a
c31871d
f1522ff
c31871d
 
67ff989
 
 
 
c31871d
67ff989
c31871d
 
67ff989
 
fae294a
67ff989
c31871d
 
b813bb8
c31871d
 
 
 
 
 
 
 
 
67ff989
c31871d
67ff989
 
c31871d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# app.py β€” Thai ASR on faster-whisper using Thaweewat/whisper-th-medium-ct2
import os
from pathlib import Path
from typing import List, Tuple

import torch
import gradio as gr
from faster_whisper import WhisperModel

MODEL_ID = "Thaweewat/whisper-th-medium-ct2"

# Pick device/compute type
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = "int8_float16" if DEVICE == "cuda" else "int8"

# Load once at startup (first cold start will download the model)
model = WhisperModel(MODEL_ID, device=DEVICE, compute_type=COMPUTE_TYPE)

def _fmt_srt_time(t: float) -> str:
    """Format seconds -> SRT timestamp."""
    if t is None:
        t = 0.0
    ms = int(round(t * 1000))
    h, ms = divmod(ms, 3600000)
    m, ms = divmod(ms, 60000)
    s, ms = divmod(ms, 1000)
    return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"

def _segments_to_srt(segments: List[Tuple[int, float, float, str]]) -> str:
    """[(idx, start, end, text)] -> SRT string."""
    lines = []
    for i, start, end, text in segments:
        lines.append(str(i))
        lines.append(f"{_fmt_srt_time(start)} --> {_fmt_srt_time(end)}")
        lines.append((text or "").strip())
        lines.append("")  # blank line between cues
    return "\n".join(lines).strip() + "\n"

def transcribe(audio_path: str):
    """
    audio_path: Gradio supplies a file path.
    Returns: transcript text, SRT file path, and list of segment dicts
    """
    # Thai-only decoding, with VAD to skip silence
    decode_opts = dict(language="th", task="transcribe", beam_size=5, best_of=5, temperature=[0.0, 0.2, 0.4])
    vad_opts = dict(min_silence_duration_ms=500)

    segments_iter, info = model.transcribe(
        audio_path,
        vad_filter=True,
        vad_parameters=vad_opts,
        **decode_opts,
    )

    segs = []
    texts = []
    for idx, seg in enumerate(segments_iter, start=1):
        start = float(seg.start) if seg.start is not None else 0.0
        end = float(seg.end) if seg.end is not None else start
        text = (seg.text or "").strip()
        segs.append((idx, start, end, text))
        texts.append(text)

    # Build outputs
    transcript = "\n".join(texts).strip()

    # Write SRT to a temp file (Gradio will serve it)
    srt_str = _segments_to_srt(segs)
    srt_path = "/tmp/output.srt"
    with open(srt_path, "w", encoding="utf-8") as f:
        f.write(srt_str)

    # JSON-friendly segments
    seg_dicts = [
        {"index": i, "start": start, "end": end, "text": text}
        for (i, start, end, text) in segs
    ]

    return transcript, srt_path, seg_dicts

with gr.Blocks() as demo:
    gr.Markdown("## πŸ‡ΉπŸ‡­ Thai ASR β€” faster-whisper (`Thaweewat/whisper-th-medium-ct2`)")
    with gr.Row():
        audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
    with gr.Row():
        btn = gr.Button("Transcribe", variant="primary")
    with gr.Row():
        out_text = gr.Textbox(label="Transcript", lines=12)
    with gr.Row():
        out_srt = gr.File(label="Download SRT")
    with gr.Row():
        out_json = gr.JSON(label="Segments (start/end/text)")

    btn.click(fn=transcribe, inputs=audio, outputs=[out_text, out_srt, out_json])

if __name__ == "__main__":
    demo.launch()