| |
| """ |
| Evaluate ByT5 + XLM-RoBERTa reranker on internal_test_set_500.csv. |
| CSV columns: id, category, input, code_mixed_reference |
| |
| Usage: |
| python misc/eval_internal_500.py # full 500 |
| python misc/eval_internal_500.py --max 10 # dry run |
| python misc/eval_internal_500.py --max 10 --cat colloquial |
| """ |
|
|
| import sys |
| import os |
| import argparse |
| import csv |
| import time |
| from pathlib import Path |
| from collections import defaultdict |
|
|
| project_root = Path(__file__).parent.parent |
| sys.path.insert(0, str(project_root)) |
|
|
| from core.decoder import BeamSearchDecoder |
|
|
|
|
| |
|
|
| def levenshtein(a, b): |
| if not a: return len(b) |
| if not b: return len(a) |
| prev = list(range(len(b) + 1)) |
| for i, ca in enumerate(a, 1): |
| curr = [i] + [0] * len(b) |
| for j, cb in enumerate(b, 1): |
| curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + (0 if ca == cb else 1)) |
| prev = curr |
| return prev[-1] |
|
|
| def cer(pred, ref): |
| if not ref: return 0.0 if not pred else 1.0 |
| return levenshtein(pred, ref) / len(ref) |
|
|
| def wer(pred, ref): |
| pw, rw = pred.split(), ref.split() |
| if not rw: return 0.0 if not pw else 1.0 |
| return levenshtein(pw, rw) / len(rw) |
|
|
| def bleu1(pred, ref): |
| pt, rt = pred.split(), ref.split() |
| if not rt: return 1.0 if not pt else 0.0 |
| matches = sum(1 for t in pt if t in rt) |
| return matches / len(rt) |
|
|
| def exact_match(pred, ref): |
| return pred.strip() == ref.strip() |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--max", type=int, default=None, help="Max samples to evaluate") |
| parser.add_argument("--cat", type=str, default=None, help="Filter to one category") |
| parser.add_argument("--out", type=str, default="misc/internal_500_results.csv", help="Output CSV path") |
| args = parser.parse_args() |
|
|
| csv_path = project_root / "misc" / "internal_test_set_500.csv" |
|
|
| |
| samples = [] |
| with open(csv_path, newline="", encoding="utf-8-sig") as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| if args.cat and row["category"] != args.cat: |
| continue |
| samples.append(row) |
| if args.max and len(samples) >= args.max: |
| break |
|
|
| print(f"Loaded {len(samples)} samples" + (f" (category={args.cat})" if args.cat else "")) |
|
|
| |
| import torch |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Loading BeamSearchDecoder on {device}...") |
| decoder = BeamSearchDecoder(device=device) |
| print("Ready.\n") |
|
|
| results = [] |
| category_stats = defaultdict(list) |
|
|
| for i, row in enumerate(samples): |
| inp = row["input"].strip() |
| ref = row["code_mixed_reference"].strip() |
| cat = row["category"] |
| sid = row["id"] |
|
|
| t0 = time.time() |
| try: |
| pred, _, _ = decoder.decode(inp) |
| except Exception as e: |
| print(f" ERROR id={sid}: {e}") |
| pred = "[ERROR]" |
| elapsed = time.time() - t0 |
|
|
| c = cer(pred, ref) |
| w = wer(pred, ref) |
| b = bleu1(pred, ref) |
| em = exact_match(pred, ref) |
|
|
| result = { |
| "id": sid, |
| "category": cat, |
| "input": inp, |
| "reference": ref, |
| "prediction": pred, |
| "cer": round(c, 4), |
| "wer": round(w, 4), |
| "bleu": round(b, 4), |
| "exact_match": em, |
| "time_s": round(elapsed, 3), |
| } |
| results.append(result) |
| category_stats[cat].append(result) |
|
|
| status = "PASS" if em else "FAIL" |
| print(f"[{i+1:>4}/{len(samples)}] {status} id={sid:>4} cat={cat:<15} CER={c:.3f} WER={w:.3f} BLEU={b:.3f}") |
| if not em: |
| print(f" IN: {inp}") |
| print(f" REF: {ref}") |
| print(f" GOT: {pred}") |
|
|
| |
| out_path = project_root / args.out |
| with open(out_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.DictWriter(f, fieldnames=results[0].keys()) |
| writer.writeheader() |
| writer.writerows(results) |
| print(f"\nResults saved to: {out_path}") |
|
|
| |
| print(f"\n{'='*70}") |
| print(f"{'CATEGORY':<18} {'n':>4} {'CER':>7} {'WER':>7} {'BLEU':>7} {'EM%':>7}") |
| print(f"{'='*70}") |
|
|
| all_results = results |
| for cat in sorted(category_stats): |
| rows = category_stats[cat] |
| n = len(rows) |
| avg_cer = sum(r["cer"] for r in rows) / n |
| avg_wer = sum(r["wer"] for r in rows) / n |
| avg_bleu = sum(r["bleu"] for r in rows) / n |
| em_pct = sum(1 for r in rows if r["exact_match"]) / n * 100 |
| print(f"{cat:<18} {n:>4} {avg_cer:>7.4f} {avg_wer:>7.4f} {avg_bleu:>7.4f} {em_pct:>6.1f}%") |
|
|
| n = len(all_results) |
| print(f"{'β'*70}") |
| print(f"{'OVERALL':<18} {n:>4} " |
| f"{sum(r['cer'] for r in all_results)/n:>7.4f} " |
| f"{sum(r['wer'] for r in all_results)/n:>7.4f} " |
| f"{sum(r['bleu'] for r in all_results)/n:>7.4f} " |
| f"{sum(1 for r in all_results if r['exact_match'])/n*100:>6.1f}%") |
| print(f"{'='*70}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|