Model Card for Model ID
Model Details
Model Description
Openwhisper Medium Fine tuned on Hindi Dataset
4.35% WER on Hindi language in Whisper Medium. Apache 2 license. On this dataset here
https://github.com/AI4Bharat/vistaar (KathBath)
import json
import soundfile as sf
from tqdm import tqdm
import jiwer
import re
import unicodedata
import torch
from transformers import pipeline
# ==========================
# 0. Hindi Text Normalization
# ==========================
def normalize_hindi_text(text: str) -> str:
"""Normalize Hindi text for ASR scoring."""
if text is None or text.strip() == "":
return "<unk>"
# Unicode normalization
text = unicodedata.normalize("NFC", text)
# Remove punctuation except Hindi letters
text = re.sub(r"[^\w\s\u0900-\u097F]", " ", text)
text = text.replace("।", " ")
# Lowercase and clean
text = text.lower()
words = text.split()
# Remove duplicates (Whisper repetition)
words = list(dict.fromkeys(words))
text = " ".join(words)
text = re.sub(r"\s+", " ", text).strip()
if not text:
return "<unk>"
return text
# ==========================
# 1. Load ASR Model
# ==========================
MODEL_DIR = "path-to-model"
BATCH_SIZE = 8 # Adjust based on GPU memory
print("Loading ASR model...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
asr = pipeline(
"automatic-speech-recognition",
model=MODEL_DIR,
device=0 if torch.cuda.is_available() else -1,
batch_size=BATCH_SIZE,
chunk_length_s=30,
)
# ==========================
# 2. Load Manifest
# ==========================
MANIFEST_PATH = "manifest.json"
print(f"Loading manifest: {MANIFEST_PATH}")
with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
manifest = [json.loads(line) for line in f]
print(f"Total samples: {len(manifest)}")
# ==========================
# 3. Process in Batches
# ==========================
all_predictions = []
all_references = []
print("\nRunning ASR inference...")
# Load all audio first
audio_data = []
for entry in tqdm(manifest, desc="Loading audio"):
try:
audio, sr = sf.read(entry["audio_filepath"])
audio_data.append({
"audio": audio,
"reference": entry["text"],
})
except Exception as e:
print(f"Error loading {entry['audio_filepath']}: {e}")
# Process in batches
print(f"\nProcessing {len(audio_data)} samples in batches of {BATCH_SIZE}...")
for i in tqdm(range(0, len(audio_data), BATCH_SIZE), desc="Inference"):
batch = audio_data[i:i + BATCH_SIZE]
audio_arrays = [item["audio"] for item in batch]
# Run inference on batch
results = asr(audio_arrays)
# Results is a list of dicts when batch_size > 1
for j, result in enumerate(results):
# Extract text
pred_text = result["text"] if isinstance(result, dict) else str(result)
ref_text = batch[j]["reference"]
# Normalize
pred_norm = normalize_hindi_text(pred_text)
ref_norm = normalize_hindi_text(ref_text)
# Store
all_predictions.append(pred_norm)
all_references.append(ref_norm)
# ==========================
# 4. Compute WER/CER
# ==========================
print("\n" + "=" * 60)
print("Computing metrics...")
# Filter out any empty pairs
valid_preds = []
valid_refs = []
for pred, ref in zip(all_predictions, all_references):
# Keep only Hindi characters for final WER computation
pred_clean = re.sub(r'[^\u0900-\u097F\s]', '', pred).strip()
ref_clean = re.sub(r'[^\u0900-\u097F\s]', '', ref).strip()
if pred_clean and ref_clean:
valid_preds.append(pred_clean)
valid_refs.append(ref_clean)
print(f"Valid pairs: {len(valid_refs)} / {len(all_references)}")
if valid_refs and valid_preds:
# Compute WER and CER
wer = jiwer.wer(valid_refs, valid_preds)
cer = jiwer.cer(valid_refs, valid_preds)
print("=" * 60)
print(f"WER: {wer * 100:.2f}%")
print(f"CER: {cer * 100:.2f}%")
print("=" * 60)
# Show some examples
print("\nSample predictions:")
for i in range(min(3, len(valid_refs))):
print(f"\n{i + 1}. Reference: {valid_refs[i]}")
print(f" Prediction: {valid_preds[i]}")
else:
print("ERROR: No valid samples for evaluation!")
- Downloads last month
- -