from __future__ import annotations import argparse import json from collections import Counter, defaultdict from dataclasses import dataclass from pathlib import Path from typing import Any, DefaultDict, Dict, Iterable, List, Mapping, Tuple from dialect_analysis.pipeline import classify_text @dataclass(frozen=True) class Sample: id: str label: str text: str strip_diacritics: bool = True synthetic: bool = False def load_samples(path: Path) -> List[Sample]: samples: List[Sample] = [] for i, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): line = line.strip() if not line or line.startswith("#"): continue obj = json.loads(line) samples.append( Sample( id=str(obj.get("id") or f"sample_{i}"), label=str(obj["label"]), text=str(obj["text"]), strip_diacritics=bool(obj.get("strip_diacritics", True)), synthetic=bool(obj.get("synthetic", False)), ) ) return samples def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Evaluate dialect classifier against a JSONL sample set.") p.add_argument( "--samples", type=Path, default=Path(__file__).with_name("samples.jsonl"), help="Path to JSONL file with {id,label,text,strip_diacritics[,synthetic]}", ) return p.parse_args() def confusion_matrix(rows: Iterable[Tuple[str, str]]) -> Tuple[List[str], List[List[int]]]: labels = sorted({t for t, _ in rows} | {p for _, p in rows}) idx = {l: i for i, l in enumerate(labels)} mat = [[0 for _ in labels] for _ in labels] for true_label, pred_label in rows: mat[idx[true_label]][idx[pred_label]] += 1 return labels, mat def main() -> int: args = parse_args() path = Path(args.samples) if not path.exists(): print(f"Missing samples file: {path}") return 2 samples = load_samples(path) if not samples: print("No samples found.") return 2 pairs: List[Tuple[str, str]] = [] correct = 0 confidences: List[float] = [] pairs_real: List[Tuple[str, str]] = [] pairs_synth: List[Tuple[str, str]] = [] correct_real = 0 correct_synth = 0 by_label: DefaultDict[str, Counter[str]] = defaultdict(Counter) for s in samples: result: Mapping[str, Any] = classify_text(s.text, strip_diacritics=s.strip_diacritics) pred = str(result.get("dialect", "")) conf = float(result.get("confidence", 0.0) or 0.0) confidences.append(conf) pairs.append((s.label, pred)) if s.synthetic: pairs_synth.append((s.label, pred)) if pred == s.label: correct_synth += 1 else: pairs_real.append((s.label, pred)) if pred == s.label: correct_real += 1 by_label[s.label][pred] += 1 if pred == s.label: correct += 1 else: scores: Mapping[str, float] = result.get("scores", {}) or {} top2 = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:2] top2_str = ", ".join(f"{d}={pct:.1f}%" for d, pct in top2) print(f"MISS {s.id}: true={s.label} pred={pred} conf={conf*100:.1f}% top2=({top2_str})") acc = correct / max(1, len(samples)) avg_conf = sum(confidences) / max(1, len(confidences)) print("\nSummary") print(f" File: {path.name}") print(f" Samples: {len(samples)}") print(f" Accuracy: {acc*100:.1f}%") print(f" Avg confidence: {avg_conf*100:.1f}%") if pairs_real and pairs_synth: acc_real = correct_real / max(1, len(pairs_real)) acc_synth = correct_synth / max(1, len(pairs_synth)) print(f" Accuracy (real): {acc_real*100:.1f}% (n={len(pairs_real)})") print(f" Accuracy (synthetic): {acc_synth*100:.1f}% (n={len(pairs_synth)})") labels, mat = confusion_matrix(pairs) print("\nConfusion matrix (rows=true, cols=pred)") header = "".ljust(14) + " ".join(l[:10].ljust(10) for l in labels) print(header) for i, true_label in enumerate(labels): row = " ".join(str(mat[i][j]).ljust(10) for j in range(len(labels))) print(true_label[:12].ljust(14) + row) print("\nPer-label predictions") for true_label in sorted(by_label.keys()): counts = by_label[true_label] total = sum(counts.values()) ordered = sorted(counts.items(), key=lambda kv: kv[1], reverse=True) dist = ", ".join(f"{p}:{c}" for p, c in ordered) print(f" {true_label} (n={total}): {dist}") return 0 if __name__ == "__main__": raise SystemExit(main())