fever-ner / setup_fever_benchmark.py
Kim-el's picture
Upload setup_fever_benchmark.py with huggingface_hub
5a6cb26 verified
Raw
History Blame Contribute Delete
5.13 kB
#!/usr/bin/env python3
"""Setup FEVER benchmark from scratch.
Downloads BEIR FEVER, builds Pyserini index, runs BM25 retrieval,
and creates the frozen pool file (beir_pool.json).
Usage:
python setup_fever_benchmark.py [--data-dir ./fever_data]
"""
import os, sys, json, time, argparse, urllib.request, zipfile
def log(msg):
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default="./fever_data")
parser.add_argument("--skip-download", action="store_true")
parser.add_argument("--skip-index", action="store_true")
args = parser.parse_args()
data_dir = args.data_dir
os.makedirs(data_dir, exist_ok=True)
# ── Step 1: Download BEIR FEVER ──
if not args.skip_download:
zip_path = os.path.join(data_dir, "fever.zip")
corpus_path = os.path.join(data_dir, "corpus.jsonl")
if not os.path.exists(corpus_path):
if not os.path.exists(zip_path):
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip"
log(f"Downloading {url} ...")
urllib.request.urlretrieve(url, zip_path)
log("Download complete.")
log("Extracting...")
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(data_dir)
log("Extraction complete.")
else:
log("Corpus already exists, skipping download.")
else:
log("Skipping download.")
# ── Step 2: Build Pyserini index ──
if not args.skip_index:
from pyserini.index.lucene import LuceneIndexer
import pyserini
index_dir = os.path.join(data_dir, "pyserini_index")
pyserini_input = os.path.join(data_dir, "pyserini_input")
if not os.path.exists(index_dir):
# Convert BEIR format to Pyserini format
log("Converting corpus to Pyserini input format...")
os.makedirs(pyserini_input, exist_ok=True)
corpus_path = os.path.join(data_dir, "corpus.jsonl")
out_path = os.path.join(pyserini_input, "docs.jsonl")
count = 0
with open(corpus_path) as fin, open(out_path, 'w') as fout:
for line in fin:
d = json.loads(line)
fout.write(json.dumps({
"id": d["_id"],
"title": d.get("title", ""),
"text": d.get("text", ""),
"contents": f"{d.get('title','')} {d.get('text','')}"
}) + "\n")
count += 1
if count % 1_000_000 == 0:
log(f" Converted {count}/{5_416_568} docs")
log(f"Converted {count} docs")
# Build index
log("Building Pyserini index (this takes ~6 min)...")
os.system(f"python -m pyserini.index.lucene "
f"--collection JsonCollection "
f"--input {pyserini_input} "
f"--index {index_dir} "
f"--generator DefaultLuceneDocumentGenerator "
f"--threads 8 "
f"--storePositions --storeDocvectors --storeRaw")
log("Index built.")
else:
log("Index already exists, skipping.")
else:
log("Skipping index build.")
# ── Step 3: Run BM25 retrieval ──
pool_path = os.path.join(data_dir, "beir_pool.json")
if not os.path.exists(pool_path):
log("Running BM25 retrieval (k1=1.2, b=0.75)...")
from pyserini.search import SimpleSearcher
searcher = SimpleSearcher(os.path.join(data_dir, "pyserini_index"))
searcher.set_bm25(k1=1.2, b=0.75)
# Load queries
queries = {}
with open(os.path.join(data_dir, "queries.jsonl")) as f:
for line in f:
d = json.loads(line)
queries[d['_id']] = d['text']
# Load qrels
qrels = {}
with open(os.path.join(data_dir, "qrels", "test.tsv")) as f:
reader = csv.reader(f, delimiter='\t')
next(reader)
for row in reader:
if not row: continue
qrels.setdefault(row[0], {})[row[1]] = int(row[2])
eval_qids = [qid for qid in queries if qid in qrels]
log(f"{len(eval_qids)} queries to retrieve")
pool = {}
t0 = time.time()
for qi, qid in enumerate(eval_qids):
hits = searcher.search(queries[qid], k=100)
if hits:
pool[qid] = [(hit.docid, float(hit.score)) for hit in hits]
if qi > 0 and qi % 500 == 0:
rate = (qi + 1) / (time.time() - t0)
remaining = (len(eval_qids) - qi - 1) / rate if rate > 0 else 0
log(f" {qi}/{len(eval_qids)} @ {rate:.0f}q/s, ~{remaining:.0f}s left")
# Save pool
with open(pool_path, 'w') as f:
json.dump({"qids": eval_qids, "pool": pool}, f)
log(f"Pool saved ({len(pool)} queries, {time.time()-t0:.0f}s)")
else:
log(f"Pool already exists at {pool_path}")
log("\nSetup complete! You can now use:")
log(" from fever_benchmark import FEVERBenchmark")
log(" bench = FEVERBenchmark()")
log(" pool = bench.load_pool('beir_pool.json')")
log(" # Re-rank, then evaluate:")
log(" results = bench.evaluate(your_rankings)")