""" Chichewa Speech2Text — Diff Viewer (Single-model inference + Ground-truth diff) What this app does: - Lets the user upload/record an audio clip (≤30s recommended). - Lets the user paste a human-verified "ground truth" reference transcript. - Lets the user choose ONE system to run (Base / Fine-tuned / OpenAI). - Produces a 2-column, word-level highlighted diff: Reference vs Hypothesis. Notes: - This keeps inference lightweight (runs only one model at a time). - The demo audio is preloaded on page load, but inference is NOT auto-run to reduce break risk. """ import os import html import re import time import urllib.request from difflib import SequenceMatcher from pathlib import Path from typing import Optional, Tuple import gradio as gr import librosa import numpy as np import torch from transformers import WhisperForConditionalGeneration, WhisperProcessor from openai import OpenAI # ----------------------------- # Demo audio # ----------------------------- DEMO_URL = "https://github.com/dmatekenya/AI-seminars-malawi/releases/download/v1.0/WAU12.wav" DEMO_AUDIO_PATH = Path("/tmp/demo.wav") def ensure_demo_audio() -> str: """ Ensure demo audio exists on disk and return the path as a string. Raises: RuntimeError: If download fails or file is empty. """ DEMO_AUDIO_PATH.parent.mkdir(parents=True, exist_ok=True) if DEMO_AUDIO_PATH.exists() and DEMO_AUDIO_PATH.stat().st_size > 0: print(f"[demo] Using cached audio: {DEMO_AUDIO_PATH} ({DEMO_AUDIO_PATH.stat().st_size} bytes)", flush=True) return str(DEMO_AUDIO_PATH) print(f"[demo] Downloading demo audio from: {DEMO_URL}", flush=True) try: tmp_path = DEMO_AUDIO_PATH.with_suffix(".wav.tmp") urllib.request.urlretrieve(DEMO_URL, tmp_path) os.replace(tmp_path, DEMO_AUDIO_PATH) except Exception as e: raise RuntimeError(f"[demo] Failed to download demo audio from {DEMO_URL}. Error: {e}") if not DEMO_AUDIO_PATH.exists() or DEMO_AUDIO_PATH.stat().st_size == 0: raise RuntimeError(f"[demo] Download completed but file is missing/empty at {DEMO_AUDIO_PATH}") print(f"[demo] Downloaded demo audio: {DEMO_AUDIO_PATH} ({DEMO_AUDIO_PATH.stat().st_size} bytes)", flush=True) return str(DEMO_AUDIO_PATH) # ----------------------------- # Models / Config # ----------------------------- BASE_REPO = "openai/whisper-large-v3" FINETUNED_REPO = "dmatekenya/whisper-large-v3-chichewa" FINETUNED_REVISION = "bff60fb08ba9f294e05bfcab4306f30b6a0cfc0a" # pinned commit hash # Keep this consistent with how you evaluated to avoid surprises. # (You can change later; for tomorrow, stability > perfection.) LOCAL_LANGUAGE = "shona" TARGET_SR = 16000 MAX_SECONDS = 30.0 # recommended, not enforced here OPENAI_MODEL = "gpt-4o-transcribe" # ----------------------------- # UI Text / Styling # ----------------------------- LOGO_HTML = """
Paste a human-verified reference transcript and compare it to one ASR system at a time.
Record or upload a short voice note (≤30 seconds recommended).
Read more about the ChichewaSpeech2Text project and sign up for our voice note donation event: Google Form.
""" # ----------------------------- # Load models once # ----------------------------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 print(f"Using device: {DEVICE}", flush=True) PROCESSOR = WhisperProcessor.from_pretrained( BASE_REPO, language=LOCAL_LANGUAGE, task="transcribe", ) MODEL_BASE = WhisperForConditionalGeneration.from_pretrained(BASE_REPO).to(DEVICE).eval() MODEL_FT = WhisperForConditionalGeneration.from_pretrained( FINETUNED_REPO, revision=FINETUNED_REVISION, ).to(DEVICE).eval() if DEVICE == "cuda": MODEL_BASE = MODEL_BASE.to(dtype=DTYPE) MODEL_FT = MODEL_FT.to(dtype=DTYPE) OPENAI_CLIENT = OpenAI() # ----------------------------- # Helpers: audio + transcription # ----------------------------- def load_audio(audio_path: str) -> Tuple[np.ndarray, int, float]: y, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True) dur = float(len(y) / sr) if sr else 0.0 return y, sr, dur @torch.inference_mode() def transcribe_local(model: WhisperForConditionalGeneration, audio_16k: np.ndarray) -> str: input_features = PROCESSOR( audio_16k, return_tensors="pt", sampling_rate=TARGET_SR, ).input_features input_features = input_features.to(DEVICE) if DEVICE == "cuda": input_features = input_features.to(dtype=DTYPE) generated_ids = model.generate(input_features=input_features) transcription = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] return transcription.strip() def transcribe_openai(audio_path: str) -> str: if not os.getenv("OPENAI_API_KEY"): return "OpenAI ASR disabled: OPENAI_API_KEY not set in Space Secrets." prompt = "Chichewa transcription. Malawi names like Lilongwe, Blantyre, Zomba. Keep local names as spoken." with open(audio_path, "rb") as f: resp = OPENAI_CLIENT.audio.transcriptions.create( file=f, model=OPENAI_MODEL, prompt=prompt, temperature=0.0, response_format="json", ) return (resp.text or "").strip() def transcribe_selected(audio_path: Optional[str], which: str) -> Tuple[str, str]: """ Transcribe using a single selected system. Parameters: audio_path: filepath from Gradio audio component which: "Base" | "Fine-tuned" | "OpenAI" Returns: status, hypothesis_text """ if not audio_path: return "Please record or upload an audio clip.", "" # Load audio only for local models y = None if which in ["Base", "Fine-tuned"]: try: y, sr, dur = load_audio(audio_path) except Exception as e: return f"❌ Failed to load audio: {e}", "" t0 = time.time() try: if which == "Base": hyp = transcribe_local(MODEL_BASE, y) elif which == "Fine-tuned": hyp = transcribe_local(MODEL_FT, y) elif which == "OpenAI": hyp = transcribe_openai(audio_path) else: return f"Unknown model selection: {which}", "" except Exception as e: return f"❌ {which} failed: {e}", "" return f"✅ {which} done in {time.time() - t0:.2f}s", (hyp or "").strip() # ----------------------------- # Helpers: diff visualization # ----------------------------- def _tokenize_words(s: str): return re.findall(r"\w+|[^\w\s]", s, flags=re.UNICODE) def diff_highlight_html(ref: str, hyp: str, title_ref="Reference", title_hyp="Hypothesis") -> str: """ Returns HTML showing a word-level diff between ref and hyp. - deletions (in ref not in hyp): red + strikethrough (shown on reference side) - insertions (in hyp not in ref): green (shown on hypothesis side) - replacements: red struck old (ref) + green new (hyp) """ ref_toks = _tokenize_words(ref or "") hyp_toks = _tokenize_words(hyp or "") sm = SequenceMatcher(a=ref_toks, b=hyp_toks) ref_out, hyp_out = [], [] for tag, i1, i2, j1, j2 in sm.get_opcodes(): a = ref_toks[i1:i2] b = hyp_toks[j1:j2] if tag == "equal": ref_out += [html.escape(t) for t in a] hyp_out += [html.escape(t) for t in b] elif tag == "delete": ref_out += [ f"{html.escape(t)}" for t in a ] elif tag == "insert": hyp_out += [ f"{html.escape(t)}" for t in b ] elif tag == "replace": ref_out += [ f"{html.escape(t)}" for t in a ] hyp_out += [ f"{html.escape(t)}" for t in b ] def _join(tokens): s = " ".join(tokens) s = re.sub(r"\s+([,.;:!?])", r"\1", s) s = re.sub(r"\(\s+", "(", s) s = re.sub(r"\s+\)", ")", s) return s ref_html = _join(ref_out) hyp_html = _join(hyp_out) return f"""