Model Card for Model ID

Model Details

Model Description

Openwhisper Medium Fine tuned on Hindi Dataset

4.35% WER on Hindi language in Whisper Medium. Apache 2 license. On this dataset here

https://github.com/AI4Bharat/vistaar (KathBath)

import json
import soundfile as sf
from tqdm import tqdm
import jiwer
import re
import unicodedata
import torch
from transformers import pipeline


# ==========================
# 0. Hindi Text Normalization
# ==========================
def normalize_hindi_text(text: str) -> str:
    """Normalize Hindi text for ASR scoring."""
    if text is None or text.strip() == "":
        return "<unk>"

    # Unicode normalization
    text = unicodedata.normalize("NFC", text)

    # Remove punctuation except Hindi letters
    text = re.sub(r"[^\w\s\u0900-\u097F]", " ", text)
    text = text.replace("।", " ")

    # Lowercase and clean
    text = text.lower()
    words = text.split()

    # Remove duplicates (Whisper repetition)
    words = list(dict.fromkeys(words))

    text = " ".join(words)
    text = re.sub(r"\s+", " ", text).strip()

    if not text:
        return "<unk>"

    return text


# ==========================
# 1. Load ASR Model
# ==========================
MODEL_DIR = "path-to-model"
BATCH_SIZE = 8  # Adjust based on GPU memory

print("Loading ASR model...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

asr = pipeline(
    "automatic-speech-recognition",
    model=MODEL_DIR,
    device=0 if torch.cuda.is_available() else -1,
    batch_size=BATCH_SIZE,
    chunk_length_s=30,
)

# ==========================
# 2. Load Manifest
# ==========================
MANIFEST_PATH = "manifest.json"

print(f"Loading manifest: {MANIFEST_PATH}")
with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
    manifest = [json.loads(line) for line in f]

print(f"Total samples: {len(manifest)}")

# ==========================
# 3. Process in Batches
# ==========================
all_predictions = []
all_references = []

print("\nRunning ASR inference...")

# Load all audio first
audio_data = []
for entry in tqdm(manifest, desc="Loading audio"):
    try:
        audio, sr = sf.read(entry["audio_filepath"])
        audio_data.append({
            "audio": audio,
            "reference": entry["text"],
        })
    except Exception as e:
        print(f"Error loading {entry['audio_filepath']}: {e}")

# Process in batches
print(f"\nProcessing {len(audio_data)} samples in batches of {BATCH_SIZE}...")

for i in tqdm(range(0, len(audio_data), BATCH_SIZE), desc="Inference"):
    batch = audio_data[i:i + BATCH_SIZE]
    audio_arrays = [item["audio"] for item in batch]

    # Run inference on batch
    results = asr(audio_arrays)

    # Results is a list of dicts when batch_size > 1
    for j, result in enumerate(results):
        # Extract text
        pred_text = result["text"] if isinstance(result, dict) else str(result)
        ref_text = batch[j]["reference"]

        # Normalize
        pred_norm = normalize_hindi_text(pred_text)
        ref_norm = normalize_hindi_text(ref_text)

        # Store
        all_predictions.append(pred_norm)
        all_references.append(ref_norm)

# ==========================
# 4. Compute WER/CER
# ==========================
print("\n" + "=" * 60)
print("Computing metrics...")

# Filter out any empty pairs
valid_preds = []
valid_refs = []

for pred, ref in zip(all_predictions, all_references):
    # Keep only Hindi characters for final WER computation
    pred_clean = re.sub(r'[^\u0900-\u097F\s]', '', pred).strip()
    ref_clean = re.sub(r'[^\u0900-\u097F\s]', '', ref).strip()

    if pred_clean and ref_clean:
        valid_preds.append(pred_clean)
        valid_refs.append(ref_clean)

print(f"Valid pairs: {len(valid_refs)} / {len(all_references)}")

if valid_refs and valid_preds:
    # Compute WER and CER
    wer = jiwer.wer(valid_refs, valid_preds)
    cer = jiwer.cer(valid_refs, valid_preds)

    print("=" * 60)
    print(f"WER: {wer * 100:.2f}%")
    print(f"CER: {cer * 100:.2f}%")
    print("=" * 60)

    # Show some examples
    print("\nSample predictions:")
    for i in range(min(3, len(valid_refs))):
        print(f"\n{i + 1}. Reference: {valid_refs[i]}")
        print(f"   Prediction: {valid_preds[i]}")
else:
    print("ERROR: No valid samples for evaluation!")
Downloads last month
-
Safetensors
Model size
0.8B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support