Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"] | |
| # /// | |
| """DIRECT comparison: Do we need the static embedding at all? | |
| Tests BM25 over English-bridged corpus AS THE ONLY RETRIEVER (not as a | |
| reranker after static). If BM25 alone hits or beats our static embedding, | |
| we don't need the static model. | |
| Three configurations evaluated: | |
| 1. Static-only (v4-C2 retrieves top-10 directly) | |
| 2. BM25-only over English-bridged corpus (no static) | |
| 3. BM25-only over chess-format corpus (theme tokens stripped — our original | |
| eval format) | |
| This is the apples-to-apples question. | |
| """ | |
| import os | |
| import sys | |
| from collections import defaultdict | |
| import numpy as np | |
| from datasets import load_dataset | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| sys.stdout.reconfigure(line_buffering=True) | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from convert_to_english import build_english_anchor, build_english_doc | |
| HELDOUT_FREQ_MIN = 3 | |
| HELDOUT_FREQ_MAX = 30 | |
| EVAL_QUERIES = 200 | |
| def _join_tags(tags): | |
| return " ".join(t.replace("_", " ") for t in tags) if tags else "" | |
| def _bigram(m): | |
| toks = m.split() | |
| return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m | |
| def build_chess_anchor(themes, op): | |
| return _join_tags(themes) + (f" {_join_tags(op or [])}" if op else "") | |
| def build_chess_doc_stripped(themes, op, moves): | |
| return f"moves {_bigram(moves)}" | |
| def ndcg_at_k(scores, rel, k=10): | |
| r = sorted(scores, key=lambda kv: -kv[1])[:k] | |
| dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r)) | |
| idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k))) | |
| return dcg / idcg if idcg > 0 else 0 | |
| def main(): | |
| print("Building held-out eval set (same as v3/v4)...") | |
| puzzles = load_dataset("Lichess/chess-puzzles", split="train") | |
| freq = defaultdict(int) | |
| rows_by_anchor = defaultdict(list) | |
| for r in puzzles: | |
| if not r["Themes"]: | |
| continue | |
| ca = build_chess_anchor(r["Themes"], r["OpeningTags"]) | |
| freq[ca] += 1 | |
| rows_by_anchor[ca].append(r) | |
| rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), | |
| key=lambda kv: kv[1]) | |
| heldout = [a for a, _ in rare[:EVAL_QUERIES]] | |
| qchess, qen = [], [] | |
| corp_chess, corp_en = [], [] | |
| held_per_doc = [] | |
| ch_to_en = {} | |
| for ca in heldout: | |
| for r in rows_by_anchor[ca]: | |
| corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"])) | |
| corp_en.append(build_english_doc(r)) | |
| held_per_doc.append(ca) | |
| if ca not in ch_to_en: | |
| ch_to_en[ca] = build_english_anchor(r) | |
| qchess = list(heldout) | |
| qen = [ch_to_en[a] for a in qchess] | |
| by_anchor = defaultdict(list) | |
| for i, a in enumerate(held_per_doc): | |
| by_anchor[a].append(i) | |
| print(f" {len(qchess)} queries, {len(corp_chess)} corpus docs") | |
| # 1. Static-only | |
| print("\n[1] Static (v4-C2) alone, ranks all corpus directly") | |
| static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final") | |
| sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) | |
| sc = sc / np.linalg.norm(sc, axis=1, keepdims=True) | |
| sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) | |
| sq = sq / np.linalg.norm(sq, axis=1, keepdims=True) | |
| static_sims = sq @ sc.T | |
| static_ndcgs = [] | |
| for qi in range(len(qchess)): | |
| rel = set(by_anchor[qchess[qi]]) | |
| score_pairs = [(int(j), float(static_sims[qi, j])) for j in range(len(corp_chess))] | |
| static_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10)) | |
| print(f" static-only NDCG@10: {np.mean(static_ndcgs):.4f}") | |
| # 2. BM25 over chess-format corpus (theme stripped) | |
| print("\n[2] BM25 alone over chess-format corpus (theme tokens stripped — same docs static sees)") | |
| bm25_chess = BM25Okapi([d.split() for d in corp_chess]) | |
| bm25_chess_ndcgs = [] | |
| for qi, q in enumerate(qchess): | |
| scores = bm25_chess.get_scores(q.split()) | |
| score_pairs = [(j, float(scores[j])) for j in range(len(corp_chess))] | |
| bm25_chess_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[q]), k=10)) | |
| print(f" BM25 (chess docs, query=chess anchor) NDCG@10: {np.mean(bm25_chess_ndcgs):.4f}") | |
| # 3. BM25 over English-bridged corpus (theme tokens visible) | |
| print("\n[3] BM25 alone over English-bridged corpus") | |
| bm25_en = BM25Okapi([d.split() for d in corp_en]) | |
| bm25_en_ndcgs = [] | |
| for qi, q in enumerate(qen): | |
| scores = bm25_en.get_scores(q.split()) | |
| score_pairs = [(j, float(scores[j])) for j in range(len(corp_en))] | |
| bm25_en_ndcgs.append(ndcg_at_k(score_pairs, set(by_anchor[qchess[qi]]), k=10)) | |
| print(f" BM25 (English docs, query=English anchor) NDCG@10: {np.mean(bm25_en_ndcgs):.4f}") | |
| # Also: static + BM25 hybrid (RRF fusion) | |
| print("\n[4] Static + BM25 fusion (RRF, K=60)") | |
| K_RRF = 60 | |
| rrf_ndcgs = [] | |
| for qi in range(len(qchess)): | |
| rel = set(by_anchor[qchess[qi]]) | |
| st_rank = np.argsort(-static_sims[qi]).argsort() | |
| bm = bm25_en.get_scores(qen[qi].split()) | |
| bm_rank = np.argsort(-bm).argsort() | |
| fused = 1.0 / (K_RRF + st_rank + 1) + 1.0 / (K_RRF + bm_rank + 1) | |
| score_pairs = [(j, float(fused[j])) for j in range(len(corp_chess))] | |
| rrf_ndcgs.append(ndcg_at_k(score_pairs, rel, k=10)) | |
| print(f" Static + BM25-English RRF fusion NDCG@10: {np.mean(rrf_ndcgs):.4f}") | |
| # Summary | |
| print("\n" + "=" * 70) | |
| print(f"{'Approach':<55} {'NDCG@10':>12}") | |
| print("=" * 70) | |
| print(f"{'Static (v4-C2) alone':<55} {np.mean(static_ndcgs):>12.4f}") | |
| print(f"{'BM25 alone over chess-format docs':<55} {np.mean(bm25_chess_ndcgs):>12.4f}") | |
| print(f"{'BM25 alone over English-bridged docs':<55} {np.mean(bm25_en_ndcgs):>12.4f}") | |
| print(f"{'Static + BM25-English RRF fusion':<55} {np.mean(rrf_ndcgs):>12.4f}") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| main() | |