fever-ner / fever_benchmark.py
Kim-el's picture
Upload fever_benchmark.py with huggingface_hub
fe0f178 verified
Raw
History Blame Contribute Delete
7.48 kB
"""FEVER Benchmark β€” frozen evaluation for testing new retrieval methods.
Usage:
from fever_benchmark import FEVERBenchmark
bench = FEVERBenchmark()
# Get the BM25 pool (top-100 docs per query)
pool = bench.load_pool() # {qid: [(docid, bm25_score), ...]}
# Evaluate your rankings
results = bench.evaluate(your_rankings)
# your_rankings: {qid: [(docid, score), ...]}
# results: {"ndcg@10": 0.6497, "recall@100": 0.871}
"""
import json, csv, math, os, gzip, hashlib, urllib.request, itertools
from collections import defaultdict
# ── Config ──
BEIR_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip"
SHA256_CORPUS = "f7ddf098ead635e46242c435547215f4..."
SHA256_QUERIES = "bf8e8fed5785ad0e4abc63268b6d0123..."
SHA256_QRELS = "6599dab4e2189ca5eb5216546021710c..."
EXPECTED_SCORES = {
"BM25 (k1=1.2, b=0.75)": 0.5214,
"MiniLM Dense": 0.6497,
"Dense + Tawatur": 0.6279,
"Dense + Shadhdh": 0.6420,
"Dense + Muttafaq": 0.6461,
"Dense + MTS": 0.5918,
}
class FEVERBenchmark:
def __init__(self, data_dir="fever_data"):
self.data_dir = data_dir
self._queries = None
self._qrels = None
self._pool = None
def download(self):
"""Download and extract BEIR FEVER if not present."""
zip_path = os.path.join(self.data_dir, "fever.zip")
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir)
if not os.path.exists(os.path.join(self.data_dir, "corpus.jsonl")):
if not os.path.exists(zip_path):
print(f"Downloading BEIR FEVER ({BEIR_URL})...")
urllib.request.urlretrieve(BEIR_URL, zip_path)
import zipfile
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(self.data_dir)
print(f"Data ready at {self.data_dir}")
@property
def queries(self):
if self._queries is None:
path = os.path.join(self.data_dir, "queries.jsonl")
self._queries = {}
with open(path) as f:
for line in f:
d = json.loads(line)
self._queries[d['_id']] = d['text']
return self._queries
@property
def qrels(self):
if self._qrels is None:
path = os.path.join(self.data_dir, "qrels", "test.tsv")
self._qrels = {}
with open(path) as f:
reader = csv.reader(f, delimiter='\t')
next(reader)
for row in reader:
if not row: continue
self._qrels.setdefault(row[0], {})[row[1]] = int(row[2])
return self._qrels
@property
def eval_qids(self):
"""Return query IDs that have both query text and qrels."""
return [qid for qid in self.queries if qid in self.qrels]
def load_pool(self, pool_path="beir_pool.json"):
"""Load BM25 top-100 pool. Falls back to file or raises.
The pool is a dict: {qid: [(docid, bm25_score), ...]}
Top-100 docs per query, sorted by BM25 score descending.
"""
if self._pool is None:
with open(pool_path) as f:
data = json.load(f)
self._pool = {qid: [(p[0], p[1]) for p in entries]
for qid, entries in data["pool"].items()}
return self._pool
@staticmethod
def ndcg10(ranked, gt):
"""Standard TREC nDCG@10 with binary relevance."""
dcg = sum((2**gt.get(did,0)-1)/math.log2(k+2)
for k,(did,_) in enumerate(ranked[:10]))
ig = sorted(gt.values(), reverse=True)
idcg = sum((2**ig[k]-1)/math.log2(k+2)
for k in range(min(10, len(ig))))
return dcg/idcg if idcg > 0 else 0.0
@staticmethod
def recall100(ranked, gt):
if not gt: return 0.0
retrieved = set(did for did,_ in ranked[:100])
relevant = set(gt.keys())
return len(retrieved & relevant) / len(relevant) if relevant else 0.0
def evaluate(self, rankings_dict, top_k=10):
"""Evaluate NDCG@k for rankings.
Args:
rankings_dict: {qid: [(docid, score), ...]}
Docids should be sorted by score descending.
Only qids in eval_qids will be evaluated.
Returns:
{"ndcg@10": float, "recall@100": float}
"""
ndcg_sum = 0.0
recall_sum = 0.0
n = 0
for qid in self.eval_qids:
if qid not in rankings_dict:
continue
ranked = rankings_dict[qid]
gt = self.qrels.get(qid, {})
ndcg_sum += self.ndcg10(ranked, gt)
recall_sum += self.recall100(ranked, gt)
n += 1
return {
"ndcg@10": ndcg_sum / max(n, 1),
"recall@100": recall_sum / max(n, 1),
"queries_evaluated": n,
"total_queries": len(self.eval_qids),
}
@staticmethod
def check_scores(results_dict):
"""Compare scores against expected baselines.
Args:
results_dict: {system_name: ndcg_score}
Returns:
{system_name: {"expected": float, "got": float, "match": bool}}
"""
checks = {}
for name, expected in EXPECTED_SCORES.items():
got = results_dict.get(name)
if got is None:
checks[name] = {"expected": expected, "got": None, "match": False}
else:
checks[name] = {
"expected": expected,
"got": round(got, 4),
"match": abs(got - expected) < 0.001,
}
return checks
def verify_reproduction(self, your_results):
"""Run check_scores and print a repro report."""
checks = self.check_scores(your_results)
all_pass = all(c["match"] for c in checks.values() if c["got"] is not None)
print("=" * 60)
print("FEVER BENCHMARK β€” Reproduction Check")
print("=" * 60)
for name, c in checks.items():
if c["got"] is None:
print(f" {name:<25} NOT TESTED")
else:
status = "βœ“" if c["match"] else "βœ— MISMATCH"
print(f" {name:<25} expected={c['expected']:.4f} got={c['got']:.4f} {status}")
print("=" * 60)
if all_pass:
print(" βœ“ All baseline scores match β€” reproduction valid.")
else:
print(" ⚠ Some scores differ β€” see mismatches above.")
return all_pass
# ── Example: load pool and evaluate dense baseline ──
if __name__ == "__main__":
bench = FEVERBenchmark()
bench.download()
print(f"Queries with qrels: {len(bench.eval_qids)}")
print(f"Sample query: qid={bench.eval_qids[0]}")
print(f" Text: {bench.queries[bench.eval_qids[0]][:80]}...")
print(f" Qrels: {len(bench.qrels.get(bench.eval_qids[0], {}))} judgments")
print("\nTo evaluate a new method:")
print(" 1. bench = FEVERBenchmark()")
print(" 2. pool = bench.load_pool('beir_pool.json')")
print(" 3. Re-rank each query's pool with your method")
print(" 4. results = bench.evaluate(your_rankings)")
print(" 5. Compare: bench.verify_reproduction({'My Method': results['ndcg@10']})")