MADQA-Leaderboard / eval /evaluate.py
Borchmann's picture
Upload folder using huggingface_hub
4829aac verified
raw
history blame
14.9 kB
#!/usr/bin/env python3
"""
Evaluation CLI for Agentic Document AI.
Evaluates model predictions against the agentic-document-ai/dataset benchmark.
Usage:
python evaluate.py results.jsonl [--by-category] [--by-domain]
python evaluate.py results_*.jsonl --compare
"""
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from metrics import (
anls_star,
anls_star_llm,
aggregate_anls_star_llm,
citation_f1,
kuiper_statistic,
wasted_effort_ratio
)
def derive_hop_type(evidence: list) -> str:
"""Derive hop type from evidence list.
- single: Single page from a single document
- cross_page: Multiple pages from the same document
- cross_doc: Pages from different documents
Args:
evidence: List of dicts with 'document' and 'page' keys
Returns:
'single', 'cross_page', or 'cross_doc'
"""
if not evidence:
return 'single'
# Get unique documents and pages
documents = set()
pages = set()
for ev in evidence:
doc = ev.get('document')
page = ev.get('page')
if doc is not None:
documents.add(doc)
if doc is not None and page is not None:
pages.add((doc, page))
# Determine hop type based on evidence structure
if len(documents) > 1:
return 'cross_doc' # Multiple documents
elif len(pages) > 1:
return 'cross_page' # Multiple pages from same document
else:
return 'single' # Single page
def load_gold_standard(dataset_name: str = "agentic-document-ai/dataset", split: str = "dev"):
"""Load gold standard from HuggingFace dataset.
Returns two mappings:
- by_text: question text -> gold data (primary)
- by_id: question id -> gold data (fallback)
"""
print(f"Loading {dataset_name} ({split} split)...")
dataset = load_dataset(dataset_name, split=split)
by_text = {}
by_id = {}
for ex in dataset:
question = ex['question'].strip()
qid = ex.get('id', '')
evidence = ex.get('evidence', [])
gold_data = {
'answers': ex.get('answer_variants', []),
'evidence': evidence,
'category': ex.get('document_category', ''),
'domain': ex.get('domain', ''),
'hop_type': derive_hop_type(evidence)
}
by_text[question] = gold_data
if qid:
by_id[qid] = gold_data
print(f"Loaded {len(by_text)} gold examples")
return by_text, by_id
def load_results(filepath: Path) -> List[Dict]:
"""Load results from JSONL file."""
results = []
with open(filepath) as f:
for line in f:
if line.strip():
results.append(json.loads(line))
return results
def evaluate_single(
result: Dict,
gold_by_text: Dict[str, Dict],
gold_by_id: Dict[str, Dict],
use_semantic: bool = False
) -> Optional[Dict[str, Any]]:
"""Evaluate a single prediction.
Matches by question text first, falls back to question ID if not found.
Args:
result: Prediction dict with 'question', 'answer', 'citations'
gold_by_text: Gold data indexed by question text
gold_by_id: Gold data indexed by question ID
use_semantic: If True, also compute semantic accuracy with LLM judge
"""
question = result.get('question', '').strip()
qid = result.get('id', '')
# Try matching by question text first
if question in gold_by_text:
gold_data = gold_by_text[question]
elif qid and qid in gold_by_id:
# Fallback to ID-based matching
gold_data = gold_by_id[qid]
else:
return None
answer = result.get('answer', '')
citations = result.get('citations', [])
# ANLS*
anls = anls_star(answer, gold_data['answers'])
# Semantic accuracy with LLM judge (if enabled)
if use_semantic:
llm_result = anls_star_llm(answer, gold_data['answers'], question)
semantic = llm_result['score']
correct = semantic >= 0.5
else:
semantic = anls
correct = anls >= 0.5
# Citation F1
doc_f1 = citation_f1(citations, gold_data['evidence'], level='document')
page_f1 = citation_f1(citations, gold_data['evidence'], level='page')
# Steps (for Kuiper)
search_history = result.get('search_history', [])
steps = len(search_history) if search_history else result.get('iterations', 0)
return {
'question': question,
'anls': anls,
'semantic': semantic,
'correct': correct,
'doc_f1': doc_f1['f1'],
'page_f1': page_f1['f1'],
'steps': steps,
'category': gold_data['category'],
'domain': gold_data['domain'],
'hop_type': gold_data.get('hop_type', 'single')
}
def aggregate_metrics(evals: List[Dict], use_semantic: bool = False) -> Dict[str, Any]:
"""Aggregate metrics across evaluations."""
if not evals:
return {}
n = len(evals)
accuracy = sum(e['correct'] for e in evals) / n
mean_anls = sum(e['anls'] for e in evals) / n
mean_doc_f1 = sum(e['doc_f1'] for e in evals) / n
mean_page_f1 = sum(e['page_f1'] for e in evals) / n
# Semantic accuracy with bias correction
if use_semantic and 'semantic' in evals[0]:
semantic_scores = [e['semantic'] for e in evals]
agg = aggregate_anls_star_llm(semantic_scores, apply_bias_correction=True)
mean_semantic = agg['adjusted_score']
semantic_ci = (agg['ci_lower'], agg['ci_upper'])
else:
mean_semantic = mean_anls
semantic_ci = None
# Kuiper
kuiper = kuiper_statistic(evals)
wasted = wasted_effort_ratio(evals)
return {
'n': n,
'accuracy': accuracy,
'mean_anls': mean_anls,
'mean_semantic': mean_semantic,
'semantic_ci': semantic_ci,
'doc_f1': mean_doc_f1,
'page_f1': mean_page_f1,
'kuiper_stat': kuiper['kuiper_stat'],
'kuiper_degenerate': kuiper['degenerate'],
'wasted_effort_ratio': wasted['ratio'],
'mean_steps_correct': wasted['mean_steps_correct'],
'mean_steps_incorrect': wasted['mean_steps_incorrect'],
}
def print_metrics(name: str, metrics: Dict, indent: int = 0, use_semantic: bool = False):
"""Print metrics in a formatted way."""
prefix = " " * indent
if 'n' not in metrics:
print(f"{prefix}{name}: No data")
return
print(f"{prefix}{name} (n={metrics['n']}):")
if use_semantic and 'mean_semantic' in metrics:
ci = metrics.get('semantic_ci')
ci_str = f" [{ci[0]:.2%}-{ci[1]:.2%}]" if ci else ""
print(f"{prefix} Semantic Accuracy: {metrics['mean_semantic']:.2%}{ci_str}")
print(f"{prefix} ANLS* (string): {metrics['mean_anls']:.4f}")
else:
print(f"{prefix} Accuracy (ANLS*≥0.5): {metrics['accuracy']:.1%}")
print(f"{prefix} Mean ANLS*: {metrics['mean_anls']:.4f}")
print(f"{prefix} Document F1: {metrics['doc_f1']:.4f}")
print(f"{prefix} Page F1: {metrics['page_f1']:.4f}")
if not metrics.get('kuiper_degenerate'):
print(f"{prefix} Kuiper Statistic: {metrics['kuiper_stat']:.2f}")
if metrics.get('wasted_effort_ratio', 0) < float('inf'):
print(f"{prefix} Wasted Effort Ratio: {metrics['wasted_effort_ratio']:.3f}")
def evaluate_file(
filepath: Path,
gold_by_text: Dict[str, Dict],
gold_by_id: Dict[str, Dict],
by_category: bool = False,
by_domain: bool = False,
by_hop_type: bool = True,
use_semantic: bool = False
) -> Dict[str, Any]:
"""Evaluate a single results file."""
results = load_results(filepath)
evals = []
unmatched = 0
total = len(results)
for i, result in enumerate(results):
if use_semantic and (i + 1) % 50 == 0:
print(f" Processing {i+1}/{total}...")
ev = evaluate_single(result, gold_by_text, gold_by_id, use_semantic=use_semantic)
if ev:
evals.append(ev)
else:
unmatched += 1
if unmatched > 0:
print(f" Warning: {unmatched} questions not found in gold standard")
# Overall metrics
overall = aggregate_metrics(evals, use_semantic=use_semantic)
output = {'overall': overall, 'use_semantic': use_semantic}
# By hop type (always included by default)
if by_hop_type:
by_hop = defaultdict(list)
for e in evals:
by_hop[e.get('hop_type', 'single')].append(e)
output['by_hop_type'] = {hop: aggregate_metrics(items, use_semantic) for hop, items in sorted(by_hop.items())}
# By category
if by_category:
by_cat = defaultdict(list)
for e in evals:
by_cat[e['category'] or 'Unknown'].append(e)
output['by_category'] = {cat: aggregate_metrics(items, use_semantic) for cat, items in sorted(by_cat.items())}
# By domain
if by_domain:
by_dom = defaultdict(list)
for e in evals:
by_dom[e['domain'] or 'Other'].append(e)
output['by_domain'] = {dom: aggregate_metrics(items, use_semantic) for dom, items in sorted(by_dom.items())}
return output
def main():
parser = argparse.ArgumentParser(
description="Evaluate model predictions on Agentic Document AI benchmark",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python evaluate.py results.jsonl
python evaluate.py results.jsonl --by-category --by-domain
python evaluate.py model1.jsonl model2.jsonl --compare
"""
)
parser.add_argument('files', nargs='+', type=Path, help='Result JSONL file(s)')
parser.add_argument('--dataset', default='agentic-document-ai/dataset',
help='HuggingFace dataset name')
parser.add_argument('--split', default='dev', help='Dataset split to evaluate on')
parser.add_argument('--by-category', action='store_true', help='Show metrics by document category')
parser.add_argument('--by-domain', action='store_true', help='Show metrics by domain')
parser.add_argument('--compare', action='store_true', help='Compare multiple models side-by-side')
parser.add_argument('--json', action='store_true', help='Output as JSON')
parser.add_argument('--semantic', action='store_true',
help='Use semantic accuracy (ANLS* + LLM judge) instead of pure ANLS*. Requires GOOGLE_API_KEY.')
args = parser.parse_args()
# Load gold standard
gold_by_text, gold_by_id = load_gold_standard(args.dataset, args.split)
if not gold_by_text:
print("Error: No gold standard data loaded", file=sys.stderr)
sys.exit(1)
all_results = {}
for filepath in args.files:
if not filepath.exists():
print(f"Error: File not found: {filepath}", file=sys.stderr)
continue
# Extract model name
name = filepath.stem
if name.startswith("results_"):
name = name[8:]
if name.endswith("_results"):
name = name[:-8]
print(f"\nEvaluating: {filepath.name}")
if args.semantic:
print(" Using semantic accuracy (ANLS* + LLM judge)...")
result = evaluate_file(
filepath, gold_by_text, gold_by_id,
args.by_category, args.by_domain,
use_semantic=args.semantic
)
all_results[name] = result
# Output
if args.json:
# Convert for JSON serialization
def sanitize(obj):
if isinstance(obj, float) and (obj != obj or obj == float('inf')): # NaN or inf
return None
if isinstance(obj, dict):
return {k: sanitize(v) for k, v in obj.items()}
if isinstance(obj, list):
return [sanitize(v) for v in obj]
return obj
print(json.dumps(sanitize(all_results), indent=2))
else:
# Print formatted output
print("\n" + "=" * 70)
print("EVALUATION RESULTS")
print("=" * 70)
if args.compare and len(all_results) > 1:
# Comparison table
models = list(all_results.keys())
if args.semantic:
print(f"\n{'Model':<35} {'Semantic':<10} {'ANLS*':<8} {'Doc F1':<8} {'Page F1':<8} {'Kuiper':<8}")
print("-" * 85)
for model in sorted(models, key=lambda m: -all_results[m]['overall'].get('mean_semantic', 0)):
m = all_results[model]['overall']
kuiper_str = f"{m['kuiper_stat']:.2f}" if not m.get('kuiper_degenerate') else "N/A"
print(f"{model:<35} {m.get('mean_semantic', 0):.1%} {m.get('mean_anls', 0):.4f} "
f"{m.get('doc_f1', 0):.4f} {m.get('page_f1', 0):.4f} {kuiper_str}")
else:
print(f"\n{'Model':<35} {'Acc':<8} {'ANLS*':<8} {'Doc F1':<8} {'Page F1':<8} {'Kuiper':<8}")
print("-" * 75)
for model in sorted(models, key=lambda m: -all_results[m]['overall'].get('accuracy', 0)):
m = all_results[model]['overall']
kuiper_str = f"{m['kuiper_stat']:.2f}" if not m.get('kuiper_degenerate') else "N/A"
print(f"{model:<35} {m.get('accuracy', 0):.1%} {m.get('mean_anls', 0):.4f} "
f"{m.get('doc_f1', 0):.4f} {m.get('page_f1', 0):.4f} {kuiper_str}")
else:
# Detailed per-model output
for model, result in all_results.items():
print(f"\n{'─' * 40}")
use_sem = result.get('use_semantic', False)
print_metrics(model, result['overall'], use_semantic=use_sem)
if 'by_category' in result:
print(f"\n By Category:")
for cat, metrics in sorted(result['by_category'].items(),
key=lambda x: -x[1].get('n', 0)):
print_metrics(cat, metrics, indent=2, use_semantic=use_sem)
if 'by_domain' in result:
print(f"\n By Domain:")
for dom, metrics in sorted(result['by_domain'].items(),
key=lambda x: -x[1].get('n', 0)):
print_metrics(dom, metrics, indent=2, use_semantic=use_sem)
print()
if __name__ == "__main__":
main()