Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"] | |
| # /// | |
| """Compare trained CE vs BM25 on English-bridged docs, plus top-K sweep. | |
| Tests: | |
| 1. Is the 0.59 CE result just lexical match that BM25 could also do? | |
| 2. Does increasing K to 200/300 push past oracle 0.59 → 0.77 → 0.87? | |
| """ | |
| import os | |
| import sys | |
| from collections import defaultdict | |
| import numpy as np | |
| from datasets import Dataset, load_dataset | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| sys.stdout.reconfigure(line_buffering=True) | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from convert_to_english import build_english_anchor, build_english_doc | |
| HELDOUT_FREQ_MIN = 3 | |
| HELDOUT_FREQ_MAX = 30 | |
| EVAL_QUERIES = 200 | |
| def _join_tags(tags): | |
| return " ".join(t.replace("_", " ") for t in tags) if tags else "" | |
| def _bigram(m): | |
| toks = m.split() | |
| return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m | |
| def build_chess_anchor(themes, op): | |
| tt = _join_tags(themes) | |
| ot = _join_tags(op or []) | |
| return tt + (f" {ot}" if ot else "") | |
| def build_chess_doc_stripped(themes, op, moves): | |
| return f"moves {_bigram(moves)}" | |
| def ndcg_at_k(scores, rel, k=10): | |
| r = sorted(scores, key=lambda kv: -kv[1])[:k] | |
| dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r)) | |
| idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k))) | |
| return dcg / idcg if idcg > 0 else 0 | |
| def main(): | |
| print("Building eval set...") | |
| puzzles = load_dataset("Lichess/chess-puzzles", split="train") | |
| freq = defaultdict(int) | |
| rows_by_anchor = defaultdict(list) | |
| for r in puzzles: | |
| if not r["Themes"]: | |
| continue | |
| ca = build_chess_anchor(r["Themes"], r["OpeningTags"]) | |
| freq[ca] += 1 | |
| rows_by_anchor[ca].append(r) | |
| rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1]) | |
| heldout = [a for a, _ in rare[:EVAL_QUERIES]] | |
| print(f" {len(heldout)} held-out anchors") | |
| qchess, qen = [], [] | |
| corp_chess, corp_en = [], [] | |
| held_per_doc = [] | |
| ch_to_en = {} | |
| for ca in heldout: | |
| for r in rows_by_anchor[ca]: | |
| corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"])) | |
| corp_en.append(build_english_doc(r)) | |
| held_per_doc.append(ca) | |
| if ca not in ch_to_en: | |
| ch_to_en[ca] = build_english_anchor(r) | |
| qchess = list(heldout) | |
| qen = [ch_to_en[a] for a in qchess] | |
| by_anchor = defaultdict(list) | |
| for i, a in enumerate(held_per_doc): | |
| by_anchor[a].append(i) | |
| print(f" corpus: {len(corp_chess)} docs") | |
| print("\nLoading static (v4-C2) for first-stage...") | |
| static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final") | |
| sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) | |
| sc = sc / np.linalg.norm(sc, axis=1, keepdims=True) | |
| sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False) | |
| sq = sq / np.linalg.norm(sq, axis=1, keepdims=True) | |
| static_sims = sq @ sc.T | |
| # Loaded trained CE | |
| print("Loading trained CE...") | |
| ce = CrossEncoder("models/chess-reranker-english/final") | |
| # BM25 on English docs | |
| print("Building BM25 over English docs...") | |
| bm25 = BM25Okapi([d.split() for d in corp_en]) | |
| print("\n" + "=" * 80) | |
| print(f" {'K':>4} {'Static':>10} {'+CE':>10} {'+BM25':>10} {'Oracle':>10}") | |
| print("=" * 80) | |
| for k in [10, 50, 100, 200, 300]: | |
| if k > len(corp_chess): | |
| continue | |
| static_ndcg = [] | |
| ce_ndcg = [] | |
| bm25_ndcg = [] | |
| oracle_ndcg = [] | |
| for qi, q_chess in enumerate(qchess): | |
| rel = set(by_anchor[q_chess]) | |
| # Static-only at top-10 | |
| top10 = np.argsort(-static_sims[qi])[:10] | |
| sp = [(int(i), float(static_sims[qi, int(i)])) for i in top10] | |
| static_ndcg.append(ndcg_at_k(sp, rel, k=10)) | |
| # Top-K shortlist | |
| topk = np.argsort(-static_sims[qi])[:k] | |
| # CE rerank | |
| pairs = [[qen[qi], corp_en[int(i)]] for i in topk] | |
| ce_scores = ce.predict(pairs, batch_size=64, show_progress_bar=False, convert_to_numpy=True) | |
| ce_sp = [(int(topk[j]), float(ce_scores[j])) for j in range(len(topk))] | |
| ce_ndcg.append(ndcg_at_k(ce_sp, rel, k=10)) | |
| # BM25 rerank over top-K shortlist | |
| bm_full = bm25.get_scores(qen[qi].split()) | |
| bm_sp = [(int(topk[j]), float(bm_full[int(topk[j])])) for j in range(len(topk))] | |
| bm25_ndcg.append(ndcg_at_k(bm_sp, rel, k=10)) | |
| # Oracle ceiling | |
| rel_in_topk = len(rel & set(int(i) for i in topk)) | |
| n10 = min(10, rel_in_topk) | |
| dcg = sum(1.0 / np.log2(r + 2) for r in range(n10)) | |
| idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), 10))) | |
| oracle_ndcg.append(dcg / idcg if idcg > 0 else 0) | |
| # static stays the same regardless of K | |
| static_v = np.mean(static_ndcg) | |
| print(f" {k:>4} {static_v:>10.4f} {np.mean(ce_ndcg):>10.4f} {np.mean(bm25_ndcg):>10.4f} {np.mean(oracle_ndcg):>10.4f}") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| main() | |