Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from collections import Counter, defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, DefaultDict, Dict, Iterable, List, Mapping, Tuple | |
| from dialect_analysis.pipeline import classify_text | |
| class Sample: | |
| id: str | |
| label: str | |
| text: str | |
| strip_diacritics: bool = True | |
| synthetic: bool = False | |
| def load_samples(path: Path) -> List[Sample]: | |
| samples: List[Sample] = [] | |
| for i, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): | |
| line = line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| obj = json.loads(line) | |
| samples.append( | |
| Sample( | |
| id=str(obj.get("id") or f"sample_{i}"), | |
| label=str(obj["label"]), | |
| text=str(obj["text"]), | |
| strip_diacritics=bool(obj.get("strip_diacritics", True)), | |
| synthetic=bool(obj.get("synthetic", False)), | |
| ) | |
| ) | |
| return samples | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Evaluate dialect classifier against a JSONL sample set.") | |
| p.add_argument( | |
| "--samples", | |
| type=Path, | |
| default=Path(__file__).with_name("samples.jsonl"), | |
| help="Path to JSONL file with {id,label,text,strip_diacritics[,synthetic]}", | |
| ) | |
| return p.parse_args() | |
| def confusion_matrix(rows: Iterable[Tuple[str, str]]) -> Tuple[List[str], List[List[int]]]: | |
| labels = sorted({t for t, _ in rows} | {p for _, p in rows}) | |
| idx = {l: i for i, l in enumerate(labels)} | |
| mat = [[0 for _ in labels] for _ in labels] | |
| for true_label, pred_label in rows: | |
| mat[idx[true_label]][idx[pred_label]] += 1 | |
| return labels, mat | |
| def main() -> int: | |
| args = parse_args() | |
| path = Path(args.samples) | |
| if not path.exists(): | |
| print(f"Missing samples file: {path}") | |
| return 2 | |
| samples = load_samples(path) | |
| if not samples: | |
| print("No samples found.") | |
| return 2 | |
| pairs: List[Tuple[str, str]] = [] | |
| correct = 0 | |
| confidences: List[float] = [] | |
| pairs_real: List[Tuple[str, str]] = [] | |
| pairs_synth: List[Tuple[str, str]] = [] | |
| correct_real = 0 | |
| correct_synth = 0 | |
| by_label: DefaultDict[str, Counter[str]] = defaultdict(Counter) | |
| for s in samples: | |
| result: Mapping[str, Any] = classify_text(s.text, strip_diacritics=s.strip_diacritics) | |
| pred = str(result.get("dialect", "")) | |
| conf = float(result.get("confidence", 0.0) or 0.0) | |
| confidences.append(conf) | |
| pairs.append((s.label, pred)) | |
| if s.synthetic: | |
| pairs_synth.append((s.label, pred)) | |
| if pred == s.label: | |
| correct_synth += 1 | |
| else: | |
| pairs_real.append((s.label, pred)) | |
| if pred == s.label: | |
| correct_real += 1 | |
| by_label[s.label][pred] += 1 | |
| if pred == s.label: | |
| correct += 1 | |
| else: | |
| scores: Mapping[str, float] = result.get("scores", {}) or {} | |
| top2 = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:2] | |
| top2_str = ", ".join(f"{d}={pct:.1f}%" for d, pct in top2) | |
| print(f"MISS {s.id}: true={s.label} pred={pred} conf={conf*100:.1f}% top2=({top2_str})") | |
| acc = correct / max(1, len(samples)) | |
| avg_conf = sum(confidences) / max(1, len(confidences)) | |
| print("\nSummary") | |
| print(f" File: {path.name}") | |
| print(f" Samples: {len(samples)}") | |
| print(f" Accuracy: {acc*100:.1f}%") | |
| print(f" Avg confidence: {avg_conf*100:.1f}%") | |
| if pairs_real and pairs_synth: | |
| acc_real = correct_real / max(1, len(pairs_real)) | |
| acc_synth = correct_synth / max(1, len(pairs_synth)) | |
| print(f" Accuracy (real): {acc_real*100:.1f}% (n={len(pairs_real)})") | |
| print(f" Accuracy (synthetic): {acc_synth*100:.1f}% (n={len(pairs_synth)})") | |
| labels, mat = confusion_matrix(pairs) | |
| print("\nConfusion matrix (rows=true, cols=pred)") | |
| header = "".ljust(14) + " ".join(l[:10].ljust(10) for l in labels) | |
| print(header) | |
| for i, true_label in enumerate(labels): | |
| row = " ".join(str(mat[i][j]).ljust(10) for j in range(len(labels))) | |
| print(true_label[:12].ljust(14) + row) | |
| print("\nPer-label predictions") | |
| for true_label in sorted(by_label.keys()): | |
| counts = by_label[true_label] | |
| total = sum(counts.values()) | |
| ordered = sorted(counts.items(), key=lambda kv: kv[1], reverse=True) | |
| dist = ", ".join(f"{p}:{c}" for p, c in ordered) | |
| print(f" {true_label} (n={total}): {dist}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |