Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "sentence-transformers[train]>=5.5.0", | |
| # "datasets>=2.19.0", | |
| # "numpy", | |
| # ] | |
| # /// | |
| """Side-by-side comparison of all chess static-embedding variants on the same | |
| held-out compositional eval. Produces the final table for NOTES.md. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from collections import defaultdict | |
| import numpy as np | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| sys.stdout.reconfigure(line_buffering=True) | |
| VARIANTS = [ | |
| ("v3 baseline", "models/static-embedding-chess/final"), | |
| ("v4-A hard-neg only", "models/static-embedding-chess-triplet/final"), | |
| ("v4-B theme distill", "models/static-embedding-chess-theme-only/final"), | |
| ("v4-C multitask 500x", "models/static-embedding-chess-multitask-500x/final"), | |
| ("v4-C2 multitask 5000x", "models/static-embedding-chess-multitask-5000x/final"), | |
| ] | |
| HELDOUT_FREQ_MIN = 3 | |
| HELDOUT_FREQ_MAX = 30 | |
| EVAL_QUERIES = 200 | |
| def _join_tags(tags): | |
| return " ".join(t.replace("_", " ") for t in tags) if tags else "" | |
| def _bigram_token_str(moves): | |
| toks = moves.split() | |
| if len(toks) < 2: | |
| return moves | |
| return moves + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) | |
| def build_puzzle_pairs(batch): | |
| anchors, positives = [], [] | |
| for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]): | |
| themes_txt = _join_tags(themes) | |
| op_txt = _join_tags(op) | |
| if not themes_txt: | |
| continue | |
| anchor = themes_txt + (f" {op_txt}" if op_txt else "") | |
| positive = f"themes {themes_txt}" | |
| if op_txt: | |
| positive += f" opening {op_txt}" | |
| positive += f" moves {_bigram_token_str(moves)}" | |
| anchors.append(anchor) | |
| positives.append(positive) | |
| return {"anchor": anchors, "positive": positives} | |
| def strip_theme_echo(p): | |
| i = p.find(" moves ") | |
| return p[i + 1 :] if i != -1 else p | |
| def ndcg_at_k(scores, rel, k=10): | |
| ranked = sorted(scores, key=lambda kv: -kv[1])[:k] | |
| dcg = sum((1.0 if d in rel else 0.0) / np.log2(r + 2) for r, (d, _) in enumerate(ranked)) | |
| idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), k))) | |
| return dcg / idcg if idcg > 0 else 0.0 | |
| def main(): | |
| print("Loading + held-out selection...") | |
| puzzles = load_dataset("Lichess/chess-puzzles", split="train") | |
| pair_puzzles = puzzles.map( | |
| build_puzzle_pairs, | |
| batched=True, batch_size=20_000, | |
| remove_columns=puzzles.column_names, | |
| num_proc=4, | |
| ) | |
| anchors = pair_puzzles["anchor"] | |
| freq = defaultdict(int) | |
| for a in anchors: | |
| freq[a] += 1 | |
| rare_pool = sorted( | |
| ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), | |
| key=lambda kv: kv[1], | |
| ) | |
| heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]} | |
| held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h] | |
| held_anchors = [anchors[i] for i in held_idx] | |
| corpus_texts = [strip_theme_echo(pair_puzzles["positive"][i]) for i in held_idx] | |
| corpus_ids = [f"d{i}" for i in range(len(corpus_texts))] | |
| by_anchor = defaultdict(list) | |
| for i, a in enumerate(held_anchors): | |
| by_anchor[a].append(corpus_ids[i]) | |
| queries = list(by_anchor.keys()) | |
| print(f" {len(queries)} queries, {len(corpus_texts)} corpus") | |
| results = [] | |
| for name, path in VARIANTS: | |
| if not os.path.exists(path): | |
| print(f"\nSKIPPING {name}: {path} not found") | |
| continue | |
| print(f"\n=== {name} ({path}) ===") | |
| m = SentenceTransformer(path) | |
| c = m.encode(corpus_texts, batch_size=128, convert_to_numpy=True, show_progress_bar=False) | |
| c = c / np.linalg.norm(c, axis=1, keepdims=True) | |
| q = m.encode(queries, batch_size=128, convert_to_numpy=True, show_progress_bar=False) | |
| q = q / np.linalg.norm(q, axis=1, keepdims=True) | |
| sims = q @ c.T | |
| ndcgs = [] | |
| for qi, query in enumerate(queries): | |
| score_pairs = [(corpus_ids[ci], float(sims[qi, ci])) for ci in range(len(corpus_ids))] | |
| rel = set(by_anchor[query]) | |
| ndcgs.append(ndcg_at_k(score_pairs, rel, k=10)) | |
| ndcg = np.mean(ndcgs) | |
| median = np.median(ndcgs) | |
| zero = sum(1 for n in ndcgs if n == 0) | |
| results.append((name, ndcg, median, zero, len(ndcgs))) | |
| print(f" NDCG@10 = {ndcg:.4f} median = {median:.4f} zero = {zero}/{len(ndcgs)}") | |
| print("\n" + "=" * 70) | |
| print(f"{'Variant':<30} {'NDCG@10':>10} {'Median':>10} {'Zero/All':>15}") | |
| print("=" * 70) | |
| for name, ndcg, median, zero, total in results: | |
| print(f"{name:<30} {ndcg:>10.4f} {median:>10.4f} {zero:>7}/{total:<7}") | |
| print("=" * 70) | |
| # === Token-similarity probe === | |
| # Measures the orthogonal-tokens problem from Phase 1: do related themes | |
| # cluster in embedding space? Higher = more semantic structure. | |
| print("\n=== Theme-token similarity (higher = more semantic clustering) ===") | |
| PROBES = [ | |
| ("fork", "skewer"), # tactical motifs (should be close) | |
| ("fork", "pin"), | |
| ("backRankMate", "smotheredMate"), # mate patterns | |
| ("kingsideAttack", "queensideAttack"), | |
| ("endgame", "middlegame"), # phases | |
| ("fork", "promotion"), # unrelated (control) | |
| ] | |
| print(f"{'Pair':<40}", end="") | |
| for name, _ in VARIANTS: | |
| if os.path.exists([p for n, p in VARIANTS if n == name][0]): | |
| print(f" {name[:14]:>16}", end="") | |
| print() | |
| print("-" * 70) | |
| for a, b in PROBES: | |
| line = f"{a} <-> {b}".ljust(40) | |
| for name, path in VARIANTS: | |
| if not os.path.exists(path): | |
| continue | |
| m = SentenceTransformer(path) | |
| ea = m.encode([a], convert_to_numpy=True)[0] | |
| eb = m.encode([b], convert_to_numpy=True)[0] | |
| ea = ea / max(np.linalg.norm(ea), 1e-9) | |
| eb = eb / max(np.linalg.norm(eb), 1e-9) | |
| sim = float(np.dot(ea, eb)) | |
| line += f" {sim:>+16.3f}" | |
| print(line) | |
| if __name__ == "__main__": | |
| main() | |
| if __name__ == "__main__": | |
| main() | |