| |
| """Setup FEVER benchmark from scratch. |
| Downloads BEIR FEVER, builds Pyserini index, runs BM25 retrieval, |
| and creates the frozen pool file (beir_pool.json). |
| |
| Usage: |
| python setup_fever_benchmark.py [--data-dir ./fever_data] |
| """ |
| import os, sys, json, time, argparse, urllib.request, zipfile |
|
|
| def log(msg): |
| print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-dir", default="./fever_data") |
| parser.add_argument("--skip-download", action="store_true") |
| parser.add_argument("--skip-index", action="store_true") |
| args = parser.parse_args() |
|
|
| data_dir = args.data_dir |
| os.makedirs(data_dir, exist_ok=True) |
|
|
| |
| if not args.skip_download: |
| zip_path = os.path.join(data_dir, "fever.zip") |
| corpus_path = os.path.join(data_dir, "corpus.jsonl") |
| if not os.path.exists(corpus_path): |
| if not os.path.exists(zip_path): |
| url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip" |
| log(f"Downloading {url} ...") |
| urllib.request.urlretrieve(url, zip_path) |
| log("Download complete.") |
| log("Extracting...") |
| with zipfile.ZipFile(zip_path, 'r') as z: |
| z.extractall(data_dir) |
| log("Extraction complete.") |
| else: |
| log("Corpus already exists, skipping download.") |
| else: |
| log("Skipping download.") |
|
|
| |
| if not args.skip_index: |
| from pyserini.index.lucene import LuceneIndexer |
| import pyserini |
| |
| index_dir = os.path.join(data_dir, "pyserini_index") |
| pyserini_input = os.path.join(data_dir, "pyserini_input") |
| |
| if not os.path.exists(index_dir): |
| |
| log("Converting corpus to Pyserini input format...") |
| os.makedirs(pyserini_input, exist_ok=True) |
| corpus_path = os.path.join(data_dir, "corpus.jsonl") |
| out_path = os.path.join(pyserini_input, "docs.jsonl") |
| |
| count = 0 |
| with open(corpus_path) as fin, open(out_path, 'w') as fout: |
| for line in fin: |
| d = json.loads(line) |
| fout.write(json.dumps({ |
| "id": d["_id"], |
| "title": d.get("title", ""), |
| "text": d.get("text", ""), |
| "contents": f"{d.get('title','')} {d.get('text','')}" |
| }) + "\n") |
| count += 1 |
| if count % 1_000_000 == 0: |
| log(f" Converted {count}/{5_416_568} docs") |
| log(f"Converted {count} docs") |
| |
| |
| log("Building Pyserini index (this takes ~6 min)...") |
| os.system(f"python -m pyserini.index.lucene " |
| f"--collection JsonCollection " |
| f"--input {pyserini_input} " |
| f"--index {index_dir} " |
| f"--generator DefaultLuceneDocumentGenerator " |
| f"--threads 8 " |
| f"--storePositions --storeDocvectors --storeRaw") |
| log("Index built.") |
| else: |
| log("Index already exists, skipping.") |
| else: |
| log("Skipping index build.") |
|
|
| |
| pool_path = os.path.join(data_dir, "beir_pool.json") |
| if not os.path.exists(pool_path): |
| log("Running BM25 retrieval (k1=1.2, b=0.75)...") |
| from pyserini.search import SimpleSearcher |
| |
| searcher = SimpleSearcher(os.path.join(data_dir, "pyserini_index")) |
| searcher.set_bm25(k1=1.2, b=0.75) |
| |
| |
| queries = {} |
| with open(os.path.join(data_dir, "queries.jsonl")) as f: |
| for line in f: |
| d = json.loads(line) |
| queries[d['_id']] = d['text'] |
| |
| |
| qrels = {} |
| with open(os.path.join(data_dir, "qrels", "test.tsv")) as f: |
| reader = csv.reader(f, delimiter='\t') |
| next(reader) |
| for row in reader: |
| if not row: continue |
| qrels.setdefault(row[0], {})[row[1]] = int(row[2]) |
| |
| eval_qids = [qid for qid in queries if qid in qrels] |
| log(f"{len(eval_qids)} queries to retrieve") |
| |
| pool = {} |
| t0 = time.time() |
| for qi, qid in enumerate(eval_qids): |
| hits = searcher.search(queries[qid], k=100) |
| if hits: |
| pool[qid] = [(hit.docid, float(hit.score)) for hit in hits] |
| if qi > 0 and qi % 500 == 0: |
| rate = (qi + 1) / (time.time() - t0) |
| remaining = (len(eval_qids) - qi - 1) / rate if rate > 0 else 0 |
| log(f" {qi}/{len(eval_qids)} @ {rate:.0f}q/s, ~{remaining:.0f}s left") |
| |
| |
| with open(pool_path, 'w') as f: |
| json.dump({"qids": eval_qids, "pool": pool}, f) |
| log(f"Pool saved ({len(pool)} queries, {time.time()-t0:.0f}s)") |
| else: |
| log(f"Pool already exists at {pool_path}") |
|
|
| log("\nSetup complete! You can now use:") |
| log(" from fever_benchmark import FEVERBenchmark") |
| log(" bench = FEVERBenchmark()") |
| log(" pool = bench.load_pool('beir_pool.json')") |
| log(" # Re-rank, then evaluate:") |
| log(" results = bench.evaluate(your_rankings)") |
|
|