#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"] # /// """Compare trained CE vs BM25 on English-bridged docs, plus top-K sweep. Tests: 1. Is the 0.59 CE result just lexical match that BM25 could also do? 2. Does increasing K to 200/300 push past oracle 0.59 → 0.77 → 0.87? """ import os import sys from collections import defaultdict import numpy as np from datasets import Dataset, load_dataset from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder, SentenceTransformer sys.stdout.reconfigure(line_buffering=True) sys.path.insert(0, os.path.dirname(__file__)) from convert_to_english import build_english_anchor, build_english_doc HELDOUT_FREQ_MIN = 3 HELDOUT_FREQ_MAX = 30 EVAL_QUERIES = 200 def _join_tags(tags): return " ".join(t.replace("_", " ") for t in tags) if tags else "" def _bigram(m): toks = m.split() return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m def build_chess_anchor(themes, op): tt = _join_tags(themes) ot = _join_tags(op or []) return tt + (f" {ot}" if ot else "") def build_chess_doc_stripped(themes, op, moves): return f"moves {_bigram(moves)}" def ndcg_at_k(scores, rel, k=10): r = sorted(scores, key=lambda kv: -kv[1])[:k] dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r)) idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k))) return dcg / idcg if idcg > 0 else 0 def main(): print("Building eval set...") puzzles = load_dataset("Lichess/chess-puzzles", split="train") freq = defaultdict(int) rows_by_anchor = defaultdict(list) for r in puzzles: if not r["Themes"]: continue ca = build_chess_anchor(r["Themes"], r["OpeningTags"]) freq[ca] += 1 rows_by_anchor[ca].append(r) rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1]) heldout = [a for a, _ in rare[:EVAL_QUERIES]] print(f" {len(heldout)} held-out anchors") qchess, qen = [], [] corp_chess, corp_en = [], [] held_per_doc = [] ch_to_en = {} for ca in heldout: for r in rows_by_anchor[ca]: corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"])) corp_en.append(build_english_doc(r)) held_per_doc.append(ca) if ca not in ch_to_en: ch_to_en[ca] = build_english_anchor(r) qchess = list(heldout) qen = [ch_to_en[a] for a in qchess] by_anchor = defaultdict(list) for i, a in enumerate(held_per_doc): by_anchor[a].append(i) print(f" corpus: {len(corp_chess)} docs") print("\nLoading static (v4-C2) for first-stage...") static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final") sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) sc = sc / np.linalg.norm(sc, axis=1, keepdims=True) sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) sq = sq / np.linalg.norm(sq, axis=1, keepdims=True) static_sims = sq @ sc.T # Loaded trained CE print("Loading trained CE...") ce = CrossEncoder("models/chess-reranker-english/final") # BM25 on English docs print("Building BM25 over English docs...") bm25 = BM25Okapi([d.split() for d in corp_en]) print("\n" + "=" * 80) print(f" {'K':>4} {'Static':>10} {'+CE':>10} {'+BM25':>10} {'Oracle':>10}") print("=" * 80) for k in [10, 50, 100, 200, 300]: if k > len(corp_chess): continue static_ndcg = [] ce_ndcg = [] bm25_ndcg = [] oracle_ndcg = [] for qi, q_chess in enumerate(qchess): rel = set(by_anchor[q_chess]) # Static-only at top-10 top10 = np.argsort(-static_sims[qi])[:10] sp = [(int(i), float(static_sims[qi, int(i)])) for i in top10] static_ndcg.append(ndcg_at_k(sp, rel, k=10)) # Top-K shortlist topk = np.argsort(-static_sims[qi])[:k] # CE rerank pairs = [[qen[qi], corp_en[int(i)]] for i in topk] ce_scores = ce.predict(pairs, batch_size=64, show_progress_bar=False, convert_to_numpy=True) ce_sp = [(int(topk[j]), float(ce_scores[j])) for j in range(len(topk))] ce_ndcg.append(ndcg_at_k(ce_sp, rel, k=10)) # BM25 rerank over top-K shortlist bm_full = bm25.get_scores(qen[qi].split()) bm_sp = [(int(topk[j]), float(bm_full[int(topk[j])])) for j in range(len(topk))] bm25_ndcg.append(ndcg_at_k(bm_sp, rel, k=10)) # Oracle ceiling rel_in_topk = len(rel & set(int(i) for i in topk)) n10 = min(10, rel_in_topk) dcg = sum(1.0 / np.log2(r + 2) for r in range(n10)) idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), 10))) oracle_ndcg.append(dcg / idcg if idcg > 0 else 0) # static stays the same regardless of K static_v = np.mean(static_ndcg) print(f" {k:>4} {static_v:>10.4f} {np.mean(ce_ndcg):>10.4f} {np.mean(bm25_ndcg):>10.4f} {np.mean(oracle_ndcg):>10.4f}") print("=" * 80) if __name__ == "__main__": main()