static-embedding-chess / scripts /diag_static_vs_bm25_alone.py
oneryalcin's picture
Upload scripts/diag_static_vs_bm25_alone.py with huggingface_hub
dc31f7d verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"]
# ///
"""DIRECT comparison: Do we need the static embedding at all?
Tests BM25 over English-bridged corpus AS THE ONLY RETRIEVER (not as a
reranker after static). If BM25 alone hits or beats our static embedding,
we don't need the static model.
Three configurations evaluated:
1. Static-only (v4-C2 retrieves top-10 directly)
2. BM25-only over English-bridged corpus (no static)
3. BM25-only over chess-format corpus (theme tokens stripped — our original
eval format)
This is the apples-to-apples question.
"""
import os
import sys
from collections import defaultdict
import numpy as np
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
sys.stdout.reconfigure(line_buffering=True)
sys.path.insert(0, os.path.dirname(__file__))
from convert_to_english import build_english_anchor, build_english_doc
HELDOUT_FREQ_MIN = 3
HELDOUT_FREQ_MAX = 30
EVAL_QUERIES = 200
def _join_tags(tags):
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
def _bigram(m):
toks = m.split()
return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m
def build_chess_anchor(themes, op):
return _join_tags(themes) + (f" {_join_tags(op or [])}" if op else "")
def build_chess_doc_stripped(themes, op, moves):
return f"moves {_bigram(moves)}"
def ndcg_at_k(scores, rel, k=10):
r = sorted(scores, key=lambda kv: -kv[1])[:k]
dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r))
idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k)))
return dcg / idcg if idcg > 0 else 0
def main():
print("Building held-out eval set (same as v3/v4)...")
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
freq = defaultdict(int)
rows_by_anchor = defaultdict(list)
for r in puzzles:
if not r["Themes"]:
continue
ca = build_chess_anchor(r["Themes"], r["OpeningTags"])
freq[ca] += 1
rows_by_anchor[ca].append(r)
rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
key=lambda kv: kv[1])
heldout = [a for a, _ in rare[:EVAL_QUERIES]]
qchess, qen = [], []
corp_chess, corp_en = [], []
held_per_doc = []
ch_to_en = {}
for ca in heldout:
for r in rows_by_anchor[ca]:
corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"]))
corp_en.append(build_english_doc(r))
held_per_doc.append(ca)
if ca not in ch_to_en:
ch_to_en[ca] = build_english_anchor(r)
qchess = list(heldout)
qen = [ch_to_en[a] for a in qchess]
by_anchor = defaultdict(list)
for i, a in enumerate(held_per_doc):
by_anchor[a].append(i)
print(f" {len(qchess)} queries, {len(corp_chess)} corpus docs")
# 1. Static-only
print("\n[1] Static (v4-C2) alone, ranks all corpus directly")
static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final")
sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
sc = sc / np.linalg.norm(sc, axis=1, keepdims=True)
sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
sq = sq / np.linalg.norm(sq, axis=1, keepdims=True)
static_sims = sq @ sc.T
static_ndcgs = []
for qi in range(len(qchess)):
rel = set(by_anchor[qchess[qi]])
score_pairs = [(int(j), float(static_sims[qi, j])) for j in range(len(corp_chess))]
static_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
print(f" static-only NDCG@10: {np.mean(static_ndcgs):.4f}")
# 2. BM25 over chess-format corpus (theme stripped)
print("\n[2] BM25 alone over chess-format corpus (theme tokens stripped — same docs static sees)")
bm25_chess = BM25Okapi([d.split() for d in corp_chess])
bm25_chess_ndcgs = []
for qi, q in enumerate(qchess):
scores = bm25_chess.get_scores(q.split())
score_pairs = [(j, float(scores[j])) for j in range(len(corp_chess))]
bm25_chess_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[q]), k=10))
print(f" BM25 (chess docs, query=chess anchor) NDCG@10: {np.mean(bm25_chess_ndcgs):.4f}")
# 3. BM25 over English-bridged corpus (theme tokens visible)
print("\n[3] BM25 alone over English-bridged corpus")
bm25_en = BM25Okapi([d.split() for d in corp_en])
bm25_en_ndcgs = []
for qi, q in enumerate(qen):
scores = bm25_en.get_scores(q.split())
score_pairs = [(j, float(scores[j])) for j in range(len(corp_en))]
bm25_en_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[qchess[qi]]), k=10))
print(f" BM25 (English docs, query=English anchor) NDCG@10: {np.mean(bm25_en_ndcgs):.4f}")
# Also: static + BM25 hybrid (RRF fusion)
print("\n[4] Static + BM25 fusion (RRF, K=60)")
K_RRF = 60
rrf_ndcgs = []
for qi in range(len(qchess)):
rel = set(by_anchor[qchess[qi]])
st_rank = np.argsort(-static_sims[qi]).argsort()
bm = bm25_en.get_scores(qen[qi].split())
bm_rank = np.argsort(-bm).argsort()
fused = 1.0 / (K_RRF + st_rank + 1) + 1.0 / (K_RRF + bm_rank + 1)
score_pairs = [(j, float(fused[j])) for j in range(len(corp_chess))]
rrf_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
print(f" Static + BM25-English RRF fusion NDCG@10: {np.mean(rrf_ndcgs):.4f}")
# Summary
print("\n" + "=" * 70)
print(f"{'Approach':<55} {'NDCG@10':>12}")
print("=" * 70)
print(f"{'Static (v4-C2) alone':<55} {np.mean(static_ndcgs):>12.4f}")
print(f"{'BM25 alone over chess-format docs':<55} {np.mean(bm25_chess_ndcgs):>12.4f}")
print(f"{'BM25 alone over English-bridged docs':<55} {np.mean(bm25_en_ndcgs):>12.4f}")
print(f"{'Static + BM25-English RRF fusion':<55} {np.mean(rrf_ndcgs):>12.4f}")
print("=" * 70)
if __name__ == "__main__":
main()