akshay1306's picture
Upload 14 files
5dbca28 verified
"""
Evaluation Script
=================
End-to-end evaluation of the obligation extraction pipeline on the CUAD test set.
Measures:
1. QA-level: Exact Match (EM) and F1 scores
2. Pipeline-level: How many valid structured obligations are extracted
3. Per-category breakdown
Memory Safety:
- Default num_contracts=2 to prevent RAM exhaustion
- GC after each contract
- Results stored as summaries, not full objects
Usage:
python -m src.evaluate
python -m src.evaluate --model_path ckpt_obligation --num_contracts 5
"""
import argparse
import collections
import gc
import json
import logging
import os
import re
import string
import sys
from typing import Any, Dict, List, Tuple
logger = logging.getLogger(__name__)
# ═══════════════════════════════════════════════════════════════════════════════
# QA METRICS (standard SQuAD evaluation)
# ═══════════════════════════════════════════════════════════════════════════════
def normalize_answer(s: str) -> str:
"""Lower text and remove punctuation, articles, and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punctuation(text):
return "".join(ch for ch in text if ch not in string.punctuation)
return white_space_fix(remove_articles(remove_punctuation(s.lower())))
def compute_f1(prediction: str, ground_truth: str) -> float:
"""Compute token-level F1 score."""
pred_tokens = normalize_answer(prediction).split()
truth_tokens = normalize_answer(ground_truth).split()
if not pred_tokens or not truth_tokens:
return float(pred_tokens == truth_tokens)
common = collections.Counter(pred_tokens) & collections.Counter(truth_tokens)
num_common = sum(common.values())
if num_common == 0:
return 0.0
precision = num_common / len(pred_tokens)
recall = num_common / len(truth_tokens)
return 2 * precision * recall / (precision + recall)
def compute_em(prediction: str, ground_truth: str) -> float:
"""Compute exact match score."""
return float(normalize_answer(prediction) == normalize_answer(ground_truth))
def evaluate_qa_predictions(
predictions: Dict[str, str],
references: Dict[str, List[str]],
) -> Dict[str, float]:
"""Evaluate QA predictions against ground truth.
Args:
predictions: Dict of {qa_id: predicted_answer}.
references: Dict of {qa_id: [ground_truth_answers]}.
Returns:
Dict with 'exact_match' and 'f1' scores.
"""
total_em = 0.0
total_f1 = 0.0
count = 0
for qid, pred in predictions.items():
if qid not in references:
continue
golds = references[qid]
if not golds:
# No answer expected
em = float(not pred or pred.strip() == "")
f1 = em
else:
em = max(compute_em(pred, gold) for gold in golds)
f1 = max(compute_f1(pred, gold) for gold in golds)
total_em += em
total_f1 += f1
count += 1
return {
"exact_match": (total_em / count * 100) if count > 0 else 0,
"f1": (total_f1 / count * 100) if count > 0 else 0,
"num_evaluated": count,
}
# ═══════════════════════════════════════════════════════════════════════════════
# FULL PIPELINE EVALUATION
# ═══════════════════════════════════════════════════════════════════════════════
def evaluate_pipeline(
test_path: str = "data/test.json",
model_path: str = "deepset/roberta-base-squad2",
num_contracts: int = 2,
output_path: str = "evaluation_results.json",
device: str = "auto",
):
"""Run full pipeline evaluation on CUAD test contracts.
Args:
test_path: Path to test.json.
model_path: Model name or checkpoint path.
num_contracts: Number of test contracts to evaluate.
output_path: Path to save evaluation results.
device: 'auto' (detect GPU), 'cuda', or 'cpu'.
"""
from all_model_code.model_1_code.pipeline import ObligationPipeline, save_results
from all_model_code.model_1_code.train_qa import load_cuad_data
from all_model_code.model_1_code.utils import get_device
device = get_device(device)
# ─── Load test data ──────────────────────────────────────────────────
logger.info(f"Loading test data from {test_path}")
test_data = load_cuad_data(test_path)
# ─── Initialize pipeline ─────────────────────────────────────────────
config = {
"model_name": model_path,
"device": device,
"filter_min_confidence": 0.3,
"min_fields": 1, # more lenient for evaluation
}
pipeline = ObligationPipeline(config)
# ─── Process test contracts ──────────────────────────────────────────
articles = test_data["data"][:num_contracts]
all_results = {}
qa_predictions = {}
qa_references = {}
per_category_stats = collections.defaultdict(lambda: {
"total": 0, "detected": 0, "extracted": 0,
})
for i, article in enumerate(articles):
title = article.get("title", f"contract_{i}")
logger.info(f"\n{'='*60}")
logger.info(f"Evaluating [{i+1}/{len(articles)}]: {title}")
logger.info(f"{'='*60}")
for paragraph in article["paragraphs"]:
context = paragraph["context"]
results = pipeline.process(
source=context,
source_type="text",
contract_id=title,
debug=True,
)
all_results.setdefault(title, []).extend(results)
# Collect QA-level references
for qa in paragraph["qas"]:
qid = qa["id"]
if qa["is_impossible"]:
qa_references[qid] = []
else:
qa_references[qid] = [a["text"] for a in qa["answers"]]
# Extract category from question
category = qa.get("question_type", "financial")
per_category_stats[category]["total"] += 1
if not qa["is_impossible"]:
per_category_stats[category]["detected"] += 1
# Free memory between contracts
gc.collect()
# ─── Pipeline Statistics ─────────────────────────────────────────────
total_obligations = sum(len(v) for v in all_results.values())
# Count obligations with full extraction
full_extraction = sum(
1 for results in all_results.values()
for r in results
if r.get("metric_name") and r.get("threshold_value") is not None
)
# ─── Compile Results ─────────────────────────────────────────────────
eval_results = {
"pipeline_stats": {
"num_contracts_evaluated": len(articles),
"total_obligations_extracted": total_obligations,
"obligations_with_metric_and_value": full_extraction,
"per_contract": {
title: len(results) for title, results in all_results.items()
},
},
"category_stats": dict(per_category_stats),
"sample_outputs": {
title: results[:5] for title, results in all_results.items()
},
}
# ─── Print Summary ───────────────────────────────────────────────────
print("\n" + "=" * 60)
print("EVALUATION RESULTS")
print("=" * 60)
print(f"Contracts evaluated: {len(articles)}")
print(f"Total obligations found: {total_obligations}")
print(f"With metric + value: {full_extraction}")
print()
print("Per-contract breakdown:")
for title, results in all_results.items():
print(f" {title[:50]:50s} β†’ {len(results)} obligations")
if total_obligations > 0:
print(f"\nSample extracted obligations:")
for title, results in all_results.items():
for r in results[:3]:
print(f" [{title[:30]}] {r.get('metric_name', 'N/A')} "
f"{r.get('operator', '?')} {r.get('threshold_value', '?')} "
f"(conf: {r.get('confidence_score', 0):.3f})")
# ─── Save Results ────────────────────────────────────────────────────
save_results(eval_results, output_path)
logger.info(f"Evaluation results saved to {output_path}")
return eval_results
def main():
parser = argparse.ArgumentParser(description="Evaluate obligation extraction pipeline")
parser.add_argument("--test_path", default="data/test.json")
parser.add_argument("--model_path", default="deepset/roberta-base-squad2")
parser.add_argument("--num_contracts", type=int, default=2)
parser.add_argument("--output_path", default="evaluation_results.json")
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"],
help="Device: 'auto' detects GPU, 'cuda' forces GPU, 'cpu' forces CPU")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
evaluate_pipeline(**vars(args))
if __name__ == "__main__":
main()