michsethowusu's picture
Rename app(6).py to app.py
c266968 verified
"""
ASR Benchmark — Gradio HuggingFace Space
Evaluates Wav2Vec2-BERT and Gemma-3n models on Ghanaian language datasets.
"""
import os
import gc
import warnings
import tempfile
import torch
import numpy as np
import pandas as pd
import torchaudio
import jiwer
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from datasets import load_dataset
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
Wav2Vec2BertForCTC,
AutoProcessor,
)
warnings.filterwarnings("ignore")
# ── Device ──────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Model & dataset registries ───────────────────────────────────────────────
WAV2VEC_MODELS = {
"Dagbani (w2v-bert-2.0)": "FarmerlineML/w2v-bert-2.0_2026_dagbani_ASR",
"Ewe (w2v-bert-2.0)": "FarmerlineML/w2v-bert-2.0_ewe_2",
"Twi (w2v-bert-2.0)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1",
}
GEMMA_MODEL_ID = "unsloth/gemma-3n-E2B-it"
DATASETS = {
"UNICEF Dagbani": "ghananlpcommunity/UNICEF-Ghana-Dagbani-ASR",
"UNICEF Twi": "ghananlpcommunity/UNICEF-Ghana-Twi-ASR",
"UNICEF Ewe": "ghananlpcommunity/UNICEF-Ghana-Ewe-ASR",
"Bible Ewe": "ghananlpcommunity/ewe-bible-audio-text-tts",
"Bible Dagbani": "ghananlpcommunity/dagbani-bible-audio-text-tts",
"Bible Twi": "ghananlpcommunity/asante-twi-bible-speech-text",
}
# Which wav2vec model is compatible with which dataset
DATASET_MODEL_COMPAT = {
"UNICEF Dagbani": ["Dagbani (w2v-bert-2.0)", "Gemma-3n (E2B-it)"],
"UNICEF Twi": ["Twi (w2v-bert-2.0)", "Gemma-3n (E2B-it)"],
"UNICEF Ewe": ["Ewe (w2v-bert-2.0)", "Gemma-3n (E2B-it)"],
"Bible Dagbani": ["Dagbani (w2v-bert-2.0)", "Gemma-3n (E2B-it)"],
"Bible Twi": ["Twi (w2v-bert-2.0)", "Gemma-3n (E2B-it)"],
"Bible Ewe": ["Ewe (w2v-bert-2.0)", "Gemma-3n (E2B-it)"],
}
BLUE = "#185FA5"
ORANGE = "#D85A30"
# ── Helpers ──────────────────────────────────────────────────────────────────
def load_dataset_samples(dataset_name: str, n_samples: int) -> list[dict]:
ds_path = DATASETS[dataset_name]
ds = load_dataset(ds_path, split="train", streaming=True, trust_remote_code=True)
samples = []
for i, sample in enumerate(ds):
if i >= n_samples:
break
audio = sample.get("audio")
if audio is None:
continue
transcript = sample.get("text") or sample.get("transcript")
if not transcript:
continue
samples.append({
"audio": audio["array"],
"sampling_rate": audio["sampling_rate"],
"text": transcript.lower().strip(),
})
return samples
def calculate_metrics(reference: str, hypothesis: str):
if not hypothesis or not reference:
return None, None
return jiwer.cer(reference, hypothesis), jiwer.wer(reference, hypothesis)
def resample_to_16k(audio_array: np.ndarray, sr: int) -> np.ndarray:
if sr == 16000:
return audio_array
t = torch.from_numpy(audio_array).float()
return torchaudio.transforms.Resample(sr, 16000)(t).numpy()
# ── Wav2Vec2 inference ───────────────────────────────────────────────────────
class Wav2VecModel:
def __init__(self, model_id: str):
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = Wav2Vec2BertForCTC.from_pretrained(
model_id,
ignore_mismatched_sizes=True,
torch_dtype=torch.float16,
).to(device)
self.model.eval()
def transcribe(self, audio_array: np.ndarray, sampling_rate: int) -> str:
audio_16k = resample_to_16k(audio_array, sampling_rate)
inputs = self.processor(
audio_16k,
sampling_rate=16000,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
return self.processor.decode(predicted_ids[0]).lower().strip()
def unload(self):
del self.model, self.processor
torch.cuda.empty_cache()
gc.collect()
# ── Gemma-3n inference (Unsloth) ─────────────────────────────────────────────
class Gemma3nModel:
def __init__(self):
from unsloth import FastModel
self.model, self.processor = FastModel.from_pretrained(
model_name = GEMMA_MODEL_ID,
dtype = None,
max_seq_length = 1024,
load_in_4bit = True,
full_finetuning = False,
)
self.model.eval()
def transcribe(self, audio_array: np.ndarray, sampling_rate: int) -> str:
audio_16k = resample_to_16k(audio_array, sampling_rate)
messages = [
{"role": "system", "content": [{"type": "text",
"text": "You are an assistant that transcribes speech accurately."}]},
{"role": "user", "content": [
{"type": "audio", "audio": audio_16k},
{"type": "text", "text": "Please transcribe this audio."},
]},
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to("cuda")
inputs = {
k: v.to(torch.bfloat16) if v.dtype in (torch.float32, torch.float64) else v
for k, v in inputs.items()
}
with torch.inference_mode():
gen = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
tokens = gen[0][inputs["input_ids"].shape[-1]:]
return self.processor.decode(tokens, skip_special_tokens=True).lower().strip()
def unload(self):
del self.model, self.processor
torch.cuda.empty_cache()
gc.collect()
# ── Chart helpers ─────────────────────────────────────────────────────────────
def make_cer_wer_chart(df: pd.DataFrame) -> str:
df = df.dropna(subset=["Avg WER", "Avg CER"]).copy()
df["CER %"] = df["Avg CER"] * 100
df["WER %"] = df["Avg WER"] * 100
df["Model Short"] = df["Model"].apply(lambda x: x.split("/")[-1] if "/" in x else x)
models = df["Model"].unique()
ncols = min(len(models), 2)
nrows = (len(models) + 1) // 2
fig, axes = plt.subplots(nrows, ncols, figsize=(8 * ncols, 5 * nrows))
axes = np.array(axes).flatten() if len(models) > 1 else [axes]
fig.suptitle("CER & WER per Model / Dataset", fontsize=16, fontweight="bold")
for i, model in enumerate(models):
ax = axes[i]
sub = df[df["Model"] == model].sort_values("Dataset")
x = np.arange(len(sub))
w = 0.35
ax.bar(x - w / 2, sub["CER %"], w, label="CER %", color=BLUE, alpha=0.88)
ax.bar(x + w / 2, sub["WER %"], w, label="WER %", color=ORANGE, alpha=0.88)
ax.set_title(sub["Model Short"].iloc[0], fontsize=12, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(sub["Dataset"], rotation=30, ha="right", fontsize=9)
ax.set_ylabel("Error Rate (%)")
mx = max(sub["CER %"].max(), sub["WER %"].max()) * 1.15 or 100
ax.set_ylim(0, mx)
for j in range(len(x)):
ax.text(j - w/2, sub["CER %"].iloc[j] + mx*0.01, f"{sub['CER %'].iloc[j]:.1f}", ha="center", fontsize=8)
ax.text(j + w/2, sub["WER %"].iloc[j] + mx*0.01, f"{sub['WER %'].iloc[j]:.1f}", ha="center", fontsize=8)
ax.legend(fontsize=9)
ax.spines[["top","right"]].set_visible(False)
ax.grid(axis="y", linestyle="--", alpha=0.3)
for j in range(len(models), len(axes)):
axes[j].axis("off")
plt.tight_layout()
path = tempfile.mktemp(suffix="_cer_wer.png")
plt.savefig(path, dpi=130, bbox_inches="tight")
plt.close(fig)
return path
def make_domain_chart(df: pd.DataFrame) -> str:
df = df.dropna(subset=["Avg WER"]).copy()
df["Domain"] = df["Dataset"].apply(lambda x: "UNICEF" if "unicef" in x.lower() or "UNICEF" in x else "Bible")
df["Language"] = df["Dataset"].apply(
lambda x: x.replace("UNICEF ","").replace("Bible ","").strip()
)
df["WER %"] = df["Avg WER"] * 100
df["Model Short"] = df["Model"].apply(lambda x: x.split("/")[-1] if "/" in x else x)
models = df["Model Short"].unique()
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle("Domain Comparison — UNICEF vs Bible", fontsize=15, fontweight="bold")
# Avg WER per model × domain
ax = axes[0]
u_avg = [df[(df["Model Short"]==m) & (df["Domain"]=="UNICEF")]["WER %"].mean() for m in models]
b_avg = [df[(df["Model Short"]==m) & (df["Domain"]=="Bible")]["WER %"].mean() for m in models]
x = np.arange(len(models))
w = 0.35
ax.bar(x - w/2, u_avg, w, color=BLUE, alpha=0.88, label="UNICEF")
ax.bar(x + w/2, b_avg, w, color=ORANGE, alpha=0.88, label="Bible")
ax.set_title("Avg WER per Model × Domain", fontsize=11)
ax.set_xticks(x)
ax.set_xticklabels(models, rotation=20, ha="right", fontsize=9)
ax.set_ylabel("WER %"); ax.set_ylim(0, 100)
ax.legend(); ax.spines[["top","right"]].set_visible(False)
# Domain gap bar
ax2 = axes[1]
gaps = [b - u for u, b in zip(u_avg, b_avg)]
colors = [ORANGE if g > 0 else BLUE for g in gaps]
ax2.bar(range(len(models)), gaps, color=colors, alpha=0.88)
ax2.axhline(0, color="#888", linewidth=0.8, linestyle="--")
ax2.set_title("Domain gap: Bible − UNICEF (pp)", fontsize=11)
ax2.set_xticks(range(len(models)))
ax2.set_xticklabels(models, rotation=20, ha="right", fontsize=9)
ax2.set_ylabel("Percentage points"); ax2.spines[["top","right"]].set_visible(False)
plt.tight_layout()
path = tempfile.mktemp(suffix="_domain.png")
plt.savefig(path, dpi=130, bbox_inches="tight")
plt.close(fig)
return path
# ── Core evaluation function ──────────────────────────────────────────────────
def run_evaluation(
selected_models: list[str],
selected_datasets: list[str],
n_samples: int,
progress=gr.Progress(track_tqdm=True),
):
if not selected_models:
return "⚠️ Please select at least one model.", None, None, None
if not selected_datasets:
return "⚠️ Please select at least one dataset.", None, None, None
results_data = []
log_lines = []
all_model_labels = list(WAV2VEC_MODELS.keys()) + ["Gemma-3n (E2B-it)"]
models_to_run = [m for m in all_model_labels if m in selected_models]
for model_label in progress.tqdm(models_to_run, desc="Models"):
log_lines.append(f"\n🔄 Loading model: {model_label}")
yield "\n".join(log_lines), None, None, None
# Load model
try:
if model_label == "Gemma-3n (E2B-it)":
model = Gemma3nModel()
model_id = GEMMA_MODEL_ID
else:
model_id = WAV2VEC_MODELS[model_label]
model = Wav2VecModel(model_id)
log_lines.append(f" ✅ Loaded: {model_id}")
except Exception as e:
log_lines.append(f" ❌ Failed to load {model_label}: {e}")
yield "\n".join(log_lines), None, None, None
continue
# Evaluate on each selected compatible dataset
for ds_label in selected_datasets:
compat = DATASET_MODEL_COMPAT.get(ds_label, [])
if model_label not in compat:
log_lines.append(f" ⏭️ Skipping {ds_label} (incompatible language)")
continue
log_lines.append(f" 📂 Dataset: {ds_label} ({n_samples} samples)")
yield "\n".join(log_lines), None, None, None
try:
samples = load_dataset_samples(ds_label, n_samples)
except Exception as e:
log_lines.append(f" ❌ Failed to load dataset: {e}")
continue
cers, wers, details = [], [], []
for j, sample in enumerate(samples):
try:
hyp = model.transcribe(sample["audio"], sample["sampling_rate"])
cer, wer = calculate_metrics(sample["text"], hyp)
if cer is not None:
cers.append(cer); wers.append(wer)
details.append({
"ref": sample["text"][:80],
"hyp": hyp[:80],
"CER": round(cer, 4),
"WER": round(wer, 4),
})
except Exception as e:
log_lines.append(f" ⚠️ Sample {j} error: {e}")
if cers:
avg_cer = float(np.mean(cers))
avg_wer = float(np.mean(wers))
results_data.append({
"Model": model_id,
"Dataset": ds_label,
"Avg CER": round(avg_cer, 4),
"Avg WER": round(avg_wer, 4),
"Num Samples": len(cers),
})
log_lines.append(f" ✅ CER={avg_cer*100:.1f}% WER={avg_wer*100:.1f}%")
else:
log_lines.append(" ⚠️ No valid results for this combo.")
# Unload model to free memory
model.unload()
log_lines.append(f" 🧹 {model_label} unloaded.\n")
yield "\n".join(log_lines), None, None, None
# Build outputs
if not results_data:
yield "\n".join(log_lines) + "\n\n❌ No results collected.", None, None, None
return
df = pd.DataFrame(results_data)
log_lines.append("✅ Evaluation complete!")
try:
chart1 = make_cer_wer_chart(df)
chart2 = make_domain_chart(df)
except Exception as e:
log_lines.append(f"⚠️ Chart generation failed: {e}")
chart1, chart2 = None, None
yield "\n".join(log_lines), df, chart1, chart2
# ── Gradio UI ─────────────────────────────────────────────────────────────────
ALL_MODELS = list(WAV2VEC_MODELS.keys()) + ["Gemma-3n (E2B-it)"]
ALL_DATASETS = list(DATASETS.keys())
with gr.Blocks(title="ASR Benchmark — Ghanaian Languages", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎙️ ASR Benchmark — Ghanaian Languages
Evaluate **Wav2Vec2-BERT** fine-tuned models and **Gemma-3n** on Dagbani, Twi, and Ewe speech datasets.
Models are loaded one at a time and unloaded after evaluation to minimise VRAM usage.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Configuration")
model_selector = gr.CheckboxGroup(
choices=ALL_MODELS,
value=list(WAV2VEC_MODELS.keys()),
label="Models to evaluate",
)
dataset_selector = gr.CheckboxGroup(
choices=ALL_DATASETS,
value=ALL_DATASETS,
label="Datasets to include",
)
n_samples_slider = gr.Slider(
minimum=5, maximum=200, value=50, step=5,
label="Samples per dataset",
info="More samples = better estimates but longer runtime.",
)
run_btn = gr.Button("▶ Run Evaluation", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 📋 Evaluation Log")
log_box = gr.Textbox(
label="",
lines=20,
max_lines=40,
interactive=False,
placeholder="Logs will appear here once evaluation starts...",
)
gr.Markdown("### 📊 Results")
results_table = gr.DataFrame(
label="Summary (CER & WER per model × dataset)",
interactive=False,
wrap=True,
)
with gr.Row():
chart_cer_wer = gr.Image(label="CER & WER per Model", type="filepath")
chart_domain = gr.Image(label="Domain Comparison", type="filepath")
run_btn.click(
fn=run_evaluation,
inputs=[model_selector, dataset_selector, n_samples_slider],
outputs=[log_box, results_table, chart_cer_wer, chart_domain],
)
gr.Markdown(
"""
---
**Models:** [FarmerlineML](https://huggingface.co/FarmerlineML) |
**Datasets:** [ghananlpcommunity](https://huggingface.co/ghananlpcommunity) |
**Framework:** [Unsloth](https://github.com/unslothai/unsloth) + [🤗 Transformers](https://huggingface.co/transformers)
"""
)
if __name__ == "__main__":
demo.launch()