Token Classification
GLiNER
PyTorch
ONNX
English
multilingual
named-entity-recognition
information-extraction
legal
contracts
nlp
Instructions to use agilelab-org/Contractner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- GLiNER
How to use agilelab-org/Contractner with GLiNER:
from gliner import GLiNER model = GLiNER.from_pretrained("agilelab-org/Contractner") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Evaluation script for lucasorrentino/Contractner. | |
| Downloads the test set from HuggingFace, loads the model from the local repo, | |
| and runs the full evaluation: threshold sweep, per-entity metrics, latency benchmark. | |
| Usage: | |
| uv run eval.py | |
| uv run eval.py --threshold 0.9 | |
| uv run eval.py --all-thresholds | |
| uv run eval.py --skip-latency | |
| uv run eval.py --output-dir results/ | |
| """ | |
| import argparse | |
| import json | |
| import statistics | |
| import time | |
| import warnings | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Literal, Tuple, Union | |
| import numpy as np | |
| from tqdm import tqdm | |
| # ββ Evaluation helpers (self-contained, no external dependencies) βββββββββββββ | |
| def _span_overlap(a_start, a_end, b_start, b_end): | |
| return max(0, min(a_end, b_end) - max(a_start, b_start)) | |
| def _is_match(true_entity, pred_entity, tolerance, gold_cover_thresh): | |
| t_type, (t_s, t_e), t_idx = true_entity | |
| p_type, (p_s, p_e), p_idx = pred_entity | |
| if t_idx != p_idx: | |
| return False | |
| if (t_type or "").casefold() != (p_type or "").casefold(): | |
| return False | |
| overlap = _span_overlap(t_s, t_e, p_s, p_e) | |
| if overlap == 0: | |
| return False | |
| gold_len = max(0, t_e - t_s) | |
| if gold_len == 0: | |
| return False | |
| return (overlap / gold_len) >= gold_cover_thresh | |
| def extract_tp_fp_fn(y_true_flat, y_pred_flat, tolerance=1, gold_cover_thresh=1.0): | |
| from collections import defaultdict | |
| entities_true = defaultdict(set) | |
| entities_pred = defaultdict(set) | |
| for type_name, (start, end), idx in y_true_flat: | |
| entities_true[type_name].add((type_name, (start, end), idx)) | |
| for type_name, (start, end), idx in y_pred_flat: | |
| entities_pred[type_name].add((type_name, (start, end), idx)) | |
| target_names = sorted(set(entities_true) | set(entities_pred)) | |
| tp_sum = np.zeros(len(target_names), dtype=np.int32) | |
| pred_sum = np.zeros(len(target_names), dtype=np.int32) | |
| true_sum = np.zeros(len(target_names), dtype=np.int32) | |
| for i, name in enumerate(target_names): | |
| true_set = entities_true.get(name, set()) | |
| pred_set = entities_pred.get(name, set()) | |
| pred_sum[i] = len(pred_set) | |
| true_sum[i] = len(true_set) | |
| unmatched = set(true_set) | |
| for p in pred_set: | |
| for g in unmatched: | |
| if _is_match(g, p, tolerance, gold_cover_thresh): | |
| tp_sum[i] += 1 | |
| unmatched.remove(g) | |
| break | |
| return pred_sum, tp_sum, true_sum, target_names | |
| def compute_micro_prf(pred_sum, tp_sum, true_sum): | |
| tp = tp_sum.sum() | |
| p = tp / pred_sum.sum() if pred_sum.sum() > 0 else 0.0 | |
| r = tp / true_sum.sum() if true_sum.sum() > 0 else 0.0 | |
| f = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 | |
| return float(p), float(r), float(f) | |
| def flatten_for_eval(y_true, y_pred): | |
| all_true, all_pred = [], [] | |
| for i, (true, pred) in enumerate(zip(y_true, y_pred)): | |
| all_true.extend([t[0], t[1], i] for t in true) | |
| all_pred.extend([p[0], p[1], i] for p in pred) | |
| return all_true, all_pred | |
| def map_tokens_to_chars(text, tokens): | |
| spans, pos = [], 0 | |
| for token in tokens: | |
| try: | |
| start = text.index(token, pos) | |
| spans.append((start, start + len(token))) | |
| pos = start + len(token) | |
| except ValueError: | |
| spans.append((-1, -1)) | |
| return spans | |
| def process_sample(sample): | |
| if "ner" in sample and "tokenized_text" in sample and "text" in sample: | |
| token_spans = map_tokens_to_chars(sample["text"], sample["tokenized_text"]) | |
| entities = [] | |
| for start_tok, end_tok, label in sample["ner"]: | |
| if start_tok < len(token_spans) and end_tok < len(token_spans): | |
| cs, ce = token_spans[start_tok][0], token_spans[end_tok][1] | |
| if cs != -1 and ce != -1: | |
| entities.append([label.lower(), (cs, ce)]) | |
| return entities | |
| if "entities" in sample: | |
| return [[e["label"].lower(), (e["start"], e["end"])] for e in sample["entities"]] | |
| return [] | |
| def run_inference(model, samples, labels, threshold, desc=""): | |
| preds = [] | |
| for s in tqdm(samples, desc=desc or f"thresh={threshold:.1f}", leave=False): | |
| ents = model.predict_entities(s["text"], labels, threshold=threshold) | |
| preds.append([[e["label"].lower(), (e["start"], e["end"])] for e in ents]) | |
| return preds | |
| def evaluate(ground_truth, predictions, tolerance=1): | |
| flat_true, flat_pred = flatten_for_eval(ground_truth, predictions) | |
| pred_sum, tp_sum, true_sum, names = extract_tp_fp_fn( | |
| flat_true, flat_pred, tolerance=tolerance, gold_cover_thresh=1.0 | |
| ) | |
| p, r, f = compute_micro_prf(pred_sum, tp_sum, true_sum) | |
| return p, r, f, pred_sum, tp_sum, true_sum, names | |
| # ββ Plotting ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_threshold_sweep(sweep, best_thresh, out_path): | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mtick | |
| thresholds = [s["threshold"] for s in sweep] | |
| f1s = [s["f1"] for s in sweep] | |
| precs = [s["precision"] for s in sweep] | |
| recs = [s["recall"] for s in sweep] | |
| plt.style.use("dark_background") | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| fig.patch.set_facecolor("#0a0a0a") | |
| ax.set_facecolor("#1a1a1a") | |
| ax.plot(thresholds, f1s, "o-", color="#ff6b6b", lw=2.5, label="F1", ms=8) | |
| ax.plot(thresholds, precs, "s-", color="#2ecc71", lw=2, label="Precision", ms=6) | |
| ax.plot(thresholds, recs, "^-", color="#3498db", lw=2, label="Recall", ms=6) | |
| ax.axvline(best_thresh, color="#ffd700", ls="--", alpha=0.7, | |
| label=f"Best = {best_thresh}") | |
| ax.set_xlabel("Threshold", fontsize=12) | |
| ax.set_ylabel("Score", fontsize=12) | |
| ax.set_title("Precision / Recall / F1 vs Threshold (test set)", fontsize=14, fontweight="bold") | |
| ax.legend(fontsize=10, facecolor="#2a2a2a", edgecolor="white") | |
| ax.grid(True, alpha=0.3) | |
| ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0)) | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=150, bbox_inches="tight", facecolor="#0a0a0a") | |
| plt.close() | |
| print(f" Saved {out_path}") | |
| def plot_per_entity(per_entity, micro_f1, best_thresh, out_path): | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mtick | |
| rows = sorted(per_entity, key=lambda x: x["f1"]) | |
| names = [r["entity"] for r in rows] | |
| f1s = [r["f1"] / 100 for r in rows] | |
| precs = [r["precision"]/100 for r in rows] | |
| recs = [r["recall"]/100 for r in rows] | |
| plt.style.use("dark_background") | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7)) | |
| fig.patch.set_facecolor("#0a0a0a") | |
| ax1.set_facecolor("#1a1a1a") | |
| colors = ["#ff6b6b" if f >= 0.7 else "#f0a500" if f >= 0.5 else "#e74c3c" for f in f1s] | |
| bars = ax1.barh(names, f1s, color=colors, edgecolor="white", lw=0.5) | |
| for bar, val in zip(bars, f1s): | |
| ax1.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, | |
| f"{val*100:.1f}%", va="center", fontsize=9, color="white") | |
| ax1.axvline(micro_f1, color="#ffd700", ls="--", alpha=0.6, | |
| label=f"Micro F1 = {micro_f1*100:.1f}%") | |
| ax1.set_title(f"F1 per Entity (threshold={best_thresh})", fontsize=13, fontweight="bold") | |
| ax1.set_xlim(0, 1.15) | |
| ax1.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0)) | |
| ax1.grid(True, alpha=0.3, axis="x") | |
| ax1.legend(fontsize=9, facecolor="#2a2a2a") | |
| ax2.set_facecolor("#1a1a1a") | |
| x, w = np.arange(len(names)), 0.35 | |
| ax2.barh(x - w/2, precs, w, label="Precision", color="#2ecc71", alpha=0.85) | |
| ax2.barh(x + w/2, recs, w, label="Recall", color="#3498db", alpha=0.85) | |
| ax2.set_yticks(x) | |
| ax2.set_yticklabels(names, fontsize=9) | |
| ax2.set_title("Precision vs Recall per Entity", fontsize=13, fontweight="bold") | |
| ax2.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0)) | |
| ax2.set_xlim(0, 1.05) | |
| ax2.legend(fontsize=10, facecolor="#2a2a2a", edgecolor="white") | |
| ax2.grid(True, alpha=0.3, axis="x") | |
| plt.suptitle("GLiNER ContractNER β Test Set Evaluation", fontsize=15, fontweight="bold", y=1.01) | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=150, bbox_inches="tight", facecolor="#0a0a0a") | |
| plt.close() | |
| print(f" Saved {out_path}") | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate lucasorrentino/Contractner on test set") | |
| parser.add_argument("--threshold", type=float, default=0.9, | |
| help="Confidence threshold for predictions (default: 0.9)") | |
| parser.add_argument("--all-thresholds", action="store_true", | |
| help="Sweep thresholds 0.3β0.9 to find the best F1") | |
| parser.add_argument("--skip-latency", action="store_true", | |
| help="Skip the latency benchmark") | |
| parser.add_argument("--output-dir", type=str, default=".", | |
| help="Directory to save plots and eval_results.json (default: .)") | |
| args = parser.parse_args() | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| tolerance = 1 # Β±1 char boundary tolerance | |
| # ββ Load dataset from HuggingFace βββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading dataset lucasorrentino/ContractNER from HuggingFace...") | |
| from datasets import load_dataset | |
| ds = load_dataset("lucasorrentino/ContractNER") | |
| testset = list(ds["test"]) | |
| labels = json.loads((Path(__file__).parent / "labels.json").read_text()) \ | |
| if (Path(__file__).parent / "labels.json").exists() \ | |
| else sorted({ | |
| label | |
| for s in testset | |
| for _, _, label in s.get("ner", []) | |
| }) | |
| print(f" Test set : {len(testset)} samples") | |
| print(f" Labels : {len(labels)} entity types") | |
| # ββ Load model from local repo ββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nLoading model from local repo...") | |
| from gliner import GLiNER | |
| model = GLiNER.from_pretrained(str(Path(__file__).parent)) | |
| model.eval() | |
| print(" Model loaded.") | |
| # ββ Ground truth ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ground_truth = [process_sample(s) for s in testset] | |
| total_annotations = sum(len(g) for g in ground_truth) | |
| print(f" Gold annotations: {total_annotations}") | |
| # ββ Threshold sweep βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| thresholds = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] if args.all_thresholds else [args.threshold] | |
| sweep = [] | |
| print(f"\n{'β'*60}") | |
| print(f"{'Threshold':>10} {'Precision':>10} {'Recall':>8} {'F1':>8}") | |
| print(f"{'β'*60}") | |
| for thresh in thresholds: | |
| preds = run_inference(model, testset, labels, thresh) | |
| p, r, f, *_ = evaluate(ground_truth, preds, tolerance) | |
| sweep.append({"threshold": thresh, "precision": round(p*100, 2), | |
| "recall": round(r*100, 2), "f1": round(f*100, 2)}) | |
| print(f" {thresh:>8.1f} {p*100:>9.2f}% {r*100:>7.2f}% {f*100:>7.2f}%") | |
| best = max(sweep, key=lambda x: x["f1"]) | |
| best_thresh = best["threshold"] | |
| print(f"{'β'*60}") | |
| print(f" Best threshold: {best_thresh} β F1 = {best['f1']:.2f}%\n") | |
| if args.all_thresholds: | |
| plot_threshold_sweep(sweep, best_thresh, out_dir / "threshold_sweep.png") | |
| # ββ Per-entity breakdown ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Running per-entity evaluation at threshold", best_thresh) | |
| preds_best = run_inference(model, testset, labels, best_thresh, desc="per-entity eval") | |
| p_best, r_best, f_best, pred_sum, tp_sum, true_sum, names = evaluate( | |
| ground_truth, preds_best, tolerance | |
| ) | |
| per_entity = [] | |
| for i, name in enumerate(names): | |
| p = tp_sum[i] / pred_sum[i] if pred_sum[i] > 0 else 0.0 | |
| r = tp_sum[i] / true_sum[i] if true_sum[i] > 0 else 0.0 | |
| f = 2*p*r / (p+r) if (p+r) > 0 else 0.0 | |
| per_entity.append({ | |
| "entity": name.upper(), | |
| "precision": round(p*100, 2), | |
| "recall": round(r*100, 2), | |
| "f1": round(f*100, 2), | |
| "support": int(true_sum[i]), | |
| "tp": int(tp_sum[i]), | |
| "fp": int(pred_sum[i] - tp_sum[i]), | |
| "fn": int(true_sum[i] - tp_sum[i]), | |
| }) | |
| per_entity.sort(key=lambda x: -x["f1"]) | |
| print(f"\n{'β'*70}") | |
| print(f"{'Entity':<20} {'P':>7} {'R':>7} {'F1':>7} {'Support':>8} {'TP':>5} {'FP':>5} {'FN':>5}") | |
| print(f"{'β'*70}") | |
| for row in per_entity: | |
| print(f" {row['entity']:<18} {row['precision']:>6.2f}% {row['recall']:>6.2f}% " | |
| f"{row['f1']:>6.2f}% {row['support']:>8} {row['tp']:>5} {row['fp']:>5} {row['fn']:>5}") | |
| macro_f1 = np.mean([r["f1"] for r in per_entity]) | |
| print(f"{'β'*70}") | |
| print(f" {'Micro F1':<18} {p_best*100:>6.2f}% {r_best*100:>6.2f}% {f_best*100:>6.2f}%") | |
| print(f" {'Macro F1':<18} {'':>7} {'':>7} {macro_f1:>6.2f}%") | |
| plot_per_entity(per_entity, f_best, best_thresh, out_dir / "per_entity_metrics.png") | |
| # ββ Latency benchmark βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| latency_results = {} | |
| if not args.skip_latency: | |
| N_WARMUP, N_RUNS = 3, 20 | |
| scenarios = [ | |
| ("Short (~300 chars)", testset[0]["text"][:300]), | |
| ("Medium (~800 chars)", testset[0]["text"][:800]), | |
| ("Long (full chunk)", testset[0]["text"]), | |
| ] | |
| print(f"\nLatency benchmark β CPU, {N_RUNS} runs after {N_WARMUP} warmup\n") | |
| for name, text in scenarios: | |
| for _ in range(N_WARMUP): | |
| model.predict_entities(text, labels, threshold=best_thresh) | |
| times = [] | |
| for _ in range(N_RUNS): | |
| t0 = time.perf_counter() | |
| model.predict_entities(text, labels, threshold=best_thresh) | |
| times.append((time.perf_counter() - t0) * 1000) | |
| med = statistics.median(times) | |
| p95 = float(np.percentile(times, 95)) | |
| latency_results[name.strip()] = {"chars": len(text), "median_ms": round(med, 1), "p95_ms": round(p95, 1)} | |
| print(f" {name} | {len(text):>5} chars | median {med:6.1f} ms | p95 {p95:6.1f} ms | ~{1000/med:.1f} docs/s") | |
| # ONNX comparison | |
| onnx_path = Path(__file__).parent / "model.onnx" | |
| if onnx_path.exists(): | |
| try: | |
| model_onnx = GLiNER.from_pretrained( | |
| str(Path(__file__).parent), load_onnx_model=True, onnx_model_file="model.onnx" | |
| ) | |
| text_bench = testset[0]["text"] | |
| for _ in range(N_WARMUP): | |
| model_onnx.predict_entities(text_bench, labels, threshold=best_thresh) | |
| times_onnx = [] | |
| for _ in range(N_RUNS): | |
| t0 = time.perf_counter() | |
| model_onnx.predict_entities(text_bench, labels, threshold=best_thresh) | |
| times_onnx.append((time.perf_counter() - t0) * 1000) | |
| onnx_med = statistics.median(times_onnx) | |
| pt_med = latency_results["Long (full chunk)"]["median_ms"] | |
| latency_results["onnx_full_chunk"] = {"median_ms": round(onnx_med, 1), "speedup": round(pt_med / onnx_med, 2)} | |
| print(f"\n ONNX vs PyTorch (full chunk): {onnx_med:.1f} ms vs {pt_med:.1f} ms β {pt_med/onnx_med:.2f}x speedup") | |
| except Exception as e: | |
| print(f" ONNX benchmark skipped: {e}") | |
| # ββ Export results ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = { | |
| "model": "lucasorrentino/Contractner", | |
| "dataset": "lucasorrentino/ContractNER", | |
| "test_set_size": len(testset), | |
| "threshold": best_thresh, | |
| "tolerance_chars": tolerance, | |
| "match_mode": "overlap_cover", | |
| "gold_cover_thresh": 1.0, | |
| "overall": { | |
| "precision": round(p_best * 100, 2), | |
| "recall": round(r_best * 100, 2), | |
| "f1": round(f_best * 100, 2), | |
| }, | |
| "per_entity": per_entity, | |
| "threshold_sweep": sweep, | |
| "latency_cpu": latency_results, | |
| } | |
| results_path = out_dir / "eval_results.json" | |
| results_path.write_text(json.dumps(results, indent=2)) | |
| print(f"\n Saved {results_path}") | |
| print(f"\nDone. Micro F1 = {f_best*100:.2f}% at threshold={best_thresh}") | |
| if __name__ == "__main__": | |
| main() | |