#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = [ # "sentence-transformers[train]>=5.5.0", # "datasets>=2.19.0", # "numpy", # ] # /// """Side-by-side comparison of all chess static-embedding variants on the same held-out compositional eval. Produces the final table for NOTES.md. """ from __future__ import annotations import os import sys from collections import defaultdict import numpy as np from datasets import load_dataset from sentence_transformers import SentenceTransformer sys.stdout.reconfigure(line_buffering=True) VARIANTS = [ ("v3 baseline", "models/static-embedding-chess/final"), ("v4-A hard-neg only", "models/static-embedding-chess-triplet/final"), ("v4-B theme distill", "models/static-embedding-chess-theme-only/final"), ("v4-C multitask 500x", "models/static-embedding-chess-multitask-500x/final"), ("v4-C2 multitask 5000x", "models/static-embedding-chess-multitask-5000x/final"), ] HELDOUT_FREQ_MIN = 3 HELDOUT_FREQ_MAX = 30 EVAL_QUERIES = 200 def _join_tags(tags): return " ".join(t.replace("_", " ") for t in tags) if tags else "" def _bigram_token_str(moves): toks = moves.split() if len(toks) < 2: return moves return moves + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) def build_puzzle_pairs(batch): anchors, positives = [], [] for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]): themes_txt = _join_tags(themes) op_txt = _join_tags(op) if not themes_txt: continue anchor = themes_txt + (f" {op_txt}" if op_txt else "") positive = f"themes {themes_txt}" if op_txt: positive += f" opening {op_txt}" positive += f" moves {_bigram_token_str(moves)}" anchors.append(anchor) positives.append(positive) return {"anchor": anchors, "positive": positives} def strip_theme_echo(p): i = p.find(" moves ") return p[i + 1 :] if i != -1 else p def ndcg_at_k(scores, rel, k=10): ranked = sorted(scores, key=lambda kv: -kv[1])[:k] dcg = sum((1.0 if d in rel else 0.0) / np.log2(r + 2) for r, (d, _) in enumerate(ranked)) idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), k))) return dcg / idcg if idcg > 0 else 0.0 def main(): print("Loading + held-out selection...") puzzles = load_dataset("Lichess/chess-puzzles", split="train") pair_puzzles = puzzles.map( build_puzzle_pairs, batched=True, batch_size=20_000, remove_columns=puzzles.column_names, num_proc=4, ) anchors = pair_puzzles["anchor"] freq = defaultdict(int) for a in anchors: freq[a] += 1 rare_pool = sorted( ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1], ) heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]} held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h] held_anchors = [anchors[i] for i in held_idx] corpus_texts = [strip_theme_echo(pair_puzzles["positive"][i]) for i in held_idx] corpus_ids = [f"d{i}" for i in range(len(corpus_texts))] by_anchor = defaultdict(list) for i, a in enumerate(held_anchors): by_anchor[a].append(corpus_ids[i]) queries = list(by_anchor.keys()) print(f" {len(queries)} queries, {len(corpus_texts)} corpus") results = [] for name, path in VARIANTS: if not os.path.exists(path): print(f"\nSKIPPING {name}: {path} not found") continue print(f"\n=== {name} ({path}) ===") m = SentenceTransformer(path) c = m.encode(corpus_texts, batch_size=128, convert_to_numpy=True, show_progress_bar=False) c = c / np.linalg.norm(c, axis=1, keepdims=True) q = m.encode(queries, batch_size=128, convert_to_numpy=True, show_progress_bar=False) q = q / np.linalg.norm(q, axis=1, keepdims=True) sims = q @ c.T ndcgs = [] for qi, query in enumerate(queries): score_pairs = [(corpus_ids[ci], float(sims[qi, ci])) for ci in range(len(corpus_ids))] rel = set(by_anchor[query]) ndcgs.append(ndcg_at_k(score_pairs, rel, k=10)) ndcg = np.mean(ndcgs) median = np.median(ndcgs) zero = sum(1 for n in ndcgs if n == 0) results.append((name, ndcg, median, zero, len(ndcgs))) print(f" NDCG@10 = {ndcg:.4f} median = {median:.4f} zero = {zero}/{len(ndcgs)}") print("\n" + "=" * 70) print(f"{'Variant':<30} {'NDCG@10':>10} {'Median':>10} {'Zero/All':>15}") print("=" * 70) for name, ndcg, median, zero, total in results: print(f"{name:<30} {ndcg:>10.4f} {median:>10.4f} {zero:>7}/{total:<7}") print("=" * 70) # === Token-similarity probe === # Measures the orthogonal-tokens problem from Phase 1: do related themes # cluster in embedding space? Higher = more semantic structure. print("\n=== Theme-token similarity (higher = more semantic clustering) ===") PROBES = [ ("fork", "skewer"), # tactical motifs (should be close) ("fork", "pin"), ("backRankMate", "smotheredMate"), # mate patterns ("kingsideAttack", "queensideAttack"), ("endgame", "middlegame"), # phases ("fork", "promotion"), # unrelated (control) ] print(f"{'Pair':<40}", end="") for name, _ in VARIANTS: if os.path.exists([p for n, p in VARIANTS if n == name][0]): print(f" {name[:14]:>16}", end="") print() print("-" * 70) for a, b in PROBES: line = f"{a} <-> {b}".ljust(40) for name, path in VARIANTS: if not os.path.exists(path): continue m = SentenceTransformer(path) ea = m.encode([a], convert_to_numpy=True)[0] eb = m.encode([b], convert_to_numpy=True)[0] ea = ea / max(np.linalg.norm(ea), 1e-9) eb = eb / max(np.linalg.norm(eb), 1e-9) sim = float(np.dot(ea, eb)) line += f" {sim:>+16.3f}" print(line) if __name__ == "__main__": main() if __name__ == "__main__": main()