Spaces:
Running
Running
| """ | |
| Evaluation Script | |
| ================= | |
| End-to-end evaluation of the obligation extraction pipeline on the CUAD test set. | |
| Measures: | |
| 1. QA-level: Exact Match (EM) and F1 scores | |
| 2. Pipeline-level: How many valid structured obligations are extracted | |
| 3. Per-category breakdown | |
| Memory Safety: | |
| - Default num_contracts=2 to prevent RAM exhaustion | |
| - GC after each contract | |
| - Results stored as summaries, not full objects | |
| Usage: | |
| python -m src.evaluate | |
| python -m src.evaluate --model_path ckpt_obligation --num_contracts 5 | |
| """ | |
| import argparse | |
| import collections | |
| import gc | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import string | |
| import sys | |
| from typing import Any, Dict, List, Tuple | |
| logger = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # QA METRICS (standard SQuAD evaluation) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def normalize_answer(s: str) -> str: | |
| """Lower text and remove punctuation, articles, and extra whitespace.""" | |
| def remove_articles(text): | |
| return re.sub(r"\b(a|an|the)\b", " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punctuation(text): | |
| return "".join(ch for ch in text if ch not in string.punctuation) | |
| return white_space_fix(remove_articles(remove_punctuation(s.lower()))) | |
| def compute_f1(prediction: str, ground_truth: str) -> float: | |
| """Compute token-level F1 score.""" | |
| pred_tokens = normalize_answer(prediction).split() | |
| truth_tokens = normalize_answer(ground_truth).split() | |
| if not pred_tokens or not truth_tokens: | |
| return float(pred_tokens == truth_tokens) | |
| common = collections.Counter(pred_tokens) & collections.Counter(truth_tokens) | |
| num_common = sum(common.values()) | |
| if num_common == 0: | |
| return 0.0 | |
| precision = num_common / len(pred_tokens) | |
| recall = num_common / len(truth_tokens) | |
| return 2 * precision * recall / (precision + recall) | |
| def compute_em(prediction: str, ground_truth: str) -> float: | |
| """Compute exact match score.""" | |
| return float(normalize_answer(prediction) == normalize_answer(ground_truth)) | |
| def evaluate_qa_predictions( | |
| predictions: Dict[str, str], | |
| references: Dict[str, List[str]], | |
| ) -> Dict[str, float]: | |
| """Evaluate QA predictions against ground truth. | |
| Args: | |
| predictions: Dict of {qa_id: predicted_answer}. | |
| references: Dict of {qa_id: [ground_truth_answers]}. | |
| Returns: | |
| Dict with 'exact_match' and 'f1' scores. | |
| """ | |
| total_em = 0.0 | |
| total_f1 = 0.0 | |
| count = 0 | |
| for qid, pred in predictions.items(): | |
| if qid not in references: | |
| continue | |
| golds = references[qid] | |
| if not golds: | |
| # No answer expected | |
| em = float(not pred or pred.strip() == "") | |
| f1 = em | |
| else: | |
| em = max(compute_em(pred, gold) for gold in golds) | |
| f1 = max(compute_f1(pred, gold) for gold in golds) | |
| total_em += em | |
| total_f1 += f1 | |
| count += 1 | |
| return { | |
| "exact_match": (total_em / count * 100) if count > 0 else 0, | |
| "f1": (total_f1 / count * 100) if count > 0 else 0, | |
| "num_evaluated": count, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FULL PIPELINE EVALUATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_pipeline( | |
| test_path: str = "data/test.json", | |
| model_path: str = "deepset/roberta-base-squad2", | |
| num_contracts: int = 2, | |
| output_path: str = "evaluation_results.json", | |
| device: str = "auto", | |
| ): | |
| """Run full pipeline evaluation on CUAD test contracts. | |
| Args: | |
| test_path: Path to test.json. | |
| model_path: Model name or checkpoint path. | |
| num_contracts: Number of test contracts to evaluate. | |
| output_path: Path to save evaluation results. | |
| device: 'auto' (detect GPU), 'cuda', or 'cpu'. | |
| """ | |
| from all_model_code.model_1_code.pipeline import ObligationPipeline, save_results | |
| from all_model_code.model_1_code.train_qa import load_cuad_data | |
| from all_model_code.model_1_code.utils import get_device | |
| device = get_device(device) | |
| # βββ Load test data ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info(f"Loading test data from {test_path}") | |
| test_data = load_cuad_data(test_path) | |
| # βββ Initialize pipeline βββββββββββββββββββββββββββββββββββββββββββββ | |
| config = { | |
| "model_name": model_path, | |
| "device": device, | |
| "filter_min_confidence": 0.3, | |
| "min_fields": 1, # more lenient for evaluation | |
| } | |
| pipeline = ObligationPipeline(config) | |
| # βββ Process test contracts ββββββββββββββββββββββββββββββββββββββββββ | |
| articles = test_data["data"][:num_contracts] | |
| all_results = {} | |
| qa_predictions = {} | |
| qa_references = {} | |
| per_category_stats = collections.defaultdict(lambda: { | |
| "total": 0, "detected": 0, "extracted": 0, | |
| }) | |
| for i, article in enumerate(articles): | |
| title = article.get("title", f"contract_{i}") | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"Evaluating [{i+1}/{len(articles)}]: {title}") | |
| logger.info(f"{'='*60}") | |
| for paragraph in article["paragraphs"]: | |
| context = paragraph["context"] | |
| results = pipeline.process( | |
| source=context, | |
| source_type="text", | |
| contract_id=title, | |
| debug=True, | |
| ) | |
| all_results.setdefault(title, []).extend(results) | |
| # Collect QA-level references | |
| for qa in paragraph["qas"]: | |
| qid = qa["id"] | |
| if qa["is_impossible"]: | |
| qa_references[qid] = [] | |
| else: | |
| qa_references[qid] = [a["text"] for a in qa["answers"]] | |
| # Extract category from question | |
| category = qa.get("question_type", "financial") | |
| per_category_stats[category]["total"] += 1 | |
| if not qa["is_impossible"]: | |
| per_category_stats[category]["detected"] += 1 | |
| # Free memory between contracts | |
| gc.collect() | |
| # βββ Pipeline Statistics βββββββββββββββββββββββββββββββββββββββββββββ | |
| total_obligations = sum(len(v) for v in all_results.values()) | |
| # Count obligations with full extraction | |
| full_extraction = sum( | |
| 1 for results in all_results.values() | |
| for r in results | |
| if r.get("metric_name") and r.get("threshold_value") is not None | |
| ) | |
| # βββ Compile Results βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| eval_results = { | |
| "pipeline_stats": { | |
| "num_contracts_evaluated": len(articles), | |
| "total_obligations_extracted": total_obligations, | |
| "obligations_with_metric_and_value": full_extraction, | |
| "per_contract": { | |
| title: len(results) for title, results in all_results.items() | |
| }, | |
| }, | |
| "category_stats": dict(per_category_stats), | |
| "sample_outputs": { | |
| title: results[:5] for title, results in all_results.items() | |
| }, | |
| } | |
| # βββ Print Summary βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "=" * 60) | |
| print("EVALUATION RESULTS") | |
| print("=" * 60) | |
| print(f"Contracts evaluated: {len(articles)}") | |
| print(f"Total obligations found: {total_obligations}") | |
| print(f"With metric + value: {full_extraction}") | |
| print() | |
| print("Per-contract breakdown:") | |
| for title, results in all_results.items(): | |
| print(f" {title[:50]:50s} β {len(results)} obligations") | |
| if total_obligations > 0: | |
| print(f"\nSample extracted obligations:") | |
| for title, results in all_results.items(): | |
| for r in results[:3]: | |
| print(f" [{title[:30]}] {r.get('metric_name', 'N/A')} " | |
| f"{r.get('operator', '?')} {r.get('threshold_value', '?')} " | |
| f"(conf: {r.get('confidence_score', 0):.3f})") | |
| # βββ Save Results ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| save_results(eval_results, output_path) | |
| logger.info(f"Evaluation results saved to {output_path}") | |
| return eval_results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate obligation extraction pipeline") | |
| parser.add_argument("--test_path", default="data/test.json") | |
| parser.add_argument("--model_path", default="deepset/roberta-base-squad2") | |
| parser.add_argument("--num_contracts", type=int, default=2) | |
| parser.add_argument("--output_path", default="evaluation_results.json") | |
| parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"], | |
| help="Device: 'auto' detects GPU, 'cuda' forces GPU, 'cpu' forces CPU") | |
| args = parser.parse_args() | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| evaluate_pipeline(**vars(args)) | |
| if __name__ == "__main__": | |
| main() | |