arcspan / scripts /eval_exact_match.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Exact-match span F1 evaluation for Arcspan NER models.
Computes both exact-boundary span F1 (CoNLL/seqeval style) and OPF's native
containment-based span F1, printing them side by side for comparison.
Usage:
python3 scripts/eval_exact_match.py \
--checkpoint checkpoints/r8_5class/ \
--test-data data/processed/aptner_5class_test_clean.jsonl \
--device cuda
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Sequence
# ---------------------------------------------------------------------------
# Span type: (label: str, start: int, end: int) — character offsets
# ---------------------------------------------------------------------------
Span = tuple[str, int, int]
def parse_gold_spans(record: dict) -> list[Span]:
"""Extract gold spans from a JSONL record with OPF 'spans' format.
Format: {"spans": {"Label: text": [[start, end], ...], ...}}
"""
spans_field = record.get("spans", {})
if not spans_field:
return []
result: list[Span] = []
for key, offsets in spans_field.items():
# key format: "Label: matched_text"
label = key.split(":")[0].strip()
for offset_pair in offsets:
start, end = int(offset_pair[0]), int(offset_pair[1])
result.append((label, start, end))
return result
def predict_spans(redactor, text: str) -> list[Span]:
"""Run OPF inference and return predicted spans as (label, start, end)."""
result = redactor.redact(text)
spans: list[Span] = []
for det in result.detected_spans:
spans.append((det.label, det.start, det.end))
return spans
# ---------------------------------------------------------------------------
# Exact-match metrics
# ---------------------------------------------------------------------------
def _span_set(spans: list[Span]) -> set[Span]:
return set(spans)
def compute_exact_match_metrics(
all_gold: list[list[Span]],
all_pred: list[list[Span]],
) -> dict:
"""Compute exact-match span-level P/R/F1 (micro + macro + per-class)."""
# Per-class accumulators
class_tp: defaultdict[str, int] = defaultdict(int)
class_fp: defaultdict[str, int] = defaultdict(int)
class_fn: defaultdict[str, int] = defaultdict(int)
for gold_spans, pred_spans in zip(all_gold, all_pred):
gold_set = _span_set(gold_spans)
pred_set = _span_set(pred_spans)
# TP: in both gold and pred (exact label + start + end)
tp_spans = gold_set & pred_set
fp_spans = pred_set - gold_set
fn_spans = gold_set - pred_set
for label, _, _ in tp_spans:
class_tp[label] += 1
for label, _, _ in fp_spans:
class_fp[label] += 1
for label, _, _ in fn_spans:
class_fn[label] += 1
all_labels = sorted(set(class_tp) | set(class_fp) | set(class_fn))
# Per-class metrics
per_class = {}
for label in all_labels:
tp = class_tp[label]
fp = class_fp[label]
fn = class_fn[label]
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
per_class[label] = {"precision": p, "recall": r, "f1": f1,
"tp": tp, "fp": fp, "fn": fn,
"support": tp + fn}
# Micro-average
total_tp = sum(class_tp.values())
total_fp = sum(class_fp.values())
total_fn = sum(class_fn.values())
micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0.0
# Macro-average
if all_labels:
macro_p = sum(per_class[l]["precision"] for l in all_labels) / len(all_labels)
macro_r = sum(per_class[l]["recall"] for l in all_labels) / len(all_labels)
macro_f1 = sum(per_class[l]["f1"] for l in all_labels) / len(all_labels)
else:
macro_p = macro_r = macro_f1 = 0.0
return {
"per_class": per_class,
"micro": {"precision": micro_p, "recall": micro_r, "f1": micro_f1},
"macro": {"precision": macro_p, "recall": macro_r, "f1": macro_f1},
"total_tp": total_tp, "total_fp": total_fp, "total_fn": total_fn,
}
# ---------------------------------------------------------------------------
# Containment-match metrics (OPF style)
# ---------------------------------------------------------------------------
def compute_containment_metrics(
all_gold: list[list[Span]],
all_pred: list[list[Span]],
) -> dict:
"""Compute containment-based span P/R/F1.
Precision: predicted span is TP if it is *contained within* a gold span
with the same label.
Recall: gold span is TP if it is *contained within* a predicted span
with the same label.
"""
class_tp_p: defaultdict[str, int] = defaultdict(int) # for precision
class_fp: defaultdict[str, int] = defaultdict(int)
class_tp_r: defaultdict[str, int] = defaultdict(int) # for recall
class_fn: defaultdict[str, int] = defaultdict(int)
for gold_spans, pred_spans in zip(all_gold, all_pred):
# Precision direction: pred contained in gold
for p_label, p_s, p_e in pred_spans:
matched = False
for g_label, g_s, g_e in gold_spans:
if p_label == g_label and g_s <= p_s and g_e >= p_e:
matched = True
break
if matched:
class_tp_p[p_label] += 1
else:
class_fp[p_label] += 1
# Recall direction: gold contained in pred
for g_label, g_s, g_e in gold_spans:
matched = False
for p_label, p_s, p_e in pred_spans:
if g_label == p_label and p_s <= g_s and p_e >= g_e:
matched = True
break
if matched:
class_tp_r[g_label] += 1
else:
class_fn[g_label] += 1
all_labels = sorted(set(class_tp_p) | set(class_fp) | set(class_tp_r) | set(class_fn))
per_class = {}
for label in all_labels:
tp_p = class_tp_p[label]
fp = class_fp[label]
tp_r = class_tp_r[label]
fn = class_fn[label]
p = tp_p / (tp_p + fp) if (tp_p + fp) > 0 else 0.0
r = tp_r / (tp_r + fn) if (tp_r + fn) > 0 else 0.0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
per_class[label] = {"precision": p, "recall": r, "f1": f1,
"support": tp_r + fn}
total_tp_p = sum(class_tp_p.values())
total_fp = sum(class_fp.values())
total_tp_r = sum(class_tp_r.values())
total_fn = sum(class_fn.values())
micro_p = total_tp_p / (total_tp_p + total_fp) if (total_tp_p + total_fp) > 0 else 0.0
micro_r = total_tp_r / (total_tp_r + total_fn) if (total_tp_r + total_fn) > 0 else 0.0
micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0.0
if all_labels:
macro_p = sum(per_class[l]["precision"] for l in all_labels) / len(all_labels)
macro_r = sum(per_class[l]["recall"] for l in all_labels) / len(all_labels)
macro_f1 = sum(per_class[l]["f1"] for l in all_labels) / len(all_labels)
else:
macro_p = macro_r = macro_f1 = 0.0
return {
"per_class": per_class,
"micro": {"precision": micro_p, "recall": micro_r, "f1": micro_f1},
"macro": {"precision": macro_p, "recall": macro_r, "f1": macro_f1},
}
# ---------------------------------------------------------------------------
# Printing
# ---------------------------------------------------------------------------
def print_metrics_table(title: str, metrics: dict, show_counts: bool = False) -> None:
print(f"\n{'=' * 72}")
print(f" {title}")
print(f"{'=' * 72}")
per_class = metrics["per_class"]
if per_class:
if show_counts:
header = f" {'Label':<20s} {'Prec':>7s} {'Rec':>7s} {'F1':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'Sup':>5s}"
else:
header = f" {'Label':<20s} {'Prec':>7s} {'Rec':>7s} {'F1':>7s} {'Sup':>5s}"
print(header)
print(f" {'-' * (len(header) - 2)}")
for label in sorted(per_class):
m = per_class[label]
if show_counts:
print(f" {label:<20s} {m['precision']:7.4f} {m['recall']:7.4f} {m['f1']:7.4f} "
f"{m.get('tp', '-'):>5} {m.get('fp', '-'):>5} {m.get('fn', '-'):>5} {m['support']:>5}")
else:
print(f" {label:<20s} {m['precision']:7.4f} {m['recall']:7.4f} {m['f1']:7.4f} {m['support']:>5}")
print()
micro = metrics["micro"]
macro = metrics["macro"]
print(f" {'micro-avg':<20s} {micro['precision']:7.4f} {micro['recall']:7.4f} {micro['f1']:7.4f}")
print(f" {'macro-avg':<20s} {macro['precision']:7.4f} {macro['recall']:7.4f} {macro['f1']:7.4f}")
def print_comparison(exact: dict, containment: dict) -> None:
print(f"\n{'=' * 72}")
print(" COMPARISON: Exact-Match vs Containment Span F1")
print(f"{'=' * 72}")
print(f" {'Metric':<25s} {'Exact':>10s} {'Contain':>10s} {'Delta':>10s}")
print(f" {'-' * 55}")
for agg in ["micro", "macro"]:
for m in ["precision", "recall", "f1"]:
e = exact[agg][m]
c = containment[agg][m]
delta = c - e
print(f" {agg + '-' + m:<25s} {e:10.4f} {c:10.4f} {delta:+10.4f}")
# Per-class F1 comparison
all_labels = sorted(set(exact["per_class"]) | set(containment["per_class"]))
if all_labels:
print(f"\n {'Label':<20s} {'Exact-F1':>10s} {'Cont-F1':>10s} {'Delta':>10s}")
print(f" {'-' * 50}")
for label in all_labels:
e_f1 = exact["per_class"].get(label, {}).get("f1", 0.0)
c_f1 = containment["per_class"].get(label, {}).get("f1", 0.0)
print(f" {label:<20s} {e_f1:10.4f} {c_f1:10.4f} {c_f1 - e_f1:+10.4f}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(
description="Exact-match span F1 evaluation for Arcspan NER models"
)
parser.add_argument("--checkpoint", required=True, help="Path to OPF checkpoint dir")
parser.add_argument("--test-data", required=True, help="Path to test JSONL file")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
parser.add_argument("--max-examples", type=int, default=None,
help="Limit number of examples (for quick testing)")
parser.add_argument("--decode-mode", default="viterbi", choices=["viterbi", "argmax"])
parser.add_argument("--json-out", default=None, help="Write metrics JSON to this path")
args = parser.parse_args()
# Load model
print(f"Loading model from {args.checkpoint} on {args.device}...")
from opf import OPF
redactor = OPF(
model=args.checkpoint,
device=args.device,
output_mode="typed",
decode_mode=args.decode_mode,
trim_whitespace=True,
discard_overlapping_predicted_spans=False,
)
if args.decode_mode == "viterbi":
redactor.set_viterbi_decoder()
print("Model loaded.")
# Load test data
test_path = Path(args.test_data)
records: list[dict] = []
with open(test_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
records.append(json.loads(line))
if args.max_examples is not None:
records = records[:args.max_examples]
print(f"Loaded {len(records)} test examples from {test_path}")
# Run inference
all_gold: list[list[Span]] = []
all_pred: list[list[Span]] = []
n_gold_spans = 0
n_pred_spans = 0
start_time = time.perf_counter()
for i, record in enumerate(records):
text = record["text"]
gold = parse_gold_spans(record)
pred = predict_spans(redactor, text)
all_gold.append(gold)
all_pred.append(pred)
n_gold_spans += len(gold)
n_pred_spans += len(pred)
if (i + 1) % 100 == 0:
elapsed = time.perf_counter() - start_time
print(f" [{i+1}/{len(records)}] {elapsed:.1f}s "
f"({(i+1)/elapsed:.1f} ex/s)", file=sys.stderr)
elapsed = time.perf_counter() - start_time
print(f"\nInference complete: {len(records)} examples, "
f"{n_gold_spans} gold spans, {n_pred_spans} predicted spans, "
f"{elapsed:.1f}s ({len(records)/elapsed:.1f} ex/s)")
# Compute metrics
exact_metrics = compute_exact_match_metrics(all_gold, all_pred)
containment_metrics = compute_containment_metrics(all_gold, all_pred)
print_metrics_table("EXACT-MATCH Span Metrics (CoNLL/seqeval style)",
exact_metrics, show_counts=True)
print_metrics_table("CONTAINMENT Span Metrics (OPF native style)",
containment_metrics)
print_comparison(exact_metrics, containment_metrics)
# Optional JSON output
if args.json_out:
output = {
"exact_match": {
"micro": exact_metrics["micro"],
"macro": exact_metrics["macro"],
"per_class": exact_metrics["per_class"],
},
"containment": {
"micro": containment_metrics["micro"],
"macro": containment_metrics["macro"],
"per_class": containment_metrics["per_class"],
},
"n_examples": len(records),
"n_gold_spans": n_gold_spans,
"n_pred_spans": n_pred_spans,
"checkpoint": args.checkpoint,
"test_data": args.test_data,
"decode_mode": args.decode_mode,
}
out_path = Path(args.json_out)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(output, indent=2), encoding="utf-8")
print(f"\nMetrics written to {args.json_out}")
print()
if __name__ == "__main__":
main()