SinCode / misc /eval_internal_500.py
KalanaPabasara
SinCode v3 β€” seq2seq pipeline, evaluation scripts, IndoNLP benchmark data
1fed70a
#!/usr/bin/env python3
"""
Evaluate ByT5 + XLM-RoBERTa reranker on internal_test_set_500.csv.
CSV columns: id, category, input, code_mixed_reference
Usage:
python misc/eval_internal_500.py # full 500
python misc/eval_internal_500.py --max 10 # dry run
python misc/eval_internal_500.py --max 10 --cat colloquial
"""
import sys
import os
import argparse
import csv
import time
from pathlib import Path
from collections import defaultdict
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from core.decoder import BeamSearchDecoder
# ── Metrics ─────────────────────────────────────────────────────────────────
def levenshtein(a, b):
if not a: return len(b)
if not b: return len(a)
prev = list(range(len(b) + 1))
for i, ca in enumerate(a, 1):
curr = [i] + [0] * len(b)
for j, cb in enumerate(b, 1):
curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + (0 if ca == cb else 1))
prev = curr
return prev[-1]
def cer(pred, ref):
if not ref: return 0.0 if not pred else 1.0
return levenshtein(pred, ref) / len(ref)
def wer(pred, ref):
pw, rw = pred.split(), ref.split()
if not rw: return 0.0 if not pw else 1.0
return levenshtein(pw, rw) / len(rw)
def bleu1(pred, ref):
pt, rt = pred.split(), ref.split()
if not rt: return 1.0 if not pt else 0.0
matches = sum(1 for t in pt if t in rt)
return matches / len(rt)
def exact_match(pred, ref):
return pred.strip() == ref.strip()
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max", type=int, default=None, help="Max samples to evaluate")
parser.add_argument("--cat", type=str, default=None, help="Filter to one category")
parser.add_argument("--out", type=str, default="misc/internal_500_results.csv", help="Output CSV path")
args = parser.parse_args()
csv_path = project_root / "misc" / "internal_test_set_500.csv"
# Load samples
samples = []
with open(csv_path, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
if args.cat and row["category"] != args.cat:
continue
samples.append(row)
if args.max and len(samples) >= args.max:
break
print(f"Loaded {len(samples)} samples" + (f" (category={args.cat})" if args.cat else ""))
# Load decoder
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading BeamSearchDecoder on {device}...")
decoder = BeamSearchDecoder(device=device)
print("Ready.\n")
results = []
category_stats = defaultdict(list)
for i, row in enumerate(samples):
inp = row["input"].strip()
ref = row["code_mixed_reference"].strip()
cat = row["category"]
sid = row["id"]
t0 = time.time()
try:
pred, _, _ = decoder.decode(inp)
except Exception as e:
print(f" ERROR id={sid}: {e}")
pred = "[ERROR]"
elapsed = time.time() - t0
c = cer(pred, ref)
w = wer(pred, ref)
b = bleu1(pred, ref)
em = exact_match(pred, ref)
result = {
"id": sid,
"category": cat,
"input": inp,
"reference": ref,
"prediction": pred,
"cer": round(c, 4),
"wer": round(w, 4),
"bleu": round(b, 4),
"exact_match": em,
"time_s": round(elapsed, 3),
}
results.append(result)
category_stats[cat].append(result)
status = "PASS" if em else "FAIL"
print(f"[{i+1:>4}/{len(samples)}] {status} id={sid:>4} cat={cat:<15} CER={c:.3f} WER={w:.3f} BLEU={b:.3f}")
if not em:
print(f" IN: {inp}")
print(f" REF: {ref}")
print(f" GOT: {pred}")
# Write results CSV
out_path = project_root / args.out
with open(out_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
print(f"\nResults saved to: {out_path}")
# Summary by category
print(f"\n{'='*70}")
print(f"{'CATEGORY':<18} {'n':>4} {'CER':>7} {'WER':>7} {'BLEU':>7} {'EM%':>7}")
print(f"{'='*70}")
all_results = results
for cat in sorted(category_stats):
rows = category_stats[cat]
n = len(rows)
avg_cer = sum(r["cer"] for r in rows) / n
avg_wer = sum(r["wer"] for r in rows) / n
avg_bleu = sum(r["bleu"] for r in rows) / n
em_pct = sum(1 for r in rows if r["exact_match"]) / n * 100
print(f"{cat:<18} {n:>4} {avg_cer:>7.4f} {avg_wer:>7.4f} {avg_bleu:>7.4f} {em_pct:>6.1f}%")
n = len(all_results)
print(f"{'─'*70}")
print(f"{'OVERALL':<18} {n:>4} "
f"{sum(r['cer'] for r in all_results)/n:>7.4f} "
f"{sum(r['wer'] for r in all_results)/n:>7.4f} "
f"{sum(r['bleu'] for r in all_results)/n:>7.4f} "
f"{sum(1 for r in all_results if r['exact_match'])/n*100:>6.1f}%")
print(f"{'='*70}")
if __name__ == "__main__":
main()