Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
Upload scripts/diag_static_vs_bm25_alone.py with huggingface_hub
Browse files
scripts/diag_static_vs_bm25_alone.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"]
|
| 5 |
+
# ///
|
| 6 |
+
"""DIRECT comparison: Do we need the static embedding at all?
|
| 7 |
+
|
| 8 |
+
Tests BM25 over English-bridged corpus AS THE ONLY RETRIEVER (not as a
|
| 9 |
+
reranker after static). If BM25 alone hits or beats our static embedding,
|
| 10 |
+
we don't need the static model.
|
| 11 |
+
|
| 12 |
+
Three configurations evaluated:
|
| 13 |
+
1. Static-only (v4-C2 retrieves top-10 directly)
|
| 14 |
+
2. BM25-only over English-bridged corpus (no static)
|
| 15 |
+
3. BM25-only over chess-format corpus (theme tokens stripped — our original
|
| 16 |
+
eval format)
|
| 17 |
+
|
| 18 |
+
This is the apples-to-apples question.
|
| 19 |
+
"""
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from datasets import load_dataset
|
| 26 |
+
from rank_bm25 import BM25Okapi
|
| 27 |
+
from sentence_transformers import SentenceTransformer
|
| 28 |
+
|
| 29 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 30 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 31 |
+
from convert_to_english import build_english_anchor, build_english_doc
|
| 32 |
+
|
| 33 |
+
HELDOUT_FREQ_MIN = 3
|
| 34 |
+
HELDOUT_FREQ_MAX = 30
|
| 35 |
+
EVAL_QUERIES = 200
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _join_tags(tags):
|
| 39 |
+
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _bigram(m):
|
| 43 |
+
toks = m.split()
|
| 44 |
+
return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_chess_anchor(themes, op):
|
| 48 |
+
return _join_tags(themes) + (f" {_join_tags(op or [])}" if op else "")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_chess_doc_stripped(themes, op, moves):
|
| 52 |
+
return f"moves {_bigram(moves)}"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def ndcg_at_k(scores, rel, k=10):
|
| 56 |
+
r = sorted(scores, key=lambda kv: -kv[1])[:k]
|
| 57 |
+
dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r))
|
| 58 |
+
idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k)))
|
| 59 |
+
return dcg / idcg if idcg > 0 else 0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
print("Building held-out eval set (same as v3/v4)...")
|
| 64 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 65 |
+
freq = defaultdict(int)
|
| 66 |
+
rows_by_anchor = defaultdict(list)
|
| 67 |
+
for r in puzzles:
|
| 68 |
+
if not r["Themes"]:
|
| 69 |
+
continue
|
| 70 |
+
ca = build_chess_anchor(r["Themes"], r["OpeningTags"])
|
| 71 |
+
freq[ca] += 1
|
| 72 |
+
rows_by_anchor[ca].append(r)
|
| 73 |
+
rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
|
| 74 |
+
key=lambda kv: kv[1])
|
| 75 |
+
heldout = [a for a, _ in rare[:EVAL_QUERIES]]
|
| 76 |
+
|
| 77 |
+
qchess, qen = [], []
|
| 78 |
+
corp_chess, corp_en = [], []
|
| 79 |
+
held_per_doc = []
|
| 80 |
+
ch_to_en = {}
|
| 81 |
+
for ca in heldout:
|
| 82 |
+
for r in rows_by_anchor[ca]:
|
| 83 |
+
corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"]))
|
| 84 |
+
corp_en.append(build_english_doc(r))
|
| 85 |
+
held_per_doc.append(ca)
|
| 86 |
+
if ca not in ch_to_en:
|
| 87 |
+
ch_to_en[ca] = build_english_anchor(r)
|
| 88 |
+
qchess = list(heldout)
|
| 89 |
+
qen = [ch_to_en[a] for a in qchess]
|
| 90 |
+
by_anchor = defaultdict(list)
|
| 91 |
+
for i, a in enumerate(held_per_doc):
|
| 92 |
+
by_anchor[a].append(i)
|
| 93 |
+
print(f" {len(qchess)} queries, {len(corp_chess)} corpus docs")
|
| 94 |
+
|
| 95 |
+
# 1. Static-only
|
| 96 |
+
print("\n[1] Static (v4-C2) alone, ranks all corpus directly")
|
| 97 |
+
static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final")
|
| 98 |
+
sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
|
| 99 |
+
sc = sc / np.linalg.norm(sc, axis=1, keepdims=True)
|
| 100 |
+
sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
|
| 101 |
+
sq = sq / np.linalg.norm(sq, axis=1, keepdims=True)
|
| 102 |
+
static_sims = sq @ sc.T
|
| 103 |
+
static_ndcgs = []
|
| 104 |
+
for qi in range(len(qchess)):
|
| 105 |
+
rel = set(by_anchor[qchess[qi]])
|
| 106 |
+
score_pairs = [(int(j), float(static_sims[qi, j])) for j in range(len(corp_chess))]
|
| 107 |
+
static_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
|
| 108 |
+
print(f" static-only NDCG@10: {np.mean(static_ndcgs):.4f}")
|
| 109 |
+
|
| 110 |
+
# 2. BM25 over chess-format corpus (theme stripped)
|
| 111 |
+
print("\n[2] BM25 alone over chess-format corpus (theme tokens stripped — same docs static sees)")
|
| 112 |
+
bm25_chess = BM25Okapi([d.split() for d in corp_chess])
|
| 113 |
+
bm25_chess_ndcgs = []
|
| 114 |
+
for qi, q in enumerate(qchess):
|
| 115 |
+
scores = bm25_chess.get_scores(q.split())
|
| 116 |
+
score_pairs = [(j, float(scores[j])) for j in range(len(corp_chess))]
|
| 117 |
+
bm25_chess_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[q]), k=10))
|
| 118 |
+
print(f" BM25 (chess docs, query=chess anchor) NDCG@10: {np.mean(bm25_chess_ndcgs):.4f}")
|
| 119 |
+
|
| 120 |
+
# 3. BM25 over English-bridged corpus (theme tokens visible)
|
| 121 |
+
print("\n[3] BM25 alone over English-bridged corpus")
|
| 122 |
+
bm25_en = BM25Okapi([d.split() for d in corp_en])
|
| 123 |
+
bm25_en_ndcgs = []
|
| 124 |
+
for qi, q in enumerate(qen):
|
| 125 |
+
scores = bm25_en.get_scores(q.split())
|
| 126 |
+
score_pairs = [(j, float(scores[j])) for j in range(len(corp_en))]
|
| 127 |
+
bm25_en_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[qchess[qi]]), k=10))
|
| 128 |
+
print(f" BM25 (English docs, query=English anchor) NDCG@10: {np.mean(bm25_en_ndcgs):.4f}")
|
| 129 |
+
|
| 130 |
+
# Also: static + BM25 hybrid (RRF fusion)
|
| 131 |
+
print("\n[4] Static + BM25 fusion (RRF, K=60)")
|
| 132 |
+
K_RRF = 60
|
| 133 |
+
rrf_ndcgs = []
|
| 134 |
+
for qi in range(len(qchess)):
|
| 135 |
+
rel = set(by_anchor[qchess[qi]])
|
| 136 |
+
st_rank = np.argsort(-static_sims[qi]).argsort()
|
| 137 |
+
bm = bm25_en.get_scores(qen[qi].split())
|
| 138 |
+
bm_rank = np.argsort(-bm).argsort()
|
| 139 |
+
fused = 1.0 / (K_RRF + st_rank + 1) + 1.0 / (K_RRF + bm_rank + 1)
|
| 140 |
+
score_pairs = [(j, float(fused[j])) for j in range(len(corp_chess))]
|
| 141 |
+
rrf_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
|
| 142 |
+
print(f" Static + BM25-English RRF fusion NDCG@10: {np.mean(rrf_ndcgs):.4f}")
|
| 143 |
+
|
| 144 |
+
# Summary
|
| 145 |
+
print("\n" + "=" * 70)
|
| 146 |
+
print(f"{'Approach':<55} {'NDCG@10':>12}")
|
| 147 |
+
print("=" * 70)
|
| 148 |
+
print(f"{'Static (v4-C2) alone':<55} {np.mean(static_ndcgs):>12.4f}")
|
| 149 |
+
print(f"{'BM25 alone over chess-format docs':<55} {np.mean(bm25_chess_ndcgs):>12.4f}")
|
| 150 |
+
print(f"{'BM25 alone over English-bridged docs':<55} {np.mean(bm25_en_ndcgs):>12.4f}")
|
| 151 |
+
print(f"{'Static + BM25-English RRF fusion':<55} {np.mean(rrf_ndcgs):>12.4f}")
|
| 152 |
+
print("=" * 70)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
main()
|