MADQA-Leaderboard / eval /batch_reevaluate.py
Borchmann's picture
Upload folder using huggingface_hub
9794ac3 verified
raw
history blame
17.7 kB
#!/usr/bin/env python3
"""
Batch re-evaluate all submissions with the new Semantic Accuracy metric.
This script downloads all prediction files from HuggingFace Hub and re-evaluates
them with the ANLS* + LLM judge metric.
Usage:
# Dry run - list files only
python batch_reevaluate.py --dry-run
# Re-evaluate all files
python batch_reevaluate.py
# Re-evaluate specific organization
python batch_reevaluate.py --org OpenAI
# Upload results after review
python batch_reevaluate.py --upload
"""
import json
import os
import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
from datasets import load_dataset
# Add parent for imports
sys.path.insert(0, str(Path(__file__).parent))
from metrics import (
anls_star,
anls_star_llm,
aggregate_anls_star_llm,
citation_f1,
kuiper_statistic
)
# Parallelization config
MAX_WORKERS = 24
# Config
RESULTS_REPO = "agentic-document-ai/backend-results"
TOKEN = os.environ.get("HF_TOKEN")
OUTPUT_DIR = Path(__file__).parent / "reevaluated_results"
def load_gold_data():
"""Load gold standard from HuggingFace."""
print("Loading gold standard...")
dataset = load_dataset("agentic-document-ai/dataset-PRIVATE", split="test")
gold_by_id = {}
gold_by_text = {}
def _derive_hop_type(evidence: list) -> str:
if not evidence:
return 'single'
documents = set()
pages = set()
for ev in evidence:
doc = ev.get('document')
page = ev.get('page')
if doc is not None:
documents.add(doc)
if page is not None:
pages.add(page)
if len(documents) > 1:
return 'cross_doc'
if len(pages) > 1:
return 'cross_page'
return 'single'
for ex in dataset:
qid = ex.get('id', '')
question = ex['question'].strip()
evidence = ex.get('evidence', [])
data = {
'question': question,
'answers': ex.get('answer_variants', []),
'evidence': evidence,
'category': ex.get('document_category', ''),
'domain': ex.get('domain', ''),
'hop_type': _derive_hop_type(evidence),
}
gold_by_id[qid] = data
gold_by_text[question] = data
return gold_by_id, gold_by_text
def find_prediction_files(org_filter: str = None):
"""Find all prediction JSONL files in the results repo."""
files = list_repo_files(RESULTS_REPO, repo_type="dataset", token=TOKEN)
pred_files = [f for f in files if '_predictions' in f and f.endswith('.jsonl')]
if org_filter:
pred_files = [f for f in pred_files if f.startswith(org_filter + '/')]
return pred_files
def find_result_file(pred_file: str):
"""Find the corresponding results JSON file for a predictions file."""
# Pattern: {org}/{model}_predictions_{timestamp}.jsonl -> {org}/{model}_results_{timestamp}.json
parts = pred_file.rsplit('_predictions_', 1)
if len(parts) == 2:
result_file = parts[0] + '_results_' + parts[1].replace('.jsonl', '.json')
return result_file
return None
def download_file(filepath: str) -> str:
"""Download a file from HuggingFace Hub."""
return hf_hub_download(
repo_id=RESULTS_REPO,
filename=filepath,
repo_type="dataset",
token=TOKEN
)
def _evaluate_single_prediction(args, max_retries=3):
"""Evaluate a single prediction (for parallel processing)."""
idx, pred, gold_data = args
answer = pred.get('answer', '')
question = pred.get('question', '').strip()
citations = pred.get('citations', [])
search_history = pred.get('search_history', [])
steps = len(search_history) if search_history else pred.get('iterations', 0)
# Calculate non-LLM metrics first
anls = anls_star(answer, gold_data['answers'])
doc_f1 = citation_f1(citations, gold_data['evidence'], level='document')
page_f1 = citation_f1(citations, gold_data['evidence'], level='page')
# Retry LLM call on failure
for attempt in range(max_retries):
try:
llm_result = anls_star_llm(answer, gold_data['answers'], question)
semantic_score = llm_result['score']
break
except Exception as e:
if attempt < max_retries - 1:
print(f" Item {idx} attempt {attempt+1} failed: {e}, retrying...")
time.sleep(2 ** attempt) # Exponential backoff
else:
print(f" Failed item {idx} after {max_retries} retries: {e}")
raise
return {
'idx': idx,
'anls': anls,
'semantic_score': semantic_score,
'correct': semantic_score >= 0.5,
'doc_f1': doc_f1['f1'],
'page_f1': page_f1['f1'],
'steps': steps,
'hop_type': gold_data.get('hop_type', 'single'),
'category': gold_data['category'],
'domain': gold_data['domain']
}
def evaluate_with_semantic(predictions: list, gold_by_id: dict, gold_by_text: dict) -> dict:
"""Evaluate predictions with semantic accuracy metric (parallelized)."""
# First, filter predictions to only those in test set
matched_predictions = []
for pred in predictions:
question = pred.get('question', '').strip()
qid = pred.get('id', '')
gold_data = None
if question in gold_by_text:
gold_data = gold_by_text[question]
elif qid and qid in gold_by_id:
gold_data = gold_by_id[qid]
if gold_data:
matched_predictions.append((pred, gold_data))
unmatched = len(predictions) - len(matched_predictions)
print(f" Matched {len(matched_predictions)}/{len(predictions)} predictions to test set (skipping {unmatched})")
total = len(matched_predictions)
evals = []
completed = 0
# Prepare items with index for tracking
items_with_idx = [(i, pred, gold) for i, (pred, gold) in enumerate(matched_predictions)]
# Parallel evaluation with ThreadPoolExecutor
print(f" Evaluating with {MAX_WORKERS} parallel workers...")
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(_evaluate_single_prediction, item): item[0]
for item in items_with_idx}
completed_indices = set()
try:
for future in as_completed(futures, timeout=600): # 10 min overall timeout
try:
result = future.result(timeout=120) # 2 min per item max
evals.append(result)
completed_indices.add(result['idx'])
completed += 1
if completed % 50 == 0 or completed == total:
print(f" Progress: {completed}/{total}")
except TimeoutError:
idx = futures[future]
print(f" TIMEOUT: Item {idx} took too long, skipping")
completed += 1
except TimeoutError:
# Find which items are still pending
pending = set(range(total)) - completed_indices
print(f" OVERALL TIMEOUT: {len(pending)} items still pending: {sorted(pending)[:10]}...")
# Cancel remaining futures
for future in futures:
future.cancel()
if not evals:
return None
# Aggregate
n = len(evals)
semantic_scores = [e['semantic_score'] for e in evals]
agg = aggregate_anls_star_llm(semantic_scores, apply_bias_correction=True)
mean_anls = sum(e['anls'] for e in evals) / n * 100
mean_doc_f1 = sum(e['doc_f1'] for e in evals) / n * 100
mean_page_f1 = sum(e['page_f1'] for e in evals) / n * 100
kuiper = kuiper_statistic(evals)
# By hop type
single_hop = [e for e in evals if e['hop_type'] == 'single']
cross_page = [e for e in evals if e['hop_type'] == 'cross_page']
cross_doc = [e for e in evals if e['hop_type'] == 'cross_doc']
# By domain
by_domain = defaultdict(list)
for e in evals:
domain = e['domain'] or 'Other'
by_domain[domain].append(e)
domain_scores = {}
for domain, domain_evals in sorted(by_domain.items()):
domain_semantic_scores = [e['semantic_score'] for e in domain_evals]
domain_agg = aggregate_anls_star_llm(domain_semantic_scores, apply_bias_correction=True)
domain_scores[domain] = {
'semantic': domain_agg['adjusted_score'] * 100,
'anls': sum(e['anls'] for e in domain_evals) / len(domain_evals) * 100,
'n': len(domain_evals)
}
return {
'overall': {
'semantic': agg['adjusted_score'] * 100,
'semantic_ci': (agg['ci_lower'] * 100, agg['ci_upper'] * 100), # 95% CI
'anls': mean_anls,
'page_f1': mean_page_f1,
'doc_f1': mean_doc_f1,
'kuiper': kuiper['kuiper_stat'] if not kuiper.get('degenerate') else None,
},
'single_evidence': {
'semantic': aggregate_anls_star_llm([e['semantic_score'] for e in single_hop], apply_bias_correction=True)['adjusted_score'] * 100 if single_hop else 0,
'anls': sum(e['anls'] for e in single_hop) / len(single_hop) * 100 if single_hop else 0,
'n': len(single_hop)
},
'multi_evidence_same_doc': {
'semantic': aggregate_anls_star_llm([e['semantic_score'] for e in cross_page], apply_bias_correction=True)['adjusted_score'] * 100 if cross_page else 0,
'anls': sum(e['anls'] for e in cross_page) / len(cross_page) * 100 if cross_page else 0,
'n': len(cross_page)
},
'multi_evidence_multi_doc': {
'semantic': aggregate_anls_star_llm([e['semantic_score'] for e in cross_doc], apply_bias_correction=True)['adjusted_score'] * 100 if cross_doc else 0,
'anls': sum(e['anls'] for e in cross_doc) / len(cross_doc) * 100 if cross_doc else 0,
'n': len(cross_doc)
},
'by_domain': domain_scores,
'n_evaluated': n,
'n_unmatched': unmatched
}
def main():
import argparse
parser = argparse.ArgumentParser(description="Batch re-evaluate submissions")
parser.add_argument('--dry-run', action='store_true', help="List files only, don't evaluate")
parser.add_argument('--org', type=str, help="Filter by organization (e.g., 'OpenAI')")
parser.add_argument('--upload', action='store_true', help="Upload already processed results to HuggingFace Hub (no re-evaluation)")
parser.add_argument('--skip-existing', action='store_true', help="Skip already evaluated files")
args = parser.parse_args()
OUTPUT_DIR.mkdir(exist_ok=True)
# Upload-only mode: just upload existing files
if args.upload:
print("Uploading existing results to HuggingFace Hub...")
api = HfApi()
result_files = list(OUTPUT_DIR.glob("**/*.json"))
print(f"Found {len(result_files)} result files to upload")
for result_file in result_files:
rel_path = result_file.relative_to(OUTPUT_DIR)
print(f" Uploading: {rel_path}")
try:
api.upload_file(
path_or_fileobj=str(result_file),
path_in_repo=str(rel_path),
repo_id=RESULTS_REPO,
repo_type="dataset",
token=TOKEN,
commit_message=f"Re-evaluate with semantic accuracy: {rel_path.stem}"
)
print(f" ✓ Done")
except Exception as e:
print(f" ✗ Error: {e}")
print("\nUpload complete!")
return
# Find prediction files
print("Finding prediction files...")
pred_files = find_prediction_files(args.org)
print(f"Found {len(pred_files)} prediction files")
if args.dry_run:
for f in pred_files:
print(f" - {f}")
return
# Load gold standard
gold_by_id, gold_by_text = load_gold_data()
print(f"Loaded {len(gold_by_id)} gold examples")
# Process each file
for i, pred_file in enumerate(pred_files):
print(f"\n{'='*60}")
print(f"[{i+1}/{len(pred_files)}] Processing: {pred_file}")
print('='*60)
# Check if already processed
output_file = OUTPUT_DIR / (Path(pred_file).stem.replace('_predictions', '_results') + '_reevaluated.json')
if args.skip_existing and output_file.exists():
print(" Skipping (already processed)")
continue
try:
# Download predictions
print(" Downloading predictions...")
local_pred = download_file(pred_file)
predictions = []
with open(local_pred) as f:
for line in f:
if line.strip():
predictions.append(json.loads(line))
print(f" Loaded {len(predictions)} predictions")
# Download original results to preserve metadata
result_file = find_result_file(pred_file)
original_metadata = {}
if result_file:
try:
local_result = download_file(result_file)
with open(local_result) as f:
original_data = json.load(f)
original_metadata = {
'model_name': original_data.get('model_name'),
'organization': original_data.get('organization'),
'description': original_data.get('description'),
'link': original_data.get('link'),
'tags': original_data.get('tags'),
'submitted_by': original_data.get('submitted_by'),
'metadata': original_data.get('metadata'),
'submission_date': original_data.get('submission_date'),
}
print(f" Loaded metadata: model_name={original_metadata.get('model_name')}")
except Exception as e:
print(f" Warning: Could not load original results: {e}")
# Fallback: extract metadata from filename if not found
if not original_metadata.get('model_name'):
# Pattern: Org/Model_Name_with_Stuff_predictions_timestamp.jsonl
filename = Path(pred_file).stem # e.g., GPT-5_(2025-08-07)_with_BM25_Search_Tool_predictions_20260109_152104
parts = filename.rsplit('_predictions_', 1)
if parts:
model_name = parts[0].replace('_', ' ') # Convert underscores to spaces
org = Path(pred_file).parts[0] if '/' in pred_file else 'Unknown'
original_metadata = {
'model_name': model_name,
'organization': org.replace('_', ' '),
'description': '',
'tags': ['Agentic'],
'metadata': {'model_type': 'unknown'},
}
print(f" Using fallback metadata: model_name={model_name}, org={org}")
# Evaluate
print(" Evaluating with semantic accuracy...")
start_time = time.time()
results = evaluate_with_semantic(predictions, gold_by_id, gold_by_text)
elapsed = time.time() - start_time
if results:
print(f"\n Results (took {elapsed:.1f}s):")
print(f" Semantic Accuracy: {results['overall']['semantic']:.1f}")
print(f" ANLS*: {results['overall']['anls']:.1f}")
print(f" Page F1: {results['overall']['page_f1']:.1f}")
# Save with original metadata
org = Path(pred_file).parts[0] if '/' in pred_file else 'Unknown'
output_filename = Path(pred_file).name.replace('_predictions', '_results').replace('.jsonl', '.json')
full_result = {
**original_metadata,
'results': results,
'reevaluated_date': datetime.now(timezone.utc).isoformat(),
'source_predictions_file': pred_file,
'result_file_path': f"{org}/{output_filename}",
}
# Create org subfolder
org_dir = OUTPUT_DIR / org
org_dir.mkdir(exist_ok=True)
output_file = org_dir / output_filename
with open(output_file, 'w') as f:
json.dump(full_result, f, indent=2)
print(f" Saved to: {output_file}")
else:
print(" No valid evaluations")
except Exception as e:
print(f" Error: {e}")
import traceback
traceback.print_exc()
continue
print(f"\n{'='*60}")
print("DONE!")
print(f"Results saved to: {OUTPUT_DIR}")
print(f"\nTo upload results, run: python batch_reevaluate.py --upload")
if __name__ == "__main__":
main()