| """FEVER Benchmark β frozen evaluation for testing new retrieval methods. |
| |
| Usage: |
| from fever_benchmark import FEVERBenchmark |
| bench = FEVERBenchmark() |
| |
| # Get the BM25 pool (top-100 docs per query) |
| pool = bench.load_pool() # {qid: [(docid, bm25_score), ...]} |
| |
| # Evaluate your rankings |
| results = bench.evaluate(your_rankings) |
| # your_rankings: {qid: [(docid, score), ...]} |
| # results: {"ndcg@10": 0.6497, "recall@100": 0.871} |
| """ |
|
|
| import json, csv, math, os, gzip, hashlib, urllib.request, itertools |
| from collections import defaultdict |
|
|
| |
| BEIR_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip" |
|
|
| SHA256_CORPUS = "f7ddf098ead635e46242c435547215f4..." |
| SHA256_QUERIES = "bf8e8fed5785ad0e4abc63268b6d0123..." |
| SHA256_QRELS = "6599dab4e2189ca5eb5216546021710c..." |
|
|
| EXPECTED_SCORES = { |
| "BM25 (k1=1.2, b=0.75)": 0.5214, |
| "MiniLM Dense": 0.6497, |
| "Dense + Tawatur": 0.6279, |
| "Dense + Shadhdh": 0.6420, |
| "Dense + Muttafaq": 0.6461, |
| "Dense + MTS": 0.5918, |
| } |
|
|
| class FEVERBenchmark: |
| def __init__(self, data_dir="fever_data"): |
| self.data_dir = data_dir |
| self._queries = None |
| self._qrels = None |
| self._pool = None |
| |
| def download(self): |
| """Download and extract BEIR FEVER if not present.""" |
| zip_path = os.path.join(self.data_dir, "fever.zip") |
| if not os.path.exists(self.data_dir): |
| os.makedirs(self.data_dir) |
| if not os.path.exists(os.path.join(self.data_dir, "corpus.jsonl")): |
| if not os.path.exists(zip_path): |
| print(f"Downloading BEIR FEVER ({BEIR_URL})...") |
| urllib.request.urlretrieve(BEIR_URL, zip_path) |
| import zipfile |
| with zipfile.ZipFile(zip_path, 'r') as z: |
| z.extractall(self.data_dir) |
| print(f"Data ready at {self.data_dir}") |
| |
| @property |
| def queries(self): |
| if self._queries is None: |
| path = os.path.join(self.data_dir, "queries.jsonl") |
| self._queries = {} |
| with open(path) as f: |
| for line in f: |
| d = json.loads(line) |
| self._queries[d['_id']] = d['text'] |
| return self._queries |
| |
| @property |
| def qrels(self): |
| if self._qrels is None: |
| path = os.path.join(self.data_dir, "qrels", "test.tsv") |
| self._qrels = {} |
| with open(path) as f: |
| reader = csv.reader(f, delimiter='\t') |
| next(reader) |
| for row in reader: |
| if not row: continue |
| self._qrels.setdefault(row[0], {})[row[1]] = int(row[2]) |
| return self._qrels |
| |
| @property |
| def eval_qids(self): |
| """Return query IDs that have both query text and qrels.""" |
| return [qid for qid in self.queries if qid in self.qrels] |
| |
| def load_pool(self, pool_path="beir_pool.json"): |
| """Load BM25 top-100 pool. Falls back to file or raises. |
| |
| The pool is a dict: {qid: [(docid, bm25_score), ...]} |
| Top-100 docs per query, sorted by BM25 score descending. |
| """ |
| if self._pool is None: |
| with open(pool_path) as f: |
| data = json.load(f) |
| self._pool = {qid: [(p[0], p[1]) for p in entries] |
| for qid, entries in data["pool"].items()} |
| return self._pool |
| |
| @staticmethod |
| def ndcg10(ranked, gt): |
| """Standard TREC nDCG@10 with binary relevance.""" |
| dcg = sum((2**gt.get(did,0)-1)/math.log2(k+2) |
| for k,(did,_) in enumerate(ranked[:10])) |
| ig = sorted(gt.values(), reverse=True) |
| idcg = sum((2**ig[k]-1)/math.log2(k+2) |
| for k in range(min(10, len(ig)))) |
| return dcg/idcg if idcg > 0 else 0.0 |
| |
| @staticmethod |
| def recall100(ranked, gt): |
| if not gt: return 0.0 |
| retrieved = set(did for did,_ in ranked[:100]) |
| relevant = set(gt.keys()) |
| return len(retrieved & relevant) / len(relevant) if relevant else 0.0 |
| |
| def evaluate(self, rankings_dict, top_k=10): |
| """Evaluate NDCG@k for rankings. |
| |
| Args: |
| rankings_dict: {qid: [(docid, score), ...]} |
| Docids should be sorted by score descending. |
| Only qids in eval_qids will be evaluated. |
| |
| Returns: |
| {"ndcg@10": float, "recall@100": float} |
| """ |
| ndcg_sum = 0.0 |
| recall_sum = 0.0 |
| n = 0 |
| for qid in self.eval_qids: |
| if qid not in rankings_dict: |
| continue |
| ranked = rankings_dict[qid] |
| gt = self.qrels.get(qid, {}) |
| ndcg_sum += self.ndcg10(ranked, gt) |
| recall_sum += self.recall100(ranked, gt) |
| n += 1 |
| return { |
| "ndcg@10": ndcg_sum / max(n, 1), |
| "recall@100": recall_sum / max(n, 1), |
| "queries_evaluated": n, |
| "total_queries": len(self.eval_qids), |
| } |
| |
| @staticmethod |
| def check_scores(results_dict): |
| """Compare scores against expected baselines. |
| |
| Args: |
| results_dict: {system_name: ndcg_score} |
| |
| Returns: |
| {system_name: {"expected": float, "got": float, "match": bool}} |
| """ |
| checks = {} |
| for name, expected in EXPECTED_SCORES.items(): |
| got = results_dict.get(name) |
| if got is None: |
| checks[name] = {"expected": expected, "got": None, "match": False} |
| else: |
| checks[name] = { |
| "expected": expected, |
| "got": round(got, 4), |
| "match": abs(got - expected) < 0.001, |
| } |
| return checks |
| |
| def verify_reproduction(self, your_results): |
| """Run check_scores and print a repro report.""" |
| checks = self.check_scores(your_results) |
| all_pass = all(c["match"] for c in checks.values() if c["got"] is not None) |
| |
| print("=" * 60) |
| print("FEVER BENCHMARK β Reproduction Check") |
| print("=" * 60) |
| for name, c in checks.items(): |
| if c["got"] is None: |
| print(f" {name:<25} NOT TESTED") |
| else: |
| status = "β" if c["match"] else "β MISMATCH" |
| print(f" {name:<25} expected={c['expected']:.4f} got={c['got']:.4f} {status}") |
| print("=" * 60) |
| if all_pass: |
| print(" β All baseline scores match β reproduction valid.") |
| else: |
| print(" β Some scores differ β see mismatches above.") |
| return all_pass |
|
|
|
|
| |
| if __name__ == "__main__": |
| bench = FEVERBenchmark() |
| bench.download() |
| |
| print(f"Queries with qrels: {len(bench.eval_qids)}") |
| print(f"Sample query: qid={bench.eval_qids[0]}") |
| print(f" Text: {bench.queries[bench.eval_qids[0]][:80]}...") |
| print(f" Qrels: {len(bench.qrels.get(bench.eval_qids[0], {}))} judgments") |
| |
| print("\nTo evaluate a new method:") |
| print(" 1. bench = FEVERBenchmark()") |
| print(" 2. pool = bench.load_pool('beir_pool.json')") |
| print(" 3. Re-rank each query's pool with your method") |
| print(" 4. results = bench.evaluate(your_rankings)") |
| print(" 5. Compare: bench.verify_reproduction({'My Method': results['ndcg@10']})") |
|
|