| """ |
| ASR Benchmark — Gradio HuggingFace Space |
| Evaluates Wav2Vec2-BERT and Gemma-3n models on Ghanaian language datasets. |
| """ |
|
|
| import os |
| import gc |
| import warnings |
| import tempfile |
|
|
| import torch |
| import numpy as np |
| import pandas as pd |
| import torchaudio |
| import jiwer |
| import gradio as gr |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| Wav2Vec2ForCTC, |
| Wav2Vec2Processor, |
| Wav2Vec2BertForCTC, |
| AutoProcessor, |
| ) |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| WAV2VEC_MODELS = { |
| "Dagbani (w2v-bert-2.0)": "FarmerlineML/w2v-bert-2.0_2026_dagbani_ASR", |
| "Ewe (w2v-bert-2.0)": "FarmerlineML/w2v-bert-2.0_ewe_2", |
| "Twi (w2v-bert-2.0)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1", |
| } |
|
|
| GEMMA_MODEL_ID = "unsloth/gemma-3n-E2B-it" |
|
|
| DATASETS = { |
| "UNICEF Dagbani": "ghananlpcommunity/UNICEF-Ghana-Dagbani-ASR", |
| "UNICEF Twi": "ghananlpcommunity/UNICEF-Ghana-Twi-ASR", |
| "UNICEF Ewe": "ghananlpcommunity/UNICEF-Ghana-Ewe-ASR", |
| "Bible Ewe": "ghananlpcommunity/ewe-bible-audio-text-tts", |
| "Bible Dagbani": "ghananlpcommunity/dagbani-bible-audio-text-tts", |
| "Bible Twi": "ghananlpcommunity/asante-twi-bible-speech-text", |
| } |
|
|
| |
| DATASET_MODEL_COMPAT = { |
| "UNICEF Dagbani": ["Dagbani (w2v-bert-2.0)", "Gemma-3n (E2B-it)"], |
| "UNICEF Twi": ["Twi (w2v-bert-2.0)", "Gemma-3n (E2B-it)"], |
| "UNICEF Ewe": ["Ewe (w2v-bert-2.0)", "Gemma-3n (E2B-it)"], |
| "Bible Dagbani": ["Dagbani (w2v-bert-2.0)", "Gemma-3n (E2B-it)"], |
| "Bible Twi": ["Twi (w2v-bert-2.0)", "Gemma-3n (E2B-it)"], |
| "Bible Ewe": ["Ewe (w2v-bert-2.0)", "Gemma-3n (E2B-it)"], |
| } |
|
|
| BLUE = "#185FA5" |
| ORANGE = "#D85A30" |
|
|
| |
|
|
| def load_dataset_samples(dataset_name: str, n_samples: int) -> list[dict]: |
| ds_path = DATASETS[dataset_name] |
| ds = load_dataset(ds_path, split="train", streaming=True, trust_remote_code=True) |
| samples = [] |
| for i, sample in enumerate(ds): |
| if i >= n_samples: |
| break |
| audio = sample.get("audio") |
| if audio is None: |
| continue |
| transcript = sample.get("text") or sample.get("transcript") |
| if not transcript: |
| continue |
| samples.append({ |
| "audio": audio["array"], |
| "sampling_rate": audio["sampling_rate"], |
| "text": transcript.lower().strip(), |
| }) |
| return samples |
|
|
|
|
| def calculate_metrics(reference: str, hypothesis: str): |
| if not hypothesis or not reference: |
| return None, None |
| return jiwer.cer(reference, hypothesis), jiwer.wer(reference, hypothesis) |
|
|
|
|
| def resample_to_16k(audio_array: np.ndarray, sr: int) -> np.ndarray: |
| if sr == 16000: |
| return audio_array |
| t = torch.from_numpy(audio_array).float() |
| return torchaudio.transforms.Resample(sr, 16000)(t).numpy() |
|
|
|
|
| |
|
|
| class Wav2VecModel: |
| def __init__(self, model_id: str): |
| self.processor = AutoProcessor.from_pretrained(model_id) |
| self.model = Wav2Vec2BertForCTC.from_pretrained( |
| model_id, |
| ignore_mismatched_sizes=True, |
| torch_dtype=torch.float16, |
| ).to(device) |
| self.model.eval() |
|
|
| def transcribe(self, audio_array: np.ndarray, sampling_rate: int) -> str: |
| audio_16k = resample_to_16k(audio_array, sampling_rate) |
| inputs = self.processor( |
| audio_16k, |
| sampling_rate=16000, |
| return_tensors="pt", |
| padding=True, |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| logits = self.model(**inputs).logits |
| predicted_ids = torch.argmax(logits, dim=-1) |
| return self.processor.decode(predicted_ids[0]).lower().strip() |
|
|
| def unload(self): |
| del self.model, self.processor |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
|
|
| |
|
|
| class Gemma3nModel: |
| def __init__(self): |
| from unsloth import FastModel |
| self.model, self.processor = FastModel.from_pretrained( |
| model_name = GEMMA_MODEL_ID, |
| dtype = None, |
| max_seq_length = 1024, |
| load_in_4bit = True, |
| full_finetuning = False, |
| ) |
| self.model.eval() |
|
|
| def transcribe(self, audio_array: np.ndarray, sampling_rate: int) -> str: |
| audio_16k = resample_to_16k(audio_array, sampling_rate) |
| messages = [ |
| {"role": "system", "content": [{"type": "text", |
| "text": "You are an assistant that transcribes speech accurately."}]}, |
| {"role": "user", "content": [ |
| {"type": "audio", "audio": audio_16k}, |
| {"type": "text", "text": "Please transcribe this audio."}, |
| ]}, |
| ] |
| inputs = self.processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_tensors="pt", |
| return_dict=True, |
| ).to("cuda") |
| inputs = { |
| k: v.to(torch.bfloat16) if v.dtype in (torch.float32, torch.float64) else v |
| for k, v in inputs.items() |
| } |
| with torch.inference_mode(): |
| gen = self.model.generate(**inputs, max_new_tokens=200, do_sample=False) |
| tokens = gen[0][inputs["input_ids"].shape[-1]:] |
| return self.processor.decode(tokens, skip_special_tokens=True).lower().strip() |
|
|
| def unload(self): |
| del self.model, self.processor |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
|
|
| |
|
|
| def make_cer_wer_chart(df: pd.DataFrame) -> str: |
| df = df.dropna(subset=["Avg WER", "Avg CER"]).copy() |
| df["CER %"] = df["Avg CER"] * 100 |
| df["WER %"] = df["Avg WER"] * 100 |
| df["Model Short"] = df["Model"].apply(lambda x: x.split("/")[-1] if "/" in x else x) |
|
|
| models = df["Model"].unique() |
| ncols = min(len(models), 2) |
| nrows = (len(models) + 1) // 2 |
| fig, axes = plt.subplots(nrows, ncols, figsize=(8 * ncols, 5 * nrows)) |
| axes = np.array(axes).flatten() if len(models) > 1 else [axes] |
|
|
| fig.suptitle("CER & WER per Model / Dataset", fontsize=16, fontweight="bold") |
|
|
| for i, model in enumerate(models): |
| ax = axes[i] |
| sub = df[df["Model"] == model].sort_values("Dataset") |
| x = np.arange(len(sub)) |
| w = 0.35 |
| ax.bar(x - w / 2, sub["CER %"], w, label="CER %", color=BLUE, alpha=0.88) |
| ax.bar(x + w / 2, sub["WER %"], w, label="WER %", color=ORANGE, alpha=0.88) |
| ax.set_title(sub["Model Short"].iloc[0], fontsize=12, fontweight="bold") |
| ax.set_xticks(x) |
| ax.set_xticklabels(sub["Dataset"], rotation=30, ha="right", fontsize=9) |
| ax.set_ylabel("Error Rate (%)") |
| mx = max(sub["CER %"].max(), sub["WER %"].max()) * 1.15 or 100 |
| ax.set_ylim(0, mx) |
| for j in range(len(x)): |
| ax.text(j - w/2, sub["CER %"].iloc[j] + mx*0.01, f"{sub['CER %'].iloc[j]:.1f}", ha="center", fontsize=8) |
| ax.text(j + w/2, sub["WER %"].iloc[j] + mx*0.01, f"{sub['WER %'].iloc[j]:.1f}", ha="center", fontsize=8) |
| ax.legend(fontsize=9) |
| ax.spines[["top","right"]].set_visible(False) |
| ax.grid(axis="y", linestyle="--", alpha=0.3) |
|
|
| for j in range(len(models), len(axes)): |
| axes[j].axis("off") |
|
|
| plt.tight_layout() |
| path = tempfile.mktemp(suffix="_cer_wer.png") |
| plt.savefig(path, dpi=130, bbox_inches="tight") |
| plt.close(fig) |
| return path |
|
|
|
|
| def make_domain_chart(df: pd.DataFrame) -> str: |
| df = df.dropna(subset=["Avg WER"]).copy() |
| df["Domain"] = df["Dataset"].apply(lambda x: "UNICEF" if "unicef" in x.lower() or "UNICEF" in x else "Bible") |
| df["Language"] = df["Dataset"].apply( |
| lambda x: x.replace("UNICEF ","").replace("Bible ","").strip() |
| ) |
| df["WER %"] = df["Avg WER"] * 100 |
| df["Model Short"] = df["Model"].apply(lambda x: x.split("/")[-1] if "/" in x else x) |
|
|
| models = df["Model Short"].unique() |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| fig.suptitle("Domain Comparison — UNICEF vs Bible", fontsize=15, fontweight="bold") |
|
|
| |
| ax = axes[0] |
| u_avg = [df[(df["Model Short"]==m) & (df["Domain"]=="UNICEF")]["WER %"].mean() for m in models] |
| b_avg = [df[(df["Model Short"]==m) & (df["Domain"]=="Bible")]["WER %"].mean() for m in models] |
| x = np.arange(len(models)) |
| w = 0.35 |
| ax.bar(x - w/2, u_avg, w, color=BLUE, alpha=0.88, label="UNICEF") |
| ax.bar(x + w/2, b_avg, w, color=ORANGE, alpha=0.88, label="Bible") |
| ax.set_title("Avg WER per Model × Domain", fontsize=11) |
| ax.set_xticks(x) |
| ax.set_xticklabels(models, rotation=20, ha="right", fontsize=9) |
| ax.set_ylabel("WER %"); ax.set_ylim(0, 100) |
| ax.legend(); ax.spines[["top","right"]].set_visible(False) |
|
|
| |
| ax2 = axes[1] |
| gaps = [b - u for u, b in zip(u_avg, b_avg)] |
| colors = [ORANGE if g > 0 else BLUE for g in gaps] |
| ax2.bar(range(len(models)), gaps, color=colors, alpha=0.88) |
| ax2.axhline(0, color="#888", linewidth=0.8, linestyle="--") |
| ax2.set_title("Domain gap: Bible − UNICEF (pp)", fontsize=11) |
| ax2.set_xticks(range(len(models))) |
| ax2.set_xticklabels(models, rotation=20, ha="right", fontsize=9) |
| ax2.set_ylabel("Percentage points"); ax2.spines[["top","right"]].set_visible(False) |
|
|
| plt.tight_layout() |
| path = tempfile.mktemp(suffix="_domain.png") |
| plt.savefig(path, dpi=130, bbox_inches="tight") |
| plt.close(fig) |
| return path |
|
|
|
|
| |
|
|
| def run_evaluation( |
| selected_models: list[str], |
| selected_datasets: list[str], |
| n_samples: int, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| if not selected_models: |
| return "⚠️ Please select at least one model.", None, None, None |
| if not selected_datasets: |
| return "⚠️ Please select at least one dataset.", None, None, None |
|
|
| results_data = [] |
| log_lines = [] |
|
|
| all_model_labels = list(WAV2VEC_MODELS.keys()) + ["Gemma-3n (E2B-it)"] |
| models_to_run = [m for m in all_model_labels if m in selected_models] |
|
|
| for model_label in progress.tqdm(models_to_run, desc="Models"): |
| log_lines.append(f"\n🔄 Loading model: {model_label}") |
| yield "\n".join(log_lines), None, None, None |
|
|
| |
| try: |
| if model_label == "Gemma-3n (E2B-it)": |
| model = Gemma3nModel() |
| model_id = GEMMA_MODEL_ID |
| else: |
| model_id = WAV2VEC_MODELS[model_label] |
| model = Wav2VecModel(model_id) |
| log_lines.append(f" ✅ Loaded: {model_id}") |
| except Exception as e: |
| log_lines.append(f" ❌ Failed to load {model_label}: {e}") |
| yield "\n".join(log_lines), None, None, None |
| continue |
|
|
| |
| for ds_label in selected_datasets: |
| compat = DATASET_MODEL_COMPAT.get(ds_label, []) |
| if model_label not in compat: |
| log_lines.append(f" ⏭️ Skipping {ds_label} (incompatible language)") |
| continue |
|
|
| log_lines.append(f" 📂 Dataset: {ds_label} ({n_samples} samples)") |
| yield "\n".join(log_lines), None, None, None |
|
|
| try: |
| samples = load_dataset_samples(ds_label, n_samples) |
| except Exception as e: |
| log_lines.append(f" ❌ Failed to load dataset: {e}") |
| continue |
|
|
| cers, wers, details = [], [], [] |
| for j, sample in enumerate(samples): |
| try: |
| hyp = model.transcribe(sample["audio"], sample["sampling_rate"]) |
| cer, wer = calculate_metrics(sample["text"], hyp) |
| if cer is not None: |
| cers.append(cer); wers.append(wer) |
| details.append({ |
| "ref": sample["text"][:80], |
| "hyp": hyp[:80], |
| "CER": round(cer, 4), |
| "WER": round(wer, 4), |
| }) |
| except Exception as e: |
| log_lines.append(f" ⚠️ Sample {j} error: {e}") |
|
|
| if cers: |
| avg_cer = float(np.mean(cers)) |
| avg_wer = float(np.mean(wers)) |
| results_data.append({ |
| "Model": model_id, |
| "Dataset": ds_label, |
| "Avg CER": round(avg_cer, 4), |
| "Avg WER": round(avg_wer, 4), |
| "Num Samples": len(cers), |
| }) |
| log_lines.append(f" ✅ CER={avg_cer*100:.1f}% WER={avg_wer*100:.1f}%") |
| else: |
| log_lines.append(" ⚠️ No valid results for this combo.") |
|
|
| |
| model.unload() |
| log_lines.append(f" 🧹 {model_label} unloaded.\n") |
| yield "\n".join(log_lines), None, None, None |
|
|
| |
| if not results_data: |
| yield "\n".join(log_lines) + "\n\n❌ No results collected.", None, None, None |
| return |
|
|
| df = pd.DataFrame(results_data) |
| log_lines.append("✅ Evaluation complete!") |
|
|
| try: |
| chart1 = make_cer_wer_chart(df) |
| chart2 = make_domain_chart(df) |
| except Exception as e: |
| log_lines.append(f"⚠️ Chart generation failed: {e}") |
| chart1, chart2 = None, None |
|
|
| yield "\n".join(log_lines), df, chart1, chart2 |
|
|
|
|
| |
|
|
| ALL_MODELS = list(WAV2VEC_MODELS.keys()) + ["Gemma-3n (E2B-it)"] |
| ALL_DATASETS = list(DATASETS.keys()) |
|
|
| with gr.Blocks(title="ASR Benchmark — Ghanaian Languages", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # 🎙️ ASR Benchmark — Ghanaian Languages |
| Evaluate **Wav2Vec2-BERT** fine-tuned models and **Gemma-3n** on Dagbani, Twi, and Ewe speech datasets. |
| Models are loaded one at a time and unloaded after evaluation to minimise VRAM usage. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### ⚙️ Configuration") |
|
|
| model_selector = gr.CheckboxGroup( |
| choices=ALL_MODELS, |
| value=list(WAV2VEC_MODELS.keys()), |
| label="Models to evaluate", |
| ) |
| dataset_selector = gr.CheckboxGroup( |
| choices=ALL_DATASETS, |
| value=ALL_DATASETS, |
| label="Datasets to include", |
| ) |
| n_samples_slider = gr.Slider( |
| minimum=5, maximum=200, value=50, step=5, |
| label="Samples per dataset", |
| info="More samples = better estimates but longer runtime.", |
| ) |
| run_btn = gr.Button("▶ Run Evaluation", variant="primary") |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("### 📋 Evaluation Log") |
| log_box = gr.Textbox( |
| label="", |
| lines=20, |
| max_lines=40, |
| interactive=False, |
| placeholder="Logs will appear here once evaluation starts...", |
| ) |
|
|
| gr.Markdown("### 📊 Results") |
| results_table = gr.DataFrame( |
| label="Summary (CER & WER per model × dataset)", |
| interactive=False, |
| wrap=True, |
| ) |
|
|
| with gr.Row(): |
| chart_cer_wer = gr.Image(label="CER & WER per Model", type="filepath") |
| chart_domain = gr.Image(label="Domain Comparison", type="filepath") |
|
|
| run_btn.click( |
| fn=run_evaluation, |
| inputs=[model_selector, dataset_selector, n_samples_slider], |
| outputs=[log_box, results_table, chart_cer_wer, chart_domain], |
| ) |
|
|
| gr.Markdown( |
| """ |
| --- |
| **Models:** [FarmerlineML](https://huggingface.co/FarmerlineML) | |
| **Datasets:** [ghananlpcommunity](https://huggingface.co/ghananlpcommunity) | |
| **Framework:** [Unsloth](https://github.com/unslothai/unsloth) + [🤗 Transformers](https://huggingface.co/transformers) |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|