#!/usr/bin/env python3 """Export per-query NDCG@10 — lean version. No TSV I/O during run. Saves periodic checkpoints for resume support.""" import json, os, time, math, gc, csv, traceback, sys import numpy as np os.environ["OPENAI_API_KEY"] = "dummy" INDEX_DIR = "/workspace/pyserini_index" QUERIES_FILE = "/workspace/fever/queries.jsonl" QRELS_FILE = "/workspace/fever/qrels/test.tsv" POOL_FILE = "/workspace/beir_pool.json" OUT_DIR = "/workspace/rankings" CHECKPOINT = os.path.join(OUT_DIR, "_checkpoint.json") DELTAS_CSV = os.path.join(OUT_DIR, "query_deltas.csv") os.makedirs(OUT_DIR, exist_ok=True) print("Loading queries and qrels...") queries = {} with open(QUERIES_FILE) as f: for line in f: d = json.loads(line) queries[d['_id']] = d['text'] qrels = {} with open(QRELS_FILE) as f: reader = csv.reader(f, delimiter='\t') next(reader) for row in reader: if not row: continue qrels.setdefault(row[0], {})[row[1]] = int(row[2]) with open(POOL_FILE) as f: pool_data = json.load(f) eval_qids = pool_data["qids"] pool_dict = pool_data["pool"] print(f"{len(eval_qids)} queries with pool data") print("Loading Pyserini searcher...") from pyserini.search import SimpleSearcher searcher = SimpleSearcher(INDEX_DIR) searcher.set_bm25(k1=1.2, b=0.75) def ndcg10(ranked_list, gt): dcg = sum((2**gt.get(did,0)-1)/math.log2(k+2) for k,(did,_) in enumerate(ranked_list[:10])) ig = sorted(gt.values(), reverse=True) idcg = sum((2**ig[k]-1)/math.log2(k+2) for k in range(min(10,len(ig)))) return dcg/idcg if idcg > 0 else 0.0 print("Loading MiniLM...") import torch from transformers import AutoTokenizer, AutoModel torch.set_num_threads(4) model_name = "sentence-transformers/all-MiniLM-L6-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).eval() def encode(texts): inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt") with torch.no_grad(): emb = model(**inputs).last_hidden_state.mean(dim=1) return emb encode(["warmup"]) print("MiniLM ready") ABLATIONS = { "bm25": (0, 0, 0, "bm25"), "dense": (0, 0, 0, "dense"), "dense+T": (0.2, 0, 0, "dense"), "dense+S": (0, 0.1, 0, "dense"), "dense+M": (0, 0, 0.2, "dense"), "dense+TS": (0.2, 0.1, 0, "dense"), "dense+MT": (0.2, 0, 0.2, "dense"), "dense+MS": (0, 0.1, 0.2, "dense"), "dense+MTS": (0.2, 0.1, 0.2, "dense"), } CT_VALS = [0.3, 0.5, 0.7, 0.8, 0.85, 0.9] TW_VALS = [0.1, 0.15, 0.2, 0.3] # ── Resume logic ── results = [] start_idx = 0 errors = 0 if os.path.exists(CHECKPOINT): with open(CHECKPOINT) as f: cp = json.load(f) results = cp["results"] start_idx = cp["next_idx"] errors = cp.get("errors", 0) print(f"Resuming from query index {start_idx} ({len(results)} completed)") t0 = time.time() # ── Main loop ── for qi in range(start_idx, len(eval_qids)): qid = eval_qids[qi] try: pool_100 = pool_dict.get(qid, []) if not pool_100: continue pool = [p[0] for p in pool_100] bm25_s = np.array([p[1] for p in pool_100], dtype=float) pool_texts = [] for docid in pool[:100]: doc = searcher.doc(docid) if doc: raw = doc.raw() d = json.loads(raw) pool_texts.append(d.get('contents', '')) else: pool_texts.append("") if not any(pool_texts): continue bm25_n = (bm25_s - bm25_s.min()) / (bm25_s.max() - bm25_s.min() + 1e-10) q_emb = encode([queries[qid]]) pool_embs = encode(pool_texts) n_pool = len(pool) q_normed = q_emb / (torch.norm(q_emb, dim=1, keepdim=True) + 1e-10) pool_normed = pool_embs / (torch.norm(pool_embs, dim=1, keepdim=True) + 1e-10) dense_scores = (q_normed @ pool_normed.T).cpu().numpy().flatten() dense_n = (dense_scores - dense_scores.min()) / (dense_scores.max() - dense_scores.min() + 1e-10) sim = (pool_normed @ pool_normed.T).cpu().numpy() cs = np.sum(sim > 0.7, axis=1).astype(float) csn = (cs - cs.min()) / (cs.max() - cs.min() + 1e-10) iso = cs <= 1 sim_top2 = sim[0][1] if n_pool > 1 else 0 gt = qrels.get(qid, {}) row = {"qid": qid} for a_name, (tw, sh, mu, base_type) in ABLATIONS.items(): base = bm25_n if base_type == "bm25" else dense_n if a_name == "bm25" or a_name == "dense": ranked = sorted([(pool[i], float(bm25_n[i] if base_type == 'bm25' else dense_n[i])) for i in range(n_pool)], key=lambda x: -x[1]) else: fs = base.copy() if tw > 0: fs += tw * csn if sh > 0: fs -= sh * iso.astype(float) if mu > 0 and n_pool > 2 and sim_top2 > 0.7: fs[(sim[0] > 0.7) | (sim[1] > 0.7)] += mu ranked = sorted([(pool[i], float(fs[i])) for i in range(n_pool)], key=lambda x: -x[1]) row[a_name] = ndcg10(ranked, gt) # Sweep for ct in CT_VALS: cs2 = np.sum(sim > ct, axis=1).astype(float) csn2 = (cs2 - cs2.min()) / (cs2.max() - cs2.min() + 1e-10) for tw2 in TW_VALS: fs2 = dense_n + tw2 * csn2 ranked = sorted([(pool[i], float(fs2[i])) for i in range(n_pool)], key=lambda x: -x[1]) k = f"tawatur_ct={ct}_tw={tw2}" row[k] = ndcg10(ranked, gt) results.append(row) del q_emb, pool_embs, pool_normed, sim gc.collect() if qi > 0 and qi % 200 == 0: rate = (qi + 1 - start_idx) / max(time.time() - t0, 1) remaining = (len(eval_qids) - qi - 1) / rate if rate > 0 else 0 print(f" {qi+1}/{len(eval_qids)} @ {rate:.1f}q/s, ~{remaining:.0f}s left") # Save checkpoint every 200 queries with open(CHECKPOINT, 'w') as f: json.dump({"results": results, "next_idx": qi + 1, "errors": errors}, f) except Exception as e: errors += 1 print(f" ERROR qid={qid}: {e}") if errors > 20: print("TOO MANY ERRORS") break continue # ── Write CSV ── if results: fields = list(results[0].keys()) with open(DELTAS_CSV, 'w') as f: f.write(",".join(fields) + "\n") for row in results: f.write(",".join(str(row.get(f, "")) for f in fields) + "\n") print(f"\nExported {len(results)} queries to {DELTAS_CSV}") # Cleanup if os.path.exists(CHECKPOINT): os.remove(CHECKPOINT) # Print summary dense_vals = np.array([r["dense"] for r in results]) print(f"\n{'System':<25} {'NDCG@10':>10} {'vs Dense':>12}") print("-" * 50) for k in ["bm25", "dense", "dense+T", "dense+S", "dense+M", "dense+TS", "dense+MT", "dense+MS", "dense+MTS"]: vals = np.array([r[k] for r in results]) mean_v = np.mean(vals) d = mean_v - np.mean(dense_vals) print(f"{k:<25} {mean_v:>10.4f} {d:>+12.4f}") elapsed = time.time() - t0 print(f"\nTotal: {len(results)} queries, {errors} errors, {elapsed:.0f}s")