static-embedding-chess / scripts /diag_ce_vs_bm25.py
oneryalcin's picture
Add files using upload-large-folder tool
f8392aa verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"]
# ///
"""Compare trained CE vs BM25 on English-bridged docs, plus top-K sweep.
Tests:
1. Is the 0.59 CE result just lexical match that BM25 could also do?
2. Does increasing K to 200/300 push past oracle 0.59 → 0.77 → 0.87?
"""
import os
import sys
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_dataset
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder, SentenceTransformer
sys.stdout.reconfigure(line_buffering=True)
sys.path.insert(0, os.path.dirname(__file__))
from convert_to_english import build_english_anchor, build_english_doc
HELDOUT_FREQ_MIN = 3
HELDOUT_FREQ_MAX = 30
EVAL_QUERIES = 200
def _join_tags(tags):
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
def _bigram(m):
toks = m.split()
return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m
def build_chess_anchor(themes, op):
tt = _join_tags(themes)
ot = _join_tags(op or [])
return tt + (f" {ot}" if ot else "")
def build_chess_doc_stripped(themes, op, moves):
return f"moves {_bigram(moves)}"
def ndcg_at_k(scores, rel, k=10):
r = sorted(scores, key=lambda kv: -kv[1])[:k]
dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r))
idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k)))
return dcg / idcg if idcg > 0 else 0
def main():
print("Building eval set...")
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
freq = defaultdict(int)
rows_by_anchor = defaultdict(list)
for r in puzzles:
if not r["Themes"]:
continue
ca = build_chess_anchor(r["Themes"], r["OpeningTags"])
freq[ca] += 1
rows_by_anchor[ca].append(r)
rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1])
heldout = [a for a, _ in rare[:EVAL_QUERIES]]
print(f" {len(heldout)} held-out anchors")
qchess, qen = [], []
corp_chess, corp_en = [], []
held_per_doc = []
ch_to_en = {}
for ca in heldout:
for r in rows_by_anchor[ca]:
corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"]))
corp_en.append(build_english_doc(r))
held_per_doc.append(ca)
if ca not in ch_to_en:
ch_to_en[ca] = build_english_anchor(r)
qchess = list(heldout)
qen = [ch_to_en[a] for a in qchess]
by_anchor = defaultdict(list)
for i, a in enumerate(held_per_doc):
by_anchor[a].append(i)
print(f" corpus: {len(corp_chess)} docs")
print("\nLoading static (v4-C2) for first-stage...")
static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final")
sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
sc = sc / np.linalg.norm(sc, axis=1, keepdims=True)
sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
sq = sq / np.linalg.norm(sq, axis=1, keepdims=True)
static_sims = sq @ sc.T
# Loaded trained CE
print("Loading trained CE...")
ce = CrossEncoder("models/chess-reranker-english/final")
# BM25 on English docs
print("Building BM25 over English docs...")
bm25 = BM25Okapi([d.split() for d in corp_en])
print("\n" + "=" * 80)
print(f" {'K':>4} {'Static':>10} {'+CE':>10} {'+BM25':>10} {'Oracle':>10}")
print("=" * 80)
for k in [10, 50, 100, 200, 300]:
if k > len(corp_chess):
continue
static_ndcg = []
ce_ndcg = []
bm25_ndcg = []
oracle_ndcg = []
for qi, q_chess in enumerate(qchess):
rel = set(by_anchor[q_chess])
# Static-only at top-10
top10 = np.argsort(-static_sims[qi])[:10]
sp = [(int(i), float(static_sims[qi, int(i)])) for i in top10]
static_ndcg.append(ndcg_at_k(sp, rel, k=10))
# Top-K shortlist
topk = np.argsort(-static_sims[qi])[:k]
# CE rerank
pairs = [[qen[qi], corp_en[int(i)]] for i in topk]
ce_scores = ce.predict(pairs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
ce_sp = [(int(topk[j]), float(ce_scores[j])) for j in range(len(topk))]
ce_ndcg.append(ndcg_at_k(ce_sp, rel, k=10))
# BM25 rerank over top-K shortlist
bm_full = bm25.get_scores(qen[qi].split())
bm_sp = [(int(topk[j]), float(bm_full[int(topk[j])])) for j in range(len(topk))]
bm25_ndcg.append(ndcg_at_k(bm_sp, rel, k=10))
# Oracle ceiling
rel_in_topk = len(rel & set(int(i) for i in topk))
n10 = min(10, rel_in_topk)
dcg = sum(1.0 / np.log2(r + 2) for r in range(n10))
idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), 10)))
oracle_ndcg.append(dcg / idcg if idcg > 0 else 0)
# static stays the same regardless of K
static_v = np.mean(static_ndcg)
print(f" {k:>4} {static_v:>10.4f} {np.mean(ce_ndcg):>10.4f} {np.mean(bm25_ndcg):>10.4f} {np.mean(oracle_ndcg):>10.4f}")
print("=" * 80)
if __name__ == "__main__":
main()