oneryalcin commited on
Commit
dc31f7d
·
verified ·
1 Parent(s): b82c4e6

Upload scripts/diag_static_vs_bm25_alone.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/diag_static_vs_bm25_alone.py +156 -0
scripts/diag_static_vs_bm25_alone.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"]
5
+ # ///
6
+ """DIRECT comparison: Do we need the static embedding at all?
7
+
8
+ Tests BM25 over English-bridged corpus AS THE ONLY RETRIEVER (not as a
9
+ reranker after static). If BM25 alone hits or beats our static embedding,
10
+ we don't need the static model.
11
+
12
+ Three configurations evaluated:
13
+ 1. Static-only (v4-C2 retrieves top-10 directly)
14
+ 2. BM25-only over English-bridged corpus (no static)
15
+ 3. BM25-only over chess-format corpus (theme tokens stripped — our original
16
+ eval format)
17
+
18
+ This is the apples-to-apples question.
19
+ """
20
+ import os
21
+ import sys
22
+ from collections import defaultdict
23
+
24
+ import numpy as np
25
+ from datasets import load_dataset
26
+ from rank_bm25 import BM25Okapi
27
+ from sentence_transformers import SentenceTransformer
28
+
29
+ sys.stdout.reconfigure(line_buffering=True)
30
+ sys.path.insert(0, os.path.dirname(__file__))
31
+ from convert_to_english import build_english_anchor, build_english_doc
32
+
33
+ HELDOUT_FREQ_MIN = 3
34
+ HELDOUT_FREQ_MAX = 30
35
+ EVAL_QUERIES = 200
36
+
37
+
38
+ def _join_tags(tags):
39
+ return " ".join(t.replace("_", " ") for t in tags) if tags else ""
40
+
41
+
42
+ def _bigram(m):
43
+ toks = m.split()
44
+ return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m
45
+
46
+
47
+ def build_chess_anchor(themes, op):
48
+ return _join_tags(themes) + (f" {_join_tags(op or [])}" if op else "")
49
+
50
+
51
+ def build_chess_doc_stripped(themes, op, moves):
52
+ return f"moves {_bigram(moves)}"
53
+
54
+
55
+ def ndcg_at_k(scores, rel, k=10):
56
+ r = sorted(scores, key=lambda kv: -kv[1])[:k]
57
+ dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r))
58
+ idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k)))
59
+ return dcg / idcg if idcg > 0 else 0
60
+
61
+
62
+ def main():
63
+ print("Building held-out eval set (same as v3/v4)...")
64
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
65
+ freq = defaultdict(int)
66
+ rows_by_anchor = defaultdict(list)
67
+ for r in puzzles:
68
+ if not r["Themes"]:
69
+ continue
70
+ ca = build_chess_anchor(r["Themes"], r["OpeningTags"])
71
+ freq[ca] += 1
72
+ rows_by_anchor[ca].append(r)
73
+ rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
74
+ key=lambda kv: kv[1])
75
+ heldout = [a for a, _ in rare[:EVAL_QUERIES]]
76
+
77
+ qchess, qen = [], []
78
+ corp_chess, corp_en = [], []
79
+ held_per_doc = []
80
+ ch_to_en = {}
81
+ for ca in heldout:
82
+ for r in rows_by_anchor[ca]:
83
+ corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"]))
84
+ corp_en.append(build_english_doc(r))
85
+ held_per_doc.append(ca)
86
+ if ca not in ch_to_en:
87
+ ch_to_en[ca] = build_english_anchor(r)
88
+ qchess = list(heldout)
89
+ qen = [ch_to_en[a] for a in qchess]
90
+ by_anchor = defaultdict(list)
91
+ for i, a in enumerate(held_per_doc):
92
+ by_anchor[a].append(i)
93
+ print(f" {len(qchess)} queries, {len(corp_chess)} corpus docs")
94
+
95
+ # 1. Static-only
96
+ print("\n[1] Static (v4-C2) alone, ranks all corpus directly")
97
+ static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final")
98
+ sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
99
+ sc = sc / np.linalg.norm(sc, axis=1, keepdims=True)
100
+ sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
101
+ sq = sq / np.linalg.norm(sq, axis=1, keepdims=True)
102
+ static_sims = sq @ sc.T
103
+ static_ndcgs = []
104
+ for qi in range(len(qchess)):
105
+ rel = set(by_anchor[qchess[qi]])
106
+ score_pairs = [(int(j), float(static_sims[qi, j])) for j in range(len(corp_chess))]
107
+ static_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
108
+ print(f" static-only NDCG@10: {np.mean(static_ndcgs):.4f}")
109
+
110
+ # 2. BM25 over chess-format corpus (theme stripped)
111
+ print("\n[2] BM25 alone over chess-format corpus (theme tokens stripped — same docs static sees)")
112
+ bm25_chess = BM25Okapi([d.split() for d in corp_chess])
113
+ bm25_chess_ndcgs = []
114
+ for qi, q in enumerate(qchess):
115
+ scores = bm25_chess.get_scores(q.split())
116
+ score_pairs = [(j, float(scores[j])) for j in range(len(corp_chess))]
117
+ bm25_chess_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[q]), k=10))
118
+ print(f" BM25 (chess docs, query=chess anchor) NDCG@10: {np.mean(bm25_chess_ndcgs):.4f}")
119
+
120
+ # 3. BM25 over English-bridged corpus (theme tokens visible)
121
+ print("\n[3] BM25 alone over English-bridged corpus")
122
+ bm25_en = BM25Okapi([d.split() for d in corp_en])
123
+ bm25_en_ndcgs = []
124
+ for qi, q in enumerate(qen):
125
+ scores = bm25_en.get_scores(q.split())
126
+ score_pairs = [(j, float(scores[j])) for j in range(len(corp_en))]
127
+ bm25_en_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[qchess[qi]]), k=10))
128
+ print(f" BM25 (English docs, query=English anchor) NDCG@10: {np.mean(bm25_en_ndcgs):.4f}")
129
+
130
+ # Also: static + BM25 hybrid (RRF fusion)
131
+ print("\n[4] Static + BM25 fusion (RRF, K=60)")
132
+ K_RRF = 60
133
+ rrf_ndcgs = []
134
+ for qi in range(len(qchess)):
135
+ rel = set(by_anchor[qchess[qi]])
136
+ st_rank = np.argsort(-static_sims[qi]).argsort()
137
+ bm = bm25_en.get_scores(qen[qi].split())
138
+ bm_rank = np.argsort(-bm).argsort()
139
+ fused = 1.0 / (K_RRF + st_rank + 1) + 1.0 / (K_RRF + bm_rank + 1)
140
+ score_pairs = [(j, float(fused[j])) for j in range(len(corp_chess))]
141
+ rrf_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
142
+ print(f" Static + BM25-English RRF fusion NDCG@10: {np.mean(rrf_ndcgs):.4f}")
143
+
144
+ # Summary
145
+ print("\n" + "=" * 70)
146
+ print(f"{'Approach':<55} {'NDCG@10':>12}")
147
+ print("=" * 70)
148
+ print(f"{'Static (v4-C2) alone':<55} {np.mean(static_ndcgs):>12.4f}")
149
+ print(f"{'BM25 alone over chess-format docs':<55} {np.mean(bm25_chess_ndcgs):>12.4f}")
150
+ print(f"{'BM25 alone over English-bridged docs':<55} {np.mean(bm25_en_ndcgs):>12.4f}")
151
+ print(f"{'Static + BM25-English RRF fusion':<55} {np.mean(rrf_ndcgs):>12.4f}")
152
+ print("=" * 70)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()