| |
| """Exact-match span F1 evaluation for Arcspan NER models. |
| |
| Computes both exact-boundary span F1 (CoNLL/seqeval style) and OPF's native |
| containment-based span F1, printing them side by side for comparison. |
| |
| Usage: |
| python3 scripts/eval_exact_match.py \ |
| --checkpoint checkpoints/r8_5class/ \ |
| --test-data data/processed/aptner_5class_test_clean.jsonl \ |
| --device cuda |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Sequence |
|
|
| |
| |
| |
|
|
| Span = tuple[str, int, int] |
|
|
|
|
| def parse_gold_spans(record: dict) -> list[Span]: |
| """Extract gold spans from a JSONL record with OPF 'spans' format. |
| |
| Format: {"spans": {"Label: text": [[start, end], ...], ...}} |
| """ |
| spans_field = record.get("spans", {}) |
| if not spans_field: |
| return [] |
| result: list[Span] = [] |
| for key, offsets in spans_field.items(): |
| |
| label = key.split(":")[0].strip() |
| for offset_pair in offsets: |
| start, end = int(offset_pair[0]), int(offset_pair[1]) |
| result.append((label, start, end)) |
| return result |
|
|
|
|
| def predict_spans(redactor, text: str) -> list[Span]: |
| """Run OPF inference and return predicted spans as (label, start, end).""" |
| result = redactor.redact(text) |
| spans: list[Span] = [] |
| for det in result.detected_spans: |
| spans.append((det.label, det.start, det.end)) |
| return spans |
|
|
|
|
| |
| |
| |
|
|
| def _span_set(spans: list[Span]) -> set[Span]: |
| return set(spans) |
|
|
|
|
| def compute_exact_match_metrics( |
| all_gold: list[list[Span]], |
| all_pred: list[list[Span]], |
| ) -> dict: |
| """Compute exact-match span-level P/R/F1 (micro + macro + per-class).""" |
| |
| class_tp: defaultdict[str, int] = defaultdict(int) |
| class_fp: defaultdict[str, int] = defaultdict(int) |
| class_fn: defaultdict[str, int] = defaultdict(int) |
|
|
| for gold_spans, pred_spans in zip(all_gold, all_pred): |
| gold_set = _span_set(gold_spans) |
| pred_set = _span_set(pred_spans) |
|
|
| |
| tp_spans = gold_set & pred_set |
| fp_spans = pred_set - gold_set |
| fn_spans = gold_set - pred_set |
|
|
| for label, _, _ in tp_spans: |
| class_tp[label] += 1 |
| for label, _, _ in fp_spans: |
| class_fp[label] += 1 |
| for label, _, _ in fn_spans: |
| class_fn[label] += 1 |
|
|
| all_labels = sorted(set(class_tp) | set(class_fp) | set(class_fn)) |
|
|
| |
| per_class = {} |
| for label in all_labels: |
| tp = class_tp[label] |
| fp = class_fp[label] |
| fn = class_fn[label] |
| p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 |
| per_class[label] = {"precision": p, "recall": r, "f1": f1, |
| "tp": tp, "fp": fp, "fn": fn, |
| "support": tp + fn} |
|
|
| |
| total_tp = sum(class_tp.values()) |
| total_fp = sum(class_fp.values()) |
| total_fn = sum(class_fn.values()) |
| micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0 |
| micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0 |
| micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0.0 |
|
|
| |
| if all_labels: |
| macro_p = sum(per_class[l]["precision"] for l in all_labels) / len(all_labels) |
| macro_r = sum(per_class[l]["recall"] for l in all_labels) / len(all_labels) |
| macro_f1 = sum(per_class[l]["f1"] for l in all_labels) / len(all_labels) |
| else: |
| macro_p = macro_r = macro_f1 = 0.0 |
|
|
| return { |
| "per_class": per_class, |
| "micro": {"precision": micro_p, "recall": micro_r, "f1": micro_f1}, |
| "macro": {"precision": macro_p, "recall": macro_r, "f1": macro_f1}, |
| "total_tp": total_tp, "total_fp": total_fp, "total_fn": total_fn, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def compute_containment_metrics( |
| all_gold: list[list[Span]], |
| all_pred: list[list[Span]], |
| ) -> dict: |
| """Compute containment-based span P/R/F1. |
| |
| Precision: predicted span is TP if it is *contained within* a gold span |
| with the same label. |
| Recall: gold span is TP if it is *contained within* a predicted span |
| with the same label. |
| """ |
| class_tp_p: defaultdict[str, int] = defaultdict(int) |
| class_fp: defaultdict[str, int] = defaultdict(int) |
| class_tp_r: defaultdict[str, int] = defaultdict(int) |
| class_fn: defaultdict[str, int] = defaultdict(int) |
|
|
| for gold_spans, pred_spans in zip(all_gold, all_pred): |
| |
| for p_label, p_s, p_e in pred_spans: |
| matched = False |
| for g_label, g_s, g_e in gold_spans: |
| if p_label == g_label and g_s <= p_s and g_e >= p_e: |
| matched = True |
| break |
| if matched: |
| class_tp_p[p_label] += 1 |
| else: |
| class_fp[p_label] += 1 |
|
|
| |
| for g_label, g_s, g_e in gold_spans: |
| matched = False |
| for p_label, p_s, p_e in pred_spans: |
| if g_label == p_label and p_s <= g_s and p_e >= g_e: |
| matched = True |
| break |
| if matched: |
| class_tp_r[g_label] += 1 |
| else: |
| class_fn[g_label] += 1 |
|
|
| all_labels = sorted(set(class_tp_p) | set(class_fp) | set(class_tp_r) | set(class_fn)) |
|
|
| per_class = {} |
| for label in all_labels: |
| tp_p = class_tp_p[label] |
| fp = class_fp[label] |
| tp_r = class_tp_r[label] |
| fn = class_fn[label] |
| p = tp_p / (tp_p + fp) if (tp_p + fp) > 0 else 0.0 |
| r = tp_r / (tp_r + fn) if (tp_r + fn) > 0 else 0.0 |
| f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 |
| per_class[label] = {"precision": p, "recall": r, "f1": f1, |
| "support": tp_r + fn} |
|
|
| total_tp_p = sum(class_tp_p.values()) |
| total_fp = sum(class_fp.values()) |
| total_tp_r = sum(class_tp_r.values()) |
| total_fn = sum(class_fn.values()) |
| micro_p = total_tp_p / (total_tp_p + total_fp) if (total_tp_p + total_fp) > 0 else 0.0 |
| micro_r = total_tp_r / (total_tp_r + total_fn) if (total_tp_r + total_fn) > 0 else 0.0 |
| micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0.0 |
|
|
| if all_labels: |
| macro_p = sum(per_class[l]["precision"] for l in all_labels) / len(all_labels) |
| macro_r = sum(per_class[l]["recall"] for l in all_labels) / len(all_labels) |
| macro_f1 = sum(per_class[l]["f1"] for l in all_labels) / len(all_labels) |
| else: |
| macro_p = macro_r = macro_f1 = 0.0 |
|
|
| return { |
| "per_class": per_class, |
| "micro": {"precision": micro_p, "recall": micro_r, "f1": micro_f1}, |
| "macro": {"precision": macro_p, "recall": macro_r, "f1": macro_f1}, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def print_metrics_table(title: str, metrics: dict, show_counts: bool = False) -> None: |
| print(f"\n{'=' * 72}") |
| print(f" {title}") |
| print(f"{'=' * 72}") |
|
|
| per_class = metrics["per_class"] |
| if per_class: |
| if show_counts: |
| header = f" {'Label':<20s} {'Prec':>7s} {'Rec':>7s} {'F1':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'Sup':>5s}" |
| else: |
| header = f" {'Label':<20s} {'Prec':>7s} {'Rec':>7s} {'F1':>7s} {'Sup':>5s}" |
| print(header) |
| print(f" {'-' * (len(header) - 2)}") |
| for label in sorted(per_class): |
| m = per_class[label] |
| if show_counts: |
| print(f" {label:<20s} {m['precision']:7.4f} {m['recall']:7.4f} {m['f1']:7.4f} " |
| f"{m.get('tp', '-'):>5} {m.get('fp', '-'):>5} {m.get('fn', '-'):>5} {m['support']:>5}") |
| else: |
| print(f" {label:<20s} {m['precision']:7.4f} {m['recall']:7.4f} {m['f1']:7.4f} {m['support']:>5}") |
| print() |
|
|
| micro = metrics["micro"] |
| macro = metrics["macro"] |
| print(f" {'micro-avg':<20s} {micro['precision']:7.4f} {micro['recall']:7.4f} {micro['f1']:7.4f}") |
| print(f" {'macro-avg':<20s} {macro['precision']:7.4f} {macro['recall']:7.4f} {macro['f1']:7.4f}") |
|
|
|
|
| def print_comparison(exact: dict, containment: dict) -> None: |
| print(f"\n{'=' * 72}") |
| print(" COMPARISON: Exact-Match vs Containment Span F1") |
| print(f"{'=' * 72}") |
| print(f" {'Metric':<25s} {'Exact':>10s} {'Contain':>10s} {'Delta':>10s}") |
| print(f" {'-' * 55}") |
| for agg in ["micro", "macro"]: |
| for m in ["precision", "recall", "f1"]: |
| e = exact[agg][m] |
| c = containment[agg][m] |
| delta = c - e |
| print(f" {agg + '-' + m:<25s} {e:10.4f} {c:10.4f} {delta:+10.4f}") |
|
|
| |
| all_labels = sorted(set(exact["per_class"]) | set(containment["per_class"])) |
| if all_labels: |
| print(f"\n {'Label':<20s} {'Exact-F1':>10s} {'Cont-F1':>10s} {'Delta':>10s}") |
| print(f" {'-' * 50}") |
| for label in all_labels: |
| e_f1 = exact["per_class"].get(label, {}).get("f1", 0.0) |
| c_f1 = containment["per_class"].get(label, {}).get("f1", 0.0) |
| print(f" {label:<20s} {e_f1:10.4f} {c_f1:10.4f} {c_f1 - e_f1:+10.4f}") |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Exact-match span F1 evaluation for Arcspan NER models" |
| ) |
| parser.add_argument("--checkpoint", required=True, help="Path to OPF checkpoint dir") |
| parser.add_argument("--test-data", required=True, help="Path to test JSONL file") |
| parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) |
| parser.add_argument("--max-examples", type=int, default=None, |
| help="Limit number of examples (for quick testing)") |
| parser.add_argument("--decode-mode", default="viterbi", choices=["viterbi", "argmax"]) |
| parser.add_argument("--json-out", default=None, help="Write metrics JSON to this path") |
| args = parser.parse_args() |
|
|
| |
| print(f"Loading model from {args.checkpoint} on {args.device}...") |
| from opf import OPF |
| redactor = OPF( |
| model=args.checkpoint, |
| device=args.device, |
| output_mode="typed", |
| decode_mode=args.decode_mode, |
| trim_whitespace=True, |
| discard_overlapping_predicted_spans=False, |
| ) |
| if args.decode_mode == "viterbi": |
| redactor.set_viterbi_decoder() |
| print("Model loaded.") |
|
|
| |
| test_path = Path(args.test_data) |
| records: list[dict] = [] |
| with open(test_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| records.append(json.loads(line)) |
| if args.max_examples is not None: |
| records = records[:args.max_examples] |
| print(f"Loaded {len(records)} test examples from {test_path}") |
|
|
| |
| all_gold: list[list[Span]] = [] |
| all_pred: list[list[Span]] = [] |
| n_gold_spans = 0 |
| n_pred_spans = 0 |
|
|
| start_time = time.perf_counter() |
| for i, record in enumerate(records): |
| text = record["text"] |
| gold = parse_gold_spans(record) |
| pred = predict_spans(redactor, text) |
|
|
| all_gold.append(gold) |
| all_pred.append(pred) |
| n_gold_spans += len(gold) |
| n_pred_spans += len(pred) |
|
|
| if (i + 1) % 100 == 0: |
| elapsed = time.perf_counter() - start_time |
| print(f" [{i+1}/{len(records)}] {elapsed:.1f}s " |
| f"({(i+1)/elapsed:.1f} ex/s)", file=sys.stderr) |
|
|
| elapsed = time.perf_counter() - start_time |
| print(f"\nInference complete: {len(records)} examples, " |
| f"{n_gold_spans} gold spans, {n_pred_spans} predicted spans, " |
| f"{elapsed:.1f}s ({len(records)/elapsed:.1f} ex/s)") |
|
|
| |
| exact_metrics = compute_exact_match_metrics(all_gold, all_pred) |
| containment_metrics = compute_containment_metrics(all_gold, all_pred) |
|
|
| print_metrics_table("EXACT-MATCH Span Metrics (CoNLL/seqeval style)", |
| exact_metrics, show_counts=True) |
| print_metrics_table("CONTAINMENT Span Metrics (OPF native style)", |
| containment_metrics) |
| print_comparison(exact_metrics, containment_metrics) |
|
|
| |
| if args.json_out: |
| output = { |
| "exact_match": { |
| "micro": exact_metrics["micro"], |
| "macro": exact_metrics["macro"], |
| "per_class": exact_metrics["per_class"], |
| }, |
| "containment": { |
| "micro": containment_metrics["micro"], |
| "macro": containment_metrics["macro"], |
| "per_class": containment_metrics["per_class"], |
| }, |
| "n_examples": len(records), |
| "n_gold_spans": n_gold_spans, |
| "n_pred_spans": n_pred_spans, |
| "checkpoint": args.checkpoint, |
| "test_data": args.test_data, |
| "decode_mode": args.decode_mode, |
| } |
| out_path = Path(args.json_out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(json.dumps(output, indent=2), encoding="utf-8") |
| print(f"\nMetrics written to {args.json_out}") |
|
|
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|