#!/usr/bin/env python3 """Exact-match span F1 evaluation for Arcspan NER models. Computes both exact-boundary span F1 (CoNLL/seqeval style) and OPF's native containment-based span F1, printing them side by side for comparison. Usage: python3 scripts/eval_exact_match.py \ --checkpoint checkpoints/r8_5class/ \ --test-data data/processed/aptner_5class_test_clean.jsonl \ --device cuda """ from __future__ import annotations import argparse import json import sys import time from collections import defaultdict from pathlib import Path from typing import Sequence # --------------------------------------------------------------------------- # Span type: (label: str, start: int, end: int) — character offsets # --------------------------------------------------------------------------- Span = tuple[str, int, int] def parse_gold_spans(record: dict) -> list[Span]: """Extract gold spans from a JSONL record with OPF 'spans' format. Format: {"spans": {"Label: text": [[start, end], ...], ...}} """ spans_field = record.get("spans", {}) if not spans_field: return [] result: list[Span] = [] for key, offsets in spans_field.items(): # key format: "Label: matched_text" label = key.split(":")[0].strip() for offset_pair in offsets: start, end = int(offset_pair[0]), int(offset_pair[1]) result.append((label, start, end)) return result def predict_spans(redactor, text: str) -> list[Span]: """Run OPF inference and return predicted spans as (label, start, end).""" result = redactor.redact(text) spans: list[Span] = [] for det in result.detected_spans: spans.append((det.label, det.start, det.end)) return spans # --------------------------------------------------------------------------- # Exact-match metrics # --------------------------------------------------------------------------- def _span_set(spans: list[Span]) -> set[Span]: return set(spans) def compute_exact_match_metrics( all_gold: list[list[Span]], all_pred: list[list[Span]], ) -> dict: """Compute exact-match span-level P/R/F1 (micro + macro + per-class).""" # Per-class accumulators class_tp: defaultdict[str, int] = defaultdict(int) class_fp: defaultdict[str, int] = defaultdict(int) class_fn: defaultdict[str, int] = defaultdict(int) for gold_spans, pred_spans in zip(all_gold, all_pred): gold_set = _span_set(gold_spans) pred_set = _span_set(pred_spans) # TP: in both gold and pred (exact label + start + end) tp_spans = gold_set & pred_set fp_spans = pred_set - gold_set fn_spans = gold_set - pred_set for label, _, _ in tp_spans: class_tp[label] += 1 for label, _, _ in fp_spans: class_fp[label] += 1 for label, _, _ in fn_spans: class_fn[label] += 1 all_labels = sorted(set(class_tp) | set(class_fp) | set(class_fn)) # Per-class metrics per_class = {} for label in all_labels: tp = class_tp[label] fp = class_fp[label] fn = class_fn[label] p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 per_class[label] = {"precision": p, "recall": r, "f1": f1, "tp": tp, "fp": fp, "fn": fn, "support": tp + fn} # Micro-average total_tp = sum(class_tp.values()) total_fp = sum(class_fp.values()) total_fn = sum(class_fn.values()) micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0 micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0 micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0.0 # Macro-average if all_labels: macro_p = sum(per_class[l]["precision"] for l in all_labels) / len(all_labels) macro_r = sum(per_class[l]["recall"] for l in all_labels) / len(all_labels) macro_f1 = sum(per_class[l]["f1"] for l in all_labels) / len(all_labels) else: macro_p = macro_r = macro_f1 = 0.0 return { "per_class": per_class, "micro": {"precision": micro_p, "recall": micro_r, "f1": micro_f1}, "macro": {"precision": macro_p, "recall": macro_r, "f1": macro_f1}, "total_tp": total_tp, "total_fp": total_fp, "total_fn": total_fn, } # --------------------------------------------------------------------------- # Containment-match metrics (OPF style) # --------------------------------------------------------------------------- def compute_containment_metrics( all_gold: list[list[Span]], all_pred: list[list[Span]], ) -> dict: """Compute containment-based span P/R/F1. Precision: predicted span is TP if it is *contained within* a gold span with the same label. Recall: gold span is TP if it is *contained within* a predicted span with the same label. """ class_tp_p: defaultdict[str, int] = defaultdict(int) # for precision class_fp: defaultdict[str, int] = defaultdict(int) class_tp_r: defaultdict[str, int] = defaultdict(int) # for recall class_fn: defaultdict[str, int] = defaultdict(int) for gold_spans, pred_spans in zip(all_gold, all_pred): # Precision direction: pred contained in gold for p_label, p_s, p_e in pred_spans: matched = False for g_label, g_s, g_e in gold_spans: if p_label == g_label and g_s <= p_s and g_e >= p_e: matched = True break if matched: class_tp_p[p_label] += 1 else: class_fp[p_label] += 1 # Recall direction: gold contained in pred for g_label, g_s, g_e in gold_spans: matched = False for p_label, p_s, p_e in pred_spans: if g_label == p_label and p_s <= g_s and p_e >= g_e: matched = True break if matched: class_tp_r[g_label] += 1 else: class_fn[g_label] += 1 all_labels = sorted(set(class_tp_p) | set(class_fp) | set(class_tp_r) | set(class_fn)) per_class = {} for label in all_labels: tp_p = class_tp_p[label] fp = class_fp[label] tp_r = class_tp_r[label] fn = class_fn[label] p = tp_p / (tp_p + fp) if (tp_p + fp) > 0 else 0.0 r = tp_r / (tp_r + fn) if (tp_r + fn) > 0 else 0.0 f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 per_class[label] = {"precision": p, "recall": r, "f1": f1, "support": tp_r + fn} total_tp_p = sum(class_tp_p.values()) total_fp = sum(class_fp.values()) total_tp_r = sum(class_tp_r.values()) total_fn = sum(class_fn.values()) micro_p = total_tp_p / (total_tp_p + total_fp) if (total_tp_p + total_fp) > 0 else 0.0 micro_r = total_tp_r / (total_tp_r + total_fn) if (total_tp_r + total_fn) > 0 else 0.0 micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0.0 if all_labels: macro_p = sum(per_class[l]["precision"] for l in all_labels) / len(all_labels) macro_r = sum(per_class[l]["recall"] for l in all_labels) / len(all_labels) macro_f1 = sum(per_class[l]["f1"] for l in all_labels) / len(all_labels) else: macro_p = macro_r = macro_f1 = 0.0 return { "per_class": per_class, "micro": {"precision": micro_p, "recall": micro_r, "f1": micro_f1}, "macro": {"precision": macro_p, "recall": macro_r, "f1": macro_f1}, } # --------------------------------------------------------------------------- # Printing # --------------------------------------------------------------------------- def print_metrics_table(title: str, metrics: dict, show_counts: bool = False) -> None: print(f"\n{'=' * 72}") print(f" {title}") print(f"{'=' * 72}") per_class = metrics["per_class"] if per_class: if show_counts: header = f" {'Label':<20s} {'Prec':>7s} {'Rec':>7s} {'F1':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'Sup':>5s}" else: header = f" {'Label':<20s} {'Prec':>7s} {'Rec':>7s} {'F1':>7s} {'Sup':>5s}" print(header) print(f" {'-' * (len(header) - 2)}") for label in sorted(per_class): m = per_class[label] if show_counts: print(f" {label:<20s} {m['precision']:7.4f} {m['recall']:7.4f} {m['f1']:7.4f} " f"{m.get('tp', '-'):>5} {m.get('fp', '-'):>5} {m.get('fn', '-'):>5} {m['support']:>5}") else: print(f" {label:<20s} {m['precision']:7.4f} {m['recall']:7.4f} {m['f1']:7.4f} {m['support']:>5}") print() micro = metrics["micro"] macro = metrics["macro"] print(f" {'micro-avg':<20s} {micro['precision']:7.4f} {micro['recall']:7.4f} {micro['f1']:7.4f}") print(f" {'macro-avg':<20s} {macro['precision']:7.4f} {macro['recall']:7.4f} {macro['f1']:7.4f}") def print_comparison(exact: dict, containment: dict) -> None: print(f"\n{'=' * 72}") print(" COMPARISON: Exact-Match vs Containment Span F1") print(f"{'=' * 72}") print(f" {'Metric':<25s} {'Exact':>10s} {'Contain':>10s} {'Delta':>10s}") print(f" {'-' * 55}") for agg in ["micro", "macro"]: for m in ["precision", "recall", "f1"]: e = exact[agg][m] c = containment[agg][m] delta = c - e print(f" {agg + '-' + m:<25s} {e:10.4f} {c:10.4f} {delta:+10.4f}") # Per-class F1 comparison all_labels = sorted(set(exact["per_class"]) | set(containment["per_class"])) if all_labels: print(f"\n {'Label':<20s} {'Exact-F1':>10s} {'Cont-F1':>10s} {'Delta':>10s}") print(f" {'-' * 50}") for label in all_labels: e_f1 = exact["per_class"].get(label, {}).get("f1", 0.0) c_f1 = containment["per_class"].get(label, {}).get("f1", 0.0) print(f" {label:<20s} {e_f1:10.4f} {c_f1:10.4f} {c_f1 - e_f1:+10.4f}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser( description="Exact-match span F1 evaluation for Arcspan NER models" ) parser.add_argument("--checkpoint", required=True, help="Path to OPF checkpoint dir") parser.add_argument("--test-data", required=True, help="Path to test JSONL file") parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) parser.add_argument("--max-examples", type=int, default=None, help="Limit number of examples (for quick testing)") parser.add_argument("--decode-mode", default="viterbi", choices=["viterbi", "argmax"]) parser.add_argument("--json-out", default=None, help="Write metrics JSON to this path") args = parser.parse_args() # Load model print(f"Loading model from {args.checkpoint} on {args.device}...") from opf import OPF redactor = OPF( model=args.checkpoint, device=args.device, output_mode="typed", decode_mode=args.decode_mode, trim_whitespace=True, discard_overlapping_predicted_spans=False, ) if args.decode_mode == "viterbi": redactor.set_viterbi_decoder() print("Model loaded.") # Load test data test_path = Path(args.test_data) records: list[dict] = [] with open(test_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue records.append(json.loads(line)) if args.max_examples is not None: records = records[:args.max_examples] print(f"Loaded {len(records)} test examples from {test_path}") # Run inference all_gold: list[list[Span]] = [] all_pred: list[list[Span]] = [] n_gold_spans = 0 n_pred_spans = 0 start_time = time.perf_counter() for i, record in enumerate(records): text = record["text"] gold = parse_gold_spans(record) pred = predict_spans(redactor, text) all_gold.append(gold) all_pred.append(pred) n_gold_spans += len(gold) n_pred_spans += len(pred) if (i + 1) % 100 == 0: elapsed = time.perf_counter() - start_time print(f" [{i+1}/{len(records)}] {elapsed:.1f}s " f"({(i+1)/elapsed:.1f} ex/s)", file=sys.stderr) elapsed = time.perf_counter() - start_time print(f"\nInference complete: {len(records)} examples, " f"{n_gold_spans} gold spans, {n_pred_spans} predicted spans, " f"{elapsed:.1f}s ({len(records)/elapsed:.1f} ex/s)") # Compute metrics exact_metrics = compute_exact_match_metrics(all_gold, all_pred) containment_metrics = compute_containment_metrics(all_gold, all_pred) print_metrics_table("EXACT-MATCH Span Metrics (CoNLL/seqeval style)", exact_metrics, show_counts=True) print_metrics_table("CONTAINMENT Span Metrics (OPF native style)", containment_metrics) print_comparison(exact_metrics, containment_metrics) # Optional JSON output if args.json_out: output = { "exact_match": { "micro": exact_metrics["micro"], "macro": exact_metrics["macro"], "per_class": exact_metrics["per_class"], }, "containment": { "micro": containment_metrics["micro"], "macro": containment_metrics["macro"], "per_class": containment_metrics["per_class"], }, "n_examples": len(records), "n_gold_spans": n_gold_spans, "n_pred_spans": n_pred_spans, "checkpoint": args.checkpoint, "test_data": args.test_data, "decode_mode": args.decode_mode, } out_path = Path(args.json_out) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(output, indent=2), encoding="utf-8") print(f"\nMetrics written to {args.json_out}") print() if __name__ == "__main__": main()