static-embedding-chess / scripts /generate_theme_defs.py
oneryalcin's picture
Add files using upload-large-folder tool
f8392aa verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets>=2.19.0",
# "openai>=1.0",
# "sentence-transformers[train]>=5.5.0",
# "tqdm",
# "numpy",
# ]
# ///
"""Generate natural-language definitions for each Lichess theme via DeepSeek,
then embed those definitions with a general sentence-transformer (MPNet).
The resulting (theme_token, definition_embedding) pairs form a "chess-aware
teacher" — an English description of each chess concept that MPNet CAN
understand semantically. We can then distill those embeddings into our
StaticEmbedding model's token table.
Solves the "MPNet doesn't know chess" problem: MPNet can't read UCI moves,
but it CAN read English ("A tactical motif where one piece attacks two pieces
simultaneously" → semantically near "A tactic where you create a double
attack threatening two pieces at once"). Token-level semantic structure
emerges from the LLM bridge.
Run:
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 generate_theme_defs.py
uv run --exclude-newer=2026-05-12 generate_theme_defs.py
"""
import json
import os
import subprocess
import sys
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from datasets import Dataset, load_dataset
from openai import OpenAI
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
MODEL = "deepseek-v4-flash"
TEACHER_MODEL = "sentence-transformers/all-mpnet-base-v2"
OUTPUT_PATH = "models/theme_definitions.parquet"
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
PARALLEL_WORKERS = 4
SYSTEM_PROMPT = """You write concise dictionary-style definitions of chess
concepts. Given a theme/concept name (often in camelCase from Lichess.org's
puzzle tagging system), write a single English sentence of 10-25 words
explaining the concept. Be specific and use the standard chess vocabulary that
would appear in any chess textbook.
Output ONLY the definition sentence. No labels, no quotes, no commentary.
Examples:
Input: fork
Output: A tactical motif where a single piece attacks two or more enemy pieces simultaneously, forcing a material gain.
Input: backRankMate
Output: A checkmate delivered along the opponent's back rank, typically with a rook or queen, when the king is trapped by its own pawns.
Input: zugzwang
Output: A position in which any move worsens the player's position, so being forced to move becomes a disadvantage.
"""
def get_deepseek_key():
r = subprocess.run(
["security", "find-generic-password", "-s", "deepseek-api", "-w"],
capture_output=True, text=True, timeout=5,
)
return r.stdout.strip() if r.returncode == 0 else os.environ.get("DEEPSEEK_API_KEY")
def define_theme(client, theme, debug=False):
try:
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": theme},
],
temperature=0.2,
max_tokens=1500, # DeepSeek-v4-flash spends tokens on reasoning_content; obscure mate-pattern names need lots
timeout=30,
)
content = resp.choices[0].message.content
return content.strip() if content else None
except Exception as e:
if debug:
print(f" EXC for {theme!r}: {type(e).__name__}: {e}")
return None
def main():
key = get_deepseek_key()
if not key:
sys.exit("No DeepSeek API key in keychain")
client = OpenAI(api_key=key, base_url="https://api.deepseek.com/v1")
print("Enumerating themes from Lichess puzzles...")
puzzles = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)
counter = Counter()
sample_size = 50_000 if SMOKE_TEST else 1_000_000
for i, r in enumerate(puzzles):
if i >= sample_size:
break
for t in (r["Themes"] or []):
counter[t] += 1
themes = sorted(counter.keys())
print(f" {len(themes)} unique themes")
if SMOKE_TEST:
themes = themes[:10]
print(f" SMOKE_TEST=1: limited to {len(themes)}")
print(f"\nGenerating definitions via {MODEL}...")
defs = {}
with ThreadPoolExecutor(max_workers=PARALLEL_WORKERS) as ex:
futs = {ex.submit(define_theme, client, t, True): t for t in themes}
for f in tqdm(as_completed(futs), total=len(futs)):
t = futs[f]
defs[t] = f.result()
failed = [t for t, d in defs.items() if not d]
if failed:
print(f" {len(failed)} themes failed: {failed[:5]}")
print(f" {len(defs) - len(failed)}/{len(defs)} succeeded")
print("\nSample definitions:")
for t in themes[:8]:
if defs[t]:
print(f" {t:>20s} -> {defs[t]}")
valid = [(t, defs[t]) for t in themes if defs[t]]
print(f"\nEmbedding {len(valid)} definitions with {TEACHER_MODEL}...")
teacher = SentenceTransformer(TEACHER_MODEL)
sentences = [d for _, d in valid]
embs = teacher.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_numpy=True)
# Sanity: do related themes have similar embeddings?
emb_norm = embs / np.linalg.norm(embs, axis=1, keepdims=True)
sim = emb_norm @ emb_norm.T
print("\nSanity check: pairwise similarities for related themes")
name_to_idx = {t: i for i, (t, _) in enumerate(valid)}
for a, b in [
("fork", "skewer"), ("fork", "pin"), ("backRankMate", "smotheredMate"),
("kingsideAttack", "queensideAttack"), ("endgame", "middlegame"),
("fork", "promotion"), # not directly related
]:
if a in name_to_idx and b in name_to_idx:
print(f" {a!r:>20} <-> {b!r:25} = {sim[name_to_idx[a], name_to_idx[b]]:+.3f}")
out = Dataset.from_dict({
"theme": [t for t, _ in valid],
"definition": [d for _, d in valid],
"embedding": embs.tolist(),
})
os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True)
out.to_parquet(OUTPUT_PATH)
print(f"\nSaved {len(out)} theme definitions to {OUTPUT_PATH}")
if __name__ == "__main__":
main()