#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"] # /// """DIRECT comparison: Do we need the static embedding at all? Tests BM25 over English-bridged corpus AS THE ONLY RETRIEVER (not as a reranker after static). If BM25 alone hits or beats our static embedding, we don't need the static model. Three configurations evaluated: 1. Static-only (v4-C2 retrieves top-10 directly) 2. BM25-only over English-bridged corpus (no static) 3. BM25-only over chess-format corpus (theme tokens stripped — our original eval format) This is the apples-to-apples question. """ import os import sys from collections import defaultdict import numpy as np from datasets import load_dataset from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer sys.stdout.reconfigure(line_buffering=True) sys.path.insert(0, os.path.dirname(__file__)) from convert_to_english import build_english_anchor, build_english_doc HELDOUT_FREQ_MIN = 3 HELDOUT_FREQ_MAX = 30 EVAL_QUERIES = 200 def _join_tags(tags): return " ".join(t.replace("_", " ") for t in tags) if tags else "" def _bigram(m): toks = m.split() return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m def build_chess_anchor(themes, op): return _join_tags(themes) + (f" {_join_tags(op or [])}" if op else "") def build_chess_doc_stripped(themes, op, moves): return f"moves {_bigram(moves)}" def ndcg_at_k(scores, rel, k=10): r = sorted(scores, key=lambda kv: -kv[1])[:k] dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r)) idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k))) return dcg / idcg if idcg > 0 else 0 def main(): print("Building held-out eval set (same as v3/v4)...") puzzles = load_dataset("Lichess/chess-puzzles", split="train") freq = defaultdict(int) rows_by_anchor = defaultdict(list) for r in puzzles: if not r["Themes"]: continue ca = build_chess_anchor(r["Themes"], r["OpeningTags"]) freq[ca] += 1 rows_by_anchor[ca].append(r) rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1]) heldout = [a for a, _ in rare[:EVAL_QUERIES]] qchess, qen = [], [] corp_chess, corp_en = [], [] held_per_doc = [] ch_to_en = {} for ca in heldout: for r in rows_by_anchor[ca]: corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"])) corp_en.append(build_english_doc(r)) held_per_doc.append(ca) if ca not in ch_to_en: ch_to_en[ca] = build_english_anchor(r) qchess = list(heldout) qen = [ch_to_en[a] for a in qchess] by_anchor = defaultdict(list) for i, a in enumerate(held_per_doc): by_anchor[a].append(i) print(f" {len(qchess)} queries, {len(corp_chess)} corpus docs") # 1. Static-only print("\n[1] Static (v4-C2) alone, ranks all corpus directly") static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final") sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) sc = sc / np.linalg.norm(sc, axis=1, keepdims=True) sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) sq = sq / np.linalg.norm(sq, axis=1, keepdims=True) static_sims = sq @ sc.T static_ndcgs = [] for qi in range(len(qchess)): rel = set(by_anchor[qchess[qi]]) score_pairs = [(int(j), float(static_sims[qi, j])) for j in range(len(corp_chess))] static_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10)) print(f" static-only NDCG@10: {np.mean(static_ndcgs):.4f}") # 2. BM25 over chess-format corpus (theme stripped) print("\n[2] BM25 alone over chess-format corpus (theme tokens stripped — same docs static sees)") bm25_chess = BM25Okapi([d.split() for d in corp_chess]) bm25_chess_ndcgs = [] for qi, q in enumerate(qchess): scores = bm25_chess.get_scores(q.split()) score_pairs = [(j, float(scores[j])) for j in range(len(corp_chess))] bm25_chess_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[q]), k=10)) print(f" BM25 (chess docs, query=chess anchor) NDCG@10: {np.mean(bm25_chess_ndcgs):.4f}") # 3. BM25 over English-bridged corpus (theme tokens visible) print("\n[3] BM25 alone over English-bridged corpus") bm25_en = BM25Okapi([d.split() for d in corp_en]) bm25_en_ndcgs = [] for qi, q in enumerate(qen): scores = bm25_en.get_scores(q.split()) score_pairs = [(j, float(scores[j])) for j in range(len(corp_en))] bm25_en_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[qchess[qi]]), k=10)) print(f" BM25 (English docs, query=English anchor) NDCG@10: {np.mean(bm25_en_ndcgs):.4f}") # Also: static + BM25 hybrid (RRF fusion) print("\n[4] Static + BM25 fusion (RRF, K=60)") K_RRF = 60 rrf_ndcgs = [] for qi in range(len(qchess)): rel = set(by_anchor[qchess[qi]]) st_rank = np.argsort(-static_sims[qi]).argsort() bm = bm25_en.get_scores(qen[qi].split()) bm_rank = np.argsort(-bm).argsort() fused = 1.0 / (K_RRF + st_rank + 1) + 1.0 / (K_RRF + bm_rank + 1) score_pairs = [(j, float(fused[j])) for j in range(len(corp_chess))] rrf_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10)) print(f" Static + BM25-English RRF fusion NDCG@10: {np.mean(rrf_ndcgs):.4f}") # Summary print("\n" + "=" * 70) print(f"{'Approach':<55} {'NDCG@10':>12}") print("=" * 70) print(f"{'Static (v4-C2) alone':<55} {np.mean(static_ndcgs):>12.4f}") print(f"{'BM25 alone over chess-format docs':<55} {np.mean(bm25_chess_ndcgs):>12.4f}") print(f"{'BM25 alone over English-bridged docs':<55} {np.mean(bm25_en_ndcgs):>12.4f}") print(f"{'Static + BM25-English RRF fusion':<55} {np.mean(rrf_ndcgs):>12.4f}") print("=" * 70) if __name__ == "__main__": main()