static-embedding-chess / scripts /compare_variants.py
oneryalcin's picture
Add files using upload-large-folder tool
f8392aa verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "sentence-transformers[train]>=5.5.0",
# "datasets>=2.19.0",
# "numpy",
# ]
# ///
"""Side-by-side comparison of all chess static-embedding variants on the same
held-out compositional eval. Produces the final table for NOTES.md.
"""
from __future__ import annotations
import os
import sys
from collections import defaultdict
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
sys.stdout.reconfigure(line_buffering=True)
VARIANTS = [
("v3 baseline", "models/static-embedding-chess/final"),
("v4-A hard-neg only", "models/static-embedding-chess-triplet/final"),
("v4-B theme distill", "models/static-embedding-chess-theme-only/final"),
("v4-C multitask 500x", "models/static-embedding-chess-multitask-500x/final"),
("v4-C2 multitask 5000x", "models/static-embedding-chess-multitask-5000x/final"),
]
HELDOUT_FREQ_MIN = 3
HELDOUT_FREQ_MAX = 30
EVAL_QUERIES = 200
def _join_tags(tags):
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
def _bigram_token_str(moves):
toks = moves.split()
if len(toks) < 2:
return moves
return moves + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
def build_puzzle_pairs(batch):
anchors, positives = [], []
for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
themes_txt = _join_tags(themes)
op_txt = _join_tags(op)
if not themes_txt:
continue
anchor = themes_txt + (f" {op_txt}" if op_txt else "")
positive = f"themes {themes_txt}"
if op_txt:
positive += f" opening {op_txt}"
positive += f" moves {_bigram_token_str(moves)}"
anchors.append(anchor)
positives.append(positive)
return {"anchor": anchors, "positive": positives}
def strip_theme_echo(p):
i = p.find(" moves ")
return p[i + 1 :] if i != -1 else p
def ndcg_at_k(scores, rel, k=10):
ranked = sorted(scores, key=lambda kv: -kv[1])[:k]
dcg = sum((1.0 if d in rel else 0.0) / np.log2(r + 2) for r, (d, _) in enumerate(ranked))
idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), k)))
return dcg / idcg if idcg > 0 else 0.0
def main():
print("Loading + held-out selection...")
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
pair_puzzles = puzzles.map(
build_puzzle_pairs,
batched=True, batch_size=20_000,
remove_columns=puzzles.column_names,
num_proc=4,
)
anchors = pair_puzzles["anchor"]
freq = defaultdict(int)
for a in anchors:
freq[a] += 1
rare_pool = sorted(
((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
key=lambda kv: kv[1],
)
heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]}
held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h]
held_anchors = [anchors[i] for i in held_idx]
corpus_texts = [strip_theme_echo(pair_puzzles["positive"][i]) for i in held_idx]
corpus_ids = [f"d{i}" for i in range(len(corpus_texts))]
by_anchor = defaultdict(list)
for i, a in enumerate(held_anchors):
by_anchor[a].append(corpus_ids[i])
queries = list(by_anchor.keys())
print(f" {len(queries)} queries, {len(corpus_texts)} corpus")
results = []
for name, path in VARIANTS:
if not os.path.exists(path):
print(f"\nSKIPPING {name}: {path} not found")
continue
print(f"\n=== {name} ({path}) ===")
m = SentenceTransformer(path)
c = m.encode(corpus_texts, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
c = c / np.linalg.norm(c, axis=1, keepdims=True)
q = m.encode(queries, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
q = q / np.linalg.norm(q, axis=1, keepdims=True)
sims = q @ c.T
ndcgs = []
for qi, query in enumerate(queries):
score_pairs = [(corpus_ids[ci], float(sims[qi, ci])) for ci in range(len(corpus_ids))]
rel = set(by_anchor[query])
ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
ndcg = np.mean(ndcgs)
median = np.median(ndcgs)
zero = sum(1 for n in ndcgs if n == 0)
results.append((name, ndcg, median, zero, len(ndcgs)))
print(f" NDCG@10 = {ndcg:.4f} median = {median:.4f} zero = {zero}/{len(ndcgs)}")
print("\n" + "=" * 70)
print(f"{'Variant':<30} {'NDCG@10':>10} {'Median':>10} {'Zero/All':>15}")
print("=" * 70)
for name, ndcg, median, zero, total in results:
print(f"{name:<30} {ndcg:>10.4f} {median:>10.4f} {zero:>7}/{total:<7}")
print("=" * 70)
# === Token-similarity probe ===
# Measures the orthogonal-tokens problem from Phase 1: do related themes
# cluster in embedding space? Higher = more semantic structure.
print("\n=== Theme-token similarity (higher = more semantic clustering) ===")
PROBES = [
("fork", "skewer"), # tactical motifs (should be close)
("fork", "pin"),
("backRankMate", "smotheredMate"), # mate patterns
("kingsideAttack", "queensideAttack"),
("endgame", "middlegame"), # phases
("fork", "promotion"), # unrelated (control)
]
print(f"{'Pair':<40}", end="")
for name, _ in VARIANTS:
if os.path.exists([p for n, p in VARIANTS if n == name][0]):
print(f" {name[:14]:>16}", end="")
print()
print("-" * 70)
for a, b in PROBES:
line = f"{a} <-> {b}".ljust(40)
for name, path in VARIANTS:
if not os.path.exists(path):
continue
m = SentenceTransformer(path)
ea = m.encode([a], convert_to_numpy=True)[0]
eb = m.encode([b], convert_to_numpy=True)[0]
ea = ea / max(np.linalg.norm(ea), 1e-9)
eb = eb / max(np.linalg.norm(eb), 1e-9)
sim = float(np.dot(ea, eb))
line += f" {sim:>+16.3f}"
print(line)
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()