fever-ner / export_per_query_v2.py
Kim-el's picture
Upload export_per_query_v2.py with huggingface_hub
0e4284f verified
Raw
History Blame Contribute Delete
7.36 kB
#!/usr/bin/env python3
"""Export per-query NDCG@10 β€” lean version. No TSV I/O during run.
Saves periodic checkpoints for resume support."""
import json, os, time, math, gc, csv, traceback, sys
import numpy as np
os.environ["OPENAI_API_KEY"] = "dummy"
INDEX_DIR = "/workspace/pyserini_index"
QUERIES_FILE = "/workspace/fever/queries.jsonl"
QRELS_FILE = "/workspace/fever/qrels/test.tsv"
POOL_FILE = "/workspace/beir_pool.json"
OUT_DIR = "/workspace/rankings"
CHECKPOINT = os.path.join(OUT_DIR, "_checkpoint.json")
DELTAS_CSV = os.path.join(OUT_DIR, "query_deltas.csv")
os.makedirs(OUT_DIR, exist_ok=True)
print("Loading queries and qrels...")
queries = {}
with open(QUERIES_FILE) as f:
for line in f:
d = json.loads(line)
queries[d['_id']] = d['text']
qrels = {}
with open(QRELS_FILE) as f:
reader = csv.reader(f, delimiter='\t')
next(reader)
for row in reader:
if not row: continue
qrels.setdefault(row[0], {})[row[1]] = int(row[2])
with open(POOL_FILE) as f:
pool_data = json.load(f)
eval_qids = pool_data["qids"]
pool_dict = pool_data["pool"]
print(f"{len(eval_qids)} queries with pool data")
print("Loading Pyserini searcher...")
from pyserini.search import SimpleSearcher
searcher = SimpleSearcher(INDEX_DIR)
searcher.set_bm25(k1=1.2, b=0.75)
def ndcg10(ranked_list, gt):
dcg = sum((2**gt.get(did,0)-1)/math.log2(k+2) for k,(did,_) in enumerate(ranked_list[:10]))
ig = sorted(gt.values(), reverse=True)
idcg = sum((2**ig[k]-1)/math.log2(k+2) for k in range(min(10,len(ig))))
return dcg/idcg if idcg > 0 else 0.0
print("Loading MiniLM...")
import torch
from transformers import AutoTokenizer, AutoModel
torch.set_num_threads(4)
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).eval()
def encode(texts):
inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
with torch.no_grad():
emb = model(**inputs).last_hidden_state.mean(dim=1)
return emb
encode(["warmup"])
print("MiniLM ready")
ABLATIONS = {
"bm25": (0, 0, 0, "bm25"),
"dense": (0, 0, 0, "dense"),
"dense+T": (0.2, 0, 0, "dense"),
"dense+S": (0, 0.1, 0, "dense"),
"dense+M": (0, 0, 0.2, "dense"),
"dense+TS": (0.2, 0.1, 0, "dense"),
"dense+MT": (0.2, 0, 0.2, "dense"),
"dense+MS": (0, 0.1, 0.2, "dense"),
"dense+MTS": (0.2, 0.1, 0.2, "dense"),
}
CT_VALS = [0.3, 0.5, 0.7, 0.8, 0.85, 0.9]
TW_VALS = [0.1, 0.15, 0.2, 0.3]
# ── Resume logic ──
results = []
start_idx = 0
errors = 0
if os.path.exists(CHECKPOINT):
with open(CHECKPOINT) as f:
cp = json.load(f)
results = cp["results"]
start_idx = cp["next_idx"]
errors = cp.get("errors", 0)
print(f"Resuming from query index {start_idx} ({len(results)} completed)")
t0 = time.time()
# ── Main loop ──
for qi in range(start_idx, len(eval_qids)):
qid = eval_qids[qi]
try:
pool_100 = pool_dict.get(qid, [])
if not pool_100:
continue
pool = [p[0] for p in pool_100]
bm25_s = np.array([p[1] for p in pool_100], dtype=float)
pool_texts = []
for docid in pool[:100]:
doc = searcher.doc(docid)
if doc:
raw = doc.raw()
d = json.loads(raw)
pool_texts.append(d.get('contents', ''))
else:
pool_texts.append("")
if not any(pool_texts):
continue
bm25_n = (bm25_s - bm25_s.min()) / (bm25_s.max() - bm25_s.min() + 1e-10)
q_emb = encode([queries[qid]])
pool_embs = encode(pool_texts)
n_pool = len(pool)
q_normed = q_emb / (torch.norm(q_emb, dim=1, keepdim=True) + 1e-10)
pool_normed = pool_embs / (torch.norm(pool_embs, dim=1, keepdim=True) + 1e-10)
dense_scores = (q_normed @ pool_normed.T).cpu().numpy().flatten()
dense_n = (dense_scores - dense_scores.min()) / (dense_scores.max() - dense_scores.min() + 1e-10)
sim = (pool_normed @ pool_normed.T).cpu().numpy()
cs = np.sum(sim > 0.7, axis=1).astype(float)
csn = (cs - cs.min()) / (cs.max() - cs.min() + 1e-10)
iso = cs <= 1
sim_top2 = sim[0][1] if n_pool > 1 else 0
gt = qrels.get(qid, {})
row = {"qid": qid}
for a_name, (tw, sh, mu, base_type) in ABLATIONS.items():
base = bm25_n if base_type == "bm25" else dense_n
if a_name == "bm25" or a_name == "dense":
ranked = sorted([(pool[i], float(bm25_n[i] if base_type == 'bm25' else dense_n[i])) for i in range(n_pool)], key=lambda x: -x[1])
else:
fs = base.copy()
if tw > 0: fs += tw * csn
if sh > 0: fs -= sh * iso.astype(float)
if mu > 0 and n_pool > 2 and sim_top2 > 0.7:
fs[(sim[0] > 0.7) | (sim[1] > 0.7)] += mu
ranked = sorted([(pool[i], float(fs[i])) for i in range(n_pool)], key=lambda x: -x[1])
row[a_name] = ndcg10(ranked, gt)
# Sweep
for ct in CT_VALS:
cs2 = np.sum(sim > ct, axis=1).astype(float)
csn2 = (cs2 - cs2.min()) / (cs2.max() - cs2.min() + 1e-10)
for tw2 in TW_VALS:
fs2 = dense_n + tw2 * csn2
ranked = sorted([(pool[i], float(fs2[i])) for i in range(n_pool)], key=lambda x: -x[1])
k = f"tawatur_ct={ct}_tw={tw2}"
row[k] = ndcg10(ranked, gt)
results.append(row)
del q_emb, pool_embs, pool_normed, sim
gc.collect()
if qi > 0 and qi % 200 == 0:
rate = (qi + 1 - start_idx) / max(time.time() - t0, 1)
remaining = (len(eval_qids) - qi - 1) / rate if rate > 0 else 0
print(f" {qi+1}/{len(eval_qids)} @ {rate:.1f}q/s, ~{remaining:.0f}s left")
# Save checkpoint every 200 queries
with open(CHECKPOINT, 'w') as f:
json.dump({"results": results, "next_idx": qi + 1, "errors": errors}, f)
except Exception as e:
errors += 1
print(f" ERROR qid={qid}: {e}")
if errors > 20:
print("TOO MANY ERRORS")
break
continue
# ── Write CSV ──
if results:
fields = list(results[0].keys())
with open(DELTAS_CSV, 'w') as f:
f.write(",".join(fields) + "\n")
for row in results:
f.write(",".join(str(row.get(f, "")) for f in fields) + "\n")
print(f"\nExported {len(results)} queries to {DELTAS_CSV}")
# Cleanup
if os.path.exists(CHECKPOINT):
os.remove(CHECKPOINT)
# Print summary
dense_vals = np.array([r["dense"] for r in results])
print(f"\n{'System':<25} {'NDCG@10':>10} {'vs Dense':>12}")
print("-" * 50)
for k in ["bm25", "dense", "dense+T", "dense+S", "dense+M", "dense+TS", "dense+MT", "dense+MS", "dense+MTS"]:
vals = np.array([r[k] for r in results])
mean_v = np.mean(vals)
d = mean_v - np.mean(dense_vals)
print(f"{k:<25} {mean_v:>10.4f} {d:>+12.4f}")
elapsed = time.time() - t0
print(f"\nTotal: {len(results)} queries, {errors} errors, {elapsed:.0f}s")