Spaces:
Running on T4
Running on T4
| import os | |
| import html | |
| import re | |
| from difflib import SequenceMatcher | |
| from pathlib import Path | |
| import urllib.request | |
| import time | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| from openai import OpenAI | |
| DEMO_URL = "https://github.com/dmatekenya/AI-seminars-malawi/releases/download/v1.1/test_15_secs.wav" | |
| DEMO_AUDIO_PATH = Path("/tmp/demo.wav") | |
| def ensure_demo_audio() -> str: | |
| """ | |
| Ensure demo audio exists on disk and return the path as a string. | |
| Raises RuntimeError with a useful message if download fails. | |
| """ | |
| DEMO_AUDIO_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| # If already downloaded, reuse it | |
| if DEMO_AUDIO_PATH.exists() and DEMO_AUDIO_PATH.stat().st_size > 0: | |
| print(f"[demo] Using cached audio: {DEMO_AUDIO_PATH} ({DEMO_AUDIO_PATH.stat().st_size} bytes)", flush=True) | |
| return str(DEMO_AUDIO_PATH) | |
| print(f"[demo] Downloading demo audio from: {DEMO_URL}", flush=True) | |
| try: | |
| # Download to a temp file first, then rename (avoids partial files) | |
| tmp_path = DEMO_AUDIO_PATH.with_suffix(".wav.tmp") | |
| urllib.request.urlretrieve(DEMO_URL, tmp_path) | |
| os.replace(tmp_path, DEMO_AUDIO_PATH) | |
| except Exception as e: | |
| raise RuntimeError(f"[demo] Failed to download demo audio from {DEMO_URL}. Error: {e}") | |
| if not DEMO_AUDIO_PATH.exists() or DEMO_AUDIO_PATH.stat().st_size == 0: | |
| raise RuntimeError(f"[demo] Download completed but file is missing/empty at {DEMO_AUDIO_PATH}") | |
| print(f"[demo] Downloaded demo audio: {DEMO_AUDIO_PATH} ({DEMO_AUDIO_PATH.stat().st_size} bytes)", flush=True) | |
| return str(DEMO_AUDIO_PATH) | |
| # ----------------------------- | |
| # Models / Config | |
| # ----------------------------- | |
| BASE_REPO = "openai/whisper-large-v3" | |
| FINETUNED_REPO = "dmatekenya/whisper-large-v3-chichewa" | |
| FINETUNED_REVISION = "bff60fb08ba9f294e05bfcab4306f30b6a0cfc0a" # pinned commit hash | |
| # Local WhisperProcessor language hint (keep consistent with how you evaluated) | |
| LOCAL_LANGUAGE = "shona" | |
| # Audio constraints | |
| TARGET_SR = 16000 | |
| MAX_SECONDS = 30.0 | |
| # OpenAI transcription model (commercial) | |
| OPENAI_MODEL = "gpt-4o-transcribe" # simple + stable | |
| # ----------------------------- | |
| # UI Text / Styling | |
| # ----------------------------- | |
| LOGO_HTML = """ | |
| <div style="text-align:center; margin-bottom: 25px;"> | |
| <img src="https://i.ibb.co/5nQdGSs/logo.png" | |
| style="max-width: 100%; height: auto; border-radius: 12px;"> | |
| </div> | |
| """ | |
| HEADER_HTML = """ | |
| <div style="text-align:center; max-width:900px; margin:0 auto;"> | |
| <h1 style="font-size:36px; margin-bottom:12px;"> | |
| Chichewa Speech2Text: How Custom Data Improves Transcription Performance | |
| </h1> | |
| <p style="font-size:22px; font-weight:700; color:#1F3A5F; margin-bottom:12px;"> | |
| Observe how the fine-tuned model provides better transcription quality. | |
| </p> | |
| <p style="font-size:18px; color:#444; margin-bottom:25px;"> | |
| Upload or record a short Chichewa voice note (≤30 seconds). | |
| </p> | |
| </div> | |
| """ | |
| DIVIDER = """ | |
| <div style="max-width:900px; margin:10px auto;"> | |
| <hr style="border:0; border-top:1px solid #ddd;"> | |
| </div> | |
| """ | |
| # TITLE_HTML = """ | |
| # <h1 style="text-align:center; font-size:34px; margin-bottom:10px;"> | |
| # Chichewa Speech2Text: How Custom Data Improves Performance | |
| # </h1> | |
| # """ | |
| # HIGHLIGHT_TEXT = """ | |
| # <p style="text-align:center; font-size:20px; font-weight:600; color:#1F3A5F; margin-bottom:20px;"> | |
| # Observe how the fine-tuned model provides better transcription quality. | |
| # </p> | |
| # """ | |
| # DESCRIPTION_HTML = """ | |
| # <p style="text-align:center; font-size:18px; margin-bottom: 18px;"> | |
| # Upload or record a short Chichewa voice note (≤30 seconds). The same audio will be transcribed by three systems. | |
| # </p> | |
| # """ | |
| ARTICLE_HTML = """ | |
| <p style="text-align:center; margin-top: 10px;"> | |
| Read more about the <a href="https://dmatekenya.github.io/Chichewa-Speech2Text/README.html" target="_blank">ChichewaSpeech2Text</a> project | |
| </p> | |
| """ | |
| # ----------------------------- | |
| # Load local models once | |
| # ----------------------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| print(f"Using device: {DEVICE}", flush=True) | |
| PROCESSOR = WhisperProcessor.from_pretrained( | |
| BASE_REPO, | |
| language=LOCAL_LANGUAGE, | |
| task="transcribe", | |
| ) | |
| MODEL_BASE = WhisperForConditionalGeneration.from_pretrained(BASE_REPO).to(DEVICE).eval() | |
| MODEL_FT = WhisperForConditionalGeneration.from_pretrained( | |
| FINETUNED_REPO, | |
| revision=FINETUNED_REVISION, | |
| ).to(DEVICE).eval() | |
| if DEVICE == "cuda": | |
| MODEL_BASE = MODEL_BASE.to(dtype=DTYPE) | |
| MODEL_FT = MODEL_FT.to(dtype=DTYPE) | |
| OPENAI_CLIENT = OpenAI() | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def _tokenize_words(s: str): | |
| # words + punctuation as separate tokens | |
| return re.findall(r"\w+|[^\w\s]", s, flags=re.UNICODE) | |
| def diff_highlight_html(ref: str, hyp: str, title_ref="Reference", title_hyp="Hypothesis") -> str: | |
| """ | |
| Returns HTML showing a word-level diff between ref and hyp. | |
| - deletions (in ref not in hyp): red + strikethrough | |
| - insertions (in hyp not in ref): green | |
| - replacements: red struck old + green new | |
| """ | |
| ref_toks = _tokenize_words(ref or "") | |
| hyp_toks = _tokenize_words(hyp or "") | |
| sm = SequenceMatcher(a=ref_toks, b=hyp_toks) | |
| ref_out, hyp_out = [], [] | |
| for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
| a = ref_toks[i1:i2] | |
| b = hyp_toks[j1:j2] | |
| if tag == "equal": | |
| ref_out += [html.escape(t) for t in a] | |
| hyp_out += [html.escape(t) for t in b] | |
| elif tag == "delete": | |
| ref_out += [f"<span style='color:#b00020;text-decoration:line-through;background:#ffe6e6;padding:1px 3px;border-radius:4px;'>{html.escape(t)}</span>" for t in a] | |
| elif tag == "insert": | |
| hyp_out += [f"<span style='color:#0a7a0a;background:#e6ffe6;padding:1px 3px;border-radius:4px;'>{html.escape(t)}</span>" for t in b] | |
| elif tag == "replace": | |
| ref_out += [f"<span style='color:#b00020;text-decoration:line-through;background:#ffe6e6;padding:1px 3px;border-radius:4px;'>{html.escape(t)}</span>" for t in a] | |
| hyp_out += [f"<span style='color:#0a7a0a;background:#e6ffe6;padding:1px 3px;border-radius:4px;'>{html.escape(t)}</span>" for t in b] | |
| # tidy spacing: join with spaces, then remove spaces before punctuation | |
| def _join(tokens): | |
| s = " ".join(tokens) | |
| s = re.sub(r"\s+([,.;:!?])", r"\1", s) | |
| s = re.sub(r"\(\s+", "(", s) | |
| s = re.sub(r"\s+\)", ")", s) | |
| return s | |
| ref_html = _join(ref_out) | |
| hyp_html = _join(hyp_out) | |
| return f""" | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:14px;"> | |
| <div style="padding:12px;border:1px solid #ddd;border-radius:10px;"> | |
| <div style="font-weight:700;margin-bottom:8px;">{html.escape(title_ref)}</div> | |
| <div style="line-height:1.6;">{ref_html}</div> | |
| </div> | |
| <div style="padding:12px;border:1px solid #ddd;border-radius:10px;"> | |
| <div style="font-weight:700;margin-bottom:8px;">{html.escape(title_hyp)}</div> | |
| <div style="line-height:1.6;">{hyp_html}</div> | |
| </div> | |
| </div> | |
| """ | |
| def make_diffs(base_text: str, ft_text: str, openai_text: str, ref_choice: str): | |
| if ref_choice == "Fine-tuned": | |
| ref = ft_text | |
| ref_name = "Fine-tuned (Reference)" | |
| elif ref_choice == "OpenAI": | |
| ref = openai_text | |
| ref_name = "OpenAI (Reference)" | |
| else: | |
| ref = base_text | |
| ref_name = "Base (Reference)" | |
| base_diff = diff_highlight_html(ref, base_text, title_ref=ref_name, title_hyp="Base") | |
| openai_diff = diff_highlight_html(ref, openai_text, title_ref=ref_name, title_hyp="OpenAI") | |
| return base_diff, openai_diff | |
| def load_audio(audio_path: str) -> Tuple[np.ndarray, int, float]: | |
| y, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True) | |
| dur = float(len(y) / sr) if sr else 0.0 | |
| return y, sr, dur | |
| def transcribe_local(model: WhisperForConditionalGeneration, audio_16k: np.ndarray) -> str: | |
| feats = PROCESSOR( | |
| audio_16k, | |
| return_tensors="pt", | |
| sampling_rate=TARGET_SR | |
| ).input_features | |
| model_device = next(model.parameters()).device | |
| model_dtype = next(model.parameters()).dtype # <- key line | |
| feats = feats.to(device=model_device, dtype=model_dtype) | |
| generated_ids = model.generate(input_features=feats) | |
| text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return text.strip() | |
| def transcribe_openai(audio_path: str) -> str: | |
| if not os.getenv("OPENAI_API_KEY"): | |
| return "OpenAI ASR disabled: OPENAI_API_KEY not set in Space Secrets." | |
| prompt = "Chichewa transcription. Malawi names like Lilongwe, Blantyre, Zomba. Keep local names as spoken." | |
| with open(audio_path, "rb") as f: | |
| resp = OPENAI_CLIENT.audio.transcriptions.create( | |
| file=f, | |
| model=OPENAI_MODEL, | |
| prompt=prompt, | |
| temperature=0.0, | |
| response_format="json", | |
| ) | |
| return (resp.text or "").strip() | |
| def transcribe_all(audio_path: Optional[str]) -> Tuple[str, str, str, str]: | |
| """ | |
| Returns: | |
| status, base_text, finetuned_text, openai_text | |
| """ | |
| if not audio_path: | |
| return "Please record or upload an audio clip.", "", "", "" | |
| # Load audio once | |
| try: | |
| y, sr, dur = load_audio(audio_path) | |
| except Exception as e: | |
| return f"❌ Failed to load audio: {e}", "", "", "" | |
| status = [] | |
| # Base (local) | |
| t0 = time.time() | |
| try: | |
| base_text = transcribe_local(MODEL_BASE, y) | |
| status.append(f"1. Open Source (base) {time.time()-t0:.2f}s") | |
| except Exception as e: | |
| base_text = f"[ERROR] Base failed: {e}" | |
| status.append("❌ Base failed") | |
| # Fine-tuned (local) | |
| t1 = time.time() | |
| try: | |
| ft_text = transcribe_local(MODEL_FT, y) | |
| status.append(f"2. Fine-tuned {time.time()-t1:.2f}s") | |
| except Exception as e: | |
| ft_text = f"[ERROR] Fine-tuned failed: {e}" | |
| status.append("❌ Fine-tuned failed") | |
| # OpenAI (commercial) | |
| t2 = time.time() | |
| try: | |
| openai_text = transcribe_openai(audio_path) | |
| status.append(f"3. OpenAI ({OPENAI_MODEL}) {time.time()-t2:.2f}s") | |
| except Exception as e: | |
| openai_text = f"[ERROR] OpenAI failed: {e}" | |
| status.append("❌ OpenAI failed") | |
| return "\n".join(status), base_text, ft_text, openai_text | |
| def init_demo(): | |
| audio_path = ensure_demo_audio() | |
| status, base_text, ft_text, openai_text = transcribe_all(audio_path) | |
| return audio_path, status, base_text, ft_text, openai_text | |
| # ----------------------------- | |
| # Warm-up (local models only) | |
| # ----------------------------- | |
| # def warmup(): | |
| # try: | |
| # dummy = np.zeros(int(TARGET_SR * 1.0), dtype=np.float32) | |
| # _ = transcribe_local(MODEL_BASE, dummy) | |
| # _ = transcribe_local(MODEL_FT, dummy) | |
| # print("Warm-up complete.", flush=True) | |
| # except Exception as e: | |
| # print(f"Warm-up skipped/failed: {e}", flush=True) | |
| # warmup() | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| with gr.Blocks(theme="grass", title="Chichewa Speech2Text") as demo: | |
| gr.HTML(LOGO_HTML) | |
| gr.HTML(DIVIDER) | |
| gr.HTML(HEADER_HTML) | |
| audio_in = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Audio Input (Record or Upload)", | |
| value=DEMO_AUDIO_PATH, | |
| ) | |
| run_btn = gr.Button("Transcribe & Compare", variant="primary") | |
| status_out = gr.Textbox(label="Status / timing", lines=3) | |
| with gr.Row(equal_height=True): | |
| base_out = gr.Textbox(label="Open Source ASR Model", lines=12) | |
| ft_out = gr.Textbox(label="Open Source Model Fine-Tuned with Custom Chichewa Speech", lines=12) | |
| commercial_out = gr.Textbox(label="Frontier Commercial ASR Model (OpenAI)", lines=12) | |
| run_btn.click( | |
| fn=transcribe_all, | |
| inputs=[audio_in], | |
| outputs=[status_out, base_out, ft_out, commercial_out], | |
| ) | |
| # Preload audio + transcripts immediately on page load | |
| demo.load( | |
| fn=init_demo, | |
| inputs=None, | |
| outputs=[audio_in, status_out, base_out, ft_out, commercial_out], | |
| ) | |
| gr.Markdown(ARTICLE_HTML) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=2).launch() | |