| |
| """Export per-query NDCG@10 β lean version. No TSV I/O during run. |
| Saves periodic checkpoints for resume support.""" |
| import json, os, time, math, gc, csv, traceback, sys |
| import numpy as np |
|
|
| os.environ["OPENAI_API_KEY"] = "dummy" |
| INDEX_DIR = "/workspace/pyserini_index" |
| QUERIES_FILE = "/workspace/fever/queries.jsonl" |
| QRELS_FILE = "/workspace/fever/qrels/test.tsv" |
| POOL_FILE = "/workspace/beir_pool.json" |
| OUT_DIR = "/workspace/rankings" |
| CHECKPOINT = os.path.join(OUT_DIR, "_checkpoint.json") |
| DELTAS_CSV = os.path.join(OUT_DIR, "query_deltas.csv") |
|
|
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| print("Loading queries and qrels...") |
| queries = {} |
| with open(QUERIES_FILE) as f: |
| for line in f: |
| d = json.loads(line) |
| queries[d['_id']] = d['text'] |
|
|
| qrels = {} |
| with open(QRELS_FILE) as f: |
| reader = csv.reader(f, delimiter='\t') |
| next(reader) |
| for row in reader: |
| if not row: continue |
| qrels.setdefault(row[0], {})[row[1]] = int(row[2]) |
|
|
| with open(POOL_FILE) as f: |
| pool_data = json.load(f) |
| eval_qids = pool_data["qids"] |
| pool_dict = pool_data["pool"] |
| print(f"{len(eval_qids)} queries with pool data") |
|
|
| print("Loading Pyserini searcher...") |
| from pyserini.search import SimpleSearcher |
| searcher = SimpleSearcher(INDEX_DIR) |
| searcher.set_bm25(k1=1.2, b=0.75) |
|
|
| def ndcg10(ranked_list, gt): |
| dcg = sum((2**gt.get(did,0)-1)/math.log2(k+2) for k,(did,_) in enumerate(ranked_list[:10])) |
| ig = sorted(gt.values(), reverse=True) |
| idcg = sum((2**ig[k]-1)/math.log2(k+2) for k in range(min(10,len(ig)))) |
| return dcg/idcg if idcg > 0 else 0.0 |
|
|
| print("Loading MiniLM...") |
| import torch |
| from transformers import AutoTokenizer, AutoModel |
| torch.set_num_threads(4) |
| model_name = "sentence-transformers/all-MiniLM-L6-v2" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModel.from_pretrained(model_name).eval() |
|
|
| def encode(texts): |
| inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt") |
| with torch.no_grad(): |
| emb = model(**inputs).last_hidden_state.mean(dim=1) |
| return emb |
|
|
| encode(["warmup"]) |
| print("MiniLM ready") |
|
|
| ABLATIONS = { |
| "bm25": (0, 0, 0, "bm25"), |
| "dense": (0, 0, 0, "dense"), |
| "dense+T": (0.2, 0, 0, "dense"), |
| "dense+S": (0, 0.1, 0, "dense"), |
| "dense+M": (0, 0, 0.2, "dense"), |
| "dense+TS": (0.2, 0.1, 0, "dense"), |
| "dense+MT": (0.2, 0, 0.2, "dense"), |
| "dense+MS": (0, 0.1, 0.2, "dense"), |
| "dense+MTS": (0.2, 0.1, 0.2, "dense"), |
| } |
| CT_VALS = [0.3, 0.5, 0.7, 0.8, 0.85, 0.9] |
| TW_VALS = [0.1, 0.15, 0.2, 0.3] |
|
|
| |
| results = [] |
| start_idx = 0 |
| errors = 0 |
| if os.path.exists(CHECKPOINT): |
| with open(CHECKPOINT) as f: |
| cp = json.load(f) |
| results = cp["results"] |
| start_idx = cp["next_idx"] |
| errors = cp.get("errors", 0) |
| print(f"Resuming from query index {start_idx} ({len(results)} completed)") |
|
|
| t0 = time.time() |
|
|
| |
| for qi in range(start_idx, len(eval_qids)): |
| qid = eval_qids[qi] |
| try: |
| pool_100 = pool_dict.get(qid, []) |
| if not pool_100: |
| continue |
| |
| pool = [p[0] for p in pool_100] |
| bm25_s = np.array([p[1] for p in pool_100], dtype=float) |
| |
| pool_texts = [] |
| for docid in pool[:100]: |
| doc = searcher.doc(docid) |
| if doc: |
| raw = doc.raw() |
| d = json.loads(raw) |
| pool_texts.append(d.get('contents', '')) |
| else: |
| pool_texts.append("") |
| |
| if not any(pool_texts): |
| continue |
| |
| bm25_n = (bm25_s - bm25_s.min()) / (bm25_s.max() - bm25_s.min() + 1e-10) |
| q_emb = encode([queries[qid]]) |
| pool_embs = encode(pool_texts) |
| |
| n_pool = len(pool) |
| q_normed = q_emb / (torch.norm(q_emb, dim=1, keepdim=True) + 1e-10) |
| pool_normed = pool_embs / (torch.norm(pool_embs, dim=1, keepdim=True) + 1e-10) |
| dense_scores = (q_normed @ pool_normed.T).cpu().numpy().flatten() |
| dense_n = (dense_scores - dense_scores.min()) / (dense_scores.max() - dense_scores.min() + 1e-10) |
| |
| sim = (pool_normed @ pool_normed.T).cpu().numpy() |
| cs = np.sum(sim > 0.7, axis=1).astype(float) |
| csn = (cs - cs.min()) / (cs.max() - cs.min() + 1e-10) |
| iso = cs <= 1 |
| sim_top2 = sim[0][1] if n_pool > 1 else 0 |
| |
| gt = qrels.get(qid, {}) |
| |
| row = {"qid": qid} |
| for a_name, (tw, sh, mu, base_type) in ABLATIONS.items(): |
| base = bm25_n if base_type == "bm25" else dense_n |
| if a_name == "bm25" or a_name == "dense": |
| ranked = sorted([(pool[i], float(bm25_n[i] if base_type == 'bm25' else dense_n[i])) for i in range(n_pool)], key=lambda x: -x[1]) |
| else: |
| fs = base.copy() |
| if tw > 0: fs += tw * csn |
| if sh > 0: fs -= sh * iso.astype(float) |
| if mu > 0 and n_pool > 2 and sim_top2 > 0.7: |
| fs[(sim[0] > 0.7) | (sim[1] > 0.7)] += mu |
| ranked = sorted([(pool[i], float(fs[i])) for i in range(n_pool)], key=lambda x: -x[1]) |
| row[a_name] = ndcg10(ranked, gt) |
| |
| |
| for ct in CT_VALS: |
| cs2 = np.sum(sim > ct, axis=1).astype(float) |
| csn2 = (cs2 - cs2.min()) / (cs2.max() - cs2.min() + 1e-10) |
| for tw2 in TW_VALS: |
| fs2 = dense_n + tw2 * csn2 |
| ranked = sorted([(pool[i], float(fs2[i])) for i in range(n_pool)], key=lambda x: -x[1]) |
| k = f"tawatur_ct={ct}_tw={tw2}" |
| row[k] = ndcg10(ranked, gt) |
| |
| results.append(row) |
| |
| del q_emb, pool_embs, pool_normed, sim |
| gc.collect() |
| |
| if qi > 0 and qi % 200 == 0: |
| rate = (qi + 1 - start_idx) / max(time.time() - t0, 1) |
| remaining = (len(eval_qids) - qi - 1) / rate if rate > 0 else 0 |
| print(f" {qi+1}/{len(eval_qids)} @ {rate:.1f}q/s, ~{remaining:.0f}s left") |
| |
| |
| with open(CHECKPOINT, 'w') as f: |
| json.dump({"results": results, "next_idx": qi + 1, "errors": errors}, f) |
| |
| except Exception as e: |
| errors += 1 |
| print(f" ERROR qid={qid}: {e}") |
| if errors > 20: |
| print("TOO MANY ERRORS") |
| break |
| continue |
|
|
| |
| if results: |
| fields = list(results[0].keys()) |
| with open(DELTAS_CSV, 'w') as f: |
| f.write(",".join(fields) + "\n") |
| for row in results: |
| f.write(",".join(str(row.get(f, "")) for f in fields) + "\n") |
| print(f"\nExported {len(results)} queries to {DELTAS_CSV}") |
|
|
| |
| if os.path.exists(CHECKPOINT): |
| os.remove(CHECKPOINT) |
|
|
| |
| dense_vals = np.array([r["dense"] for r in results]) |
| print(f"\n{'System':<25} {'NDCG@10':>10} {'vs Dense':>12}") |
| print("-" * 50) |
| for k in ["bm25", "dense", "dense+T", "dense+S", "dense+M", "dense+TS", "dense+MT", "dense+MS", "dense+MTS"]: |
| vals = np.array([r[k] for r in results]) |
| mean_v = np.mean(vals) |
| d = mean_v - np.mean(dense_vals) |
| print(f"{k:<25} {mean_v:>10.4f} {d:>+12.4f}") |
|
|
| elapsed = time.time() - t0 |
| print(f"\nTotal: {len(results)} queries, {errors} errors, {elapsed:.0f}s") |
|
|