Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "sentence-transformers[train]>=5.5.0", | |
| # "datasets>=2.19.0", | |
| # "accelerate>=0.26.0", | |
| # "tokenizers>=0.20", | |
| # "trackio", | |
| # ] | |
| # /// | |
| """Train a StaticEmbedding model for chess retrieval. | |
| Pair shape: | |
| anchor = "<themes> [<opening words>]" | |
| positive = "themes <themes> [opening <words>] moves <uci>" (puzzles) | |
| "name <words> eco <code> pgn <san>" (openings) | |
| Datasets: | |
| - Lichess/chess-puzzles (5.8M rows; themes + opening tags + UCI moves) | |
| - Lichess/chess-openings (3.6K rows; opening name + ECO + SAN moves) | |
| Use case: free-text search over a chess corpus. "fork endgame short" -> puzzles | |
| with that motif; "Sicilian Najdorf" -> matching openings. | |
| Design choices: | |
| - Custom WordLevel + Whitespace tokenizer trained on the corpus. Every chess | |
| token (UCI move e2e4, SAN move Nxd4, ECO code B90, theme name, opening word) | |
| is one whole token -- BERT WordPiece would shred them 4-way. | |
| - FEN dropped: position-as-character-soup doesn't fit a token-bag. | |
| - PGN move numbers stripped ("1. e4 c5" -> "e4 c5") so SAN moves are high-freq. | |
| - IR eval is custom (themes -> puzzles), not NanoBEIR -- general-English IR | |
| benchmarks don't measure chess retrieval. | |
| Run: | |
| SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_static.py | |
| uv run --exclude-newer=2026-05-12 train_chess_static.py | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import re | |
| from collections import defaultdict | |
| from contextlib import nullcontext | |
| import datasets | |
| import random | |
| import torch | |
| from datasets import Dataset, concatenate_datasets, load_dataset | |
| from tokenizers import Tokenizer | |
| from tokenizers.models import WordLevel | |
| from tokenizers.pre_tokenizers import Whitespace | |
| from tokenizers.trainers import WordLevelTrainer | |
| from sentence_transformers import ( | |
| SentenceTransformer, | |
| SentenceTransformerModelCardData, | |
| SentenceTransformerTrainer, | |
| SentenceTransformerTrainingArguments, | |
| ) | |
| from sentence_transformers.base.sampler import BatchSamplers | |
| from sentence_transformers.sentence_transformer.evaluation import ( | |
| InformationRetrievalEvaluator, | |
| SequentialEvaluator, | |
| ) | |
| from sentence_transformers.sentence_transformer.losses import ( | |
| MatryoshkaLoss, | |
| MultipleNegativesRankingLoss, | |
| ) | |
| from sentence_transformers.sentence_transformer.modules import StaticEmbedding | |
| from transformers import EarlyStoppingCallback, TrainerCallback | |
| import time | |
| EMBEDDING_DIM = 512 # was 256; 512 gives more capacity for bigram tokens | |
| MATRYOSHKA_DIMS = [512, 256, 128, 64, 32] | |
| VOCAB_SIZE = 100_000 # was 50_000; UCI/SAN bigrams add ~20-50k vocab | |
| OUTPUT_DIR = "models/static-embedding-chess" | |
| RUN_NAME = "static-embedding-chess" | |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "oneryalcin/static-embedding-chess") | |
| # TOKENIZER_PATH default lives next to the model output. On Modal, set this to | |
| # a path on the persistent volume (e.g. /cache/chess_tokenizer.json) so the | |
| # 6-min WordLevelTrainer run is amortized across launches. | |
| TOKENIZER_PATH = os.environ.get( | |
| "TOKENIZER_PATH", f"{OUTPUT_DIR}/chess_tokenizer.json" | |
| ) | |
| RETRAIN_TOKENIZER = os.environ.get("RETRAIN_TOKENIZER") == "1" | |
| SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1" | |
| FORCE_CPU = os.environ.get("FORCE_CPU") == "1" | |
| # Diagnostic knobs (default: full recipe). Both MPS and T4 show monotonic | |
| # step-time growth with the full Matryoshka stack -- toggle these to isolate. | |
| DISABLE_MATRYOSHKA = os.environ.get("DISABLE_MATRYOSHKA") == "1" | |
| MAX_STEPS_OVERRIDE = int(os.environ.get("MAX_STEPS", "0")) or None | |
| EVAL_STEPS_OVERRIDE = int(os.environ.get("EVAL_STEPS", "0")) or None | |
| EVAL_QUERIES = 200 | |
| EVAL_CORPUS = 5_000 | |
| # Held-out anchor selection: pick rare combos in this freq range. Low end > 1 | |
| # keeps multi-relevant NDCG meaningful; high end caps memorization potential. | |
| HELDOUT_FREQ_MIN = 3 | |
| HELDOUT_FREQ_MAX = 30 | |
| # Balanced-dataset config: each unique anchor expands to N (anchor, sampled_pos) | |
| # rows. The original 5.8M pairs let the model memorize specific (anchor, pos) | |
| # pairings since each anchor has ~1933 distinct positives. Capping at 100 | |
| # random samples per anchor gives the model meaningful variety without the | |
| # 50x redundancy that fuels overfitting. | |
| BALANCED_POSITIVES_PER_ANCHOR = int(os.environ.get("POSITIVES_PER_ANCHOR", "100")) | |
| # Anchor token masking probability during training. 0 disables. | |
| ANCHOR_MASK_PROB = float(os.environ.get("ANCHOR_MASK_PROB", "0.15")) | |
| # Device-aware defaults. MPS (Apple Silicon) can't do bf16 and has unified- | |
| # memory pressure, so the CUDA-targeted skill template defaults (batch=2048, | |
| # bf16=True) don't apply. Scale BATCH_SIZE up if your M-series has 36GB+. | |
| IS_CUDA = torch.cuda.is_available() and not FORCE_CPU | |
| IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() and not FORCE_CPU | |
| # StaticEmbedding is a lookup+average -- no transformer activations to fit. | |
| # Memory cost is the (batch x batch) similarity matrix + (batch x seq x dim) | |
| # lookups, both tiny. CachedMultipleNegativesRankingLoss is NOT compatible | |
| # with StaticEmbedding (no encoder to GradCache through), so we just crank | |
| # the real batch. Scale up freely if your M-series has the headroom. | |
| BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256) | |
| MOVE_NUM_RE = re.compile(r"\d+\.+") | |
| class StepTimingCallback(TrainerCallback): | |
| """Per-step instrumentation: wall time, CUDA memory, allocator state. | |
| Costs ~1ms/step. Run-once-and-read approach to diagnosing slowdowns | |
| instead of swapping configs and rerunning. | |
| """ | |
| def on_step_begin(self, args, state, control, **kw): | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self._t0 = time.perf_counter() | |
| def on_step_end(self, args, state, control, **kw): | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| dt = time.perf_counter() - self._t0 | |
| # Log every step for the first 20 to see startup; then every 10th. | |
| if state.global_step <= 20 or state.global_step % 10 == 0: | |
| if torch.cuda.is_available(): | |
| mem = torch.cuda.memory_allocated() / 1e6 | |
| reserved = torch.cuda.memory_reserved() / 1e6 | |
| logging.info( | |
| f"STEP {state.global_step}: dt={dt:.3f}s mem={mem:.0f}MB reserved={reserved:.0f}MB" | |
| ) | |
| else: | |
| logging.info(f"STEP {state.global_step}: dt={dt:.3f}s (cpu/mps)") | |
| def autocast_ctx(): | |
| if IS_CUDA: | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| return torch.autocast("cuda", dtype=dtype) | |
| if IS_MPS: | |
| return torch.autocast("mps", dtype=torch.float16) | |
| return nullcontext() | |
| def setup_logging(): | |
| os.makedirs("logs", exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=logging.INFO, | |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")], | |
| force=True, | |
| ) | |
| for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"): | |
| logging.getLogger(noisy).setLevel(logging.WARNING) | |
| if torch.cuda.is_available(): | |
| torch.set_float32_matmul_precision("high") | |
| def _join_tags(tags) -> str: | |
| if not tags: | |
| return "" | |
| return " ".join(t.replace("_", " ") for t in tags) | |
| def _strip_pgn_move_numbers(pgn: str) -> str: | |
| return MOVE_NUM_RE.sub("", pgn).strip() | |
| def _bigram_token_str(moves: str) -> str: | |
| """Append bigram tokens to a whitespace-separated move sequence. | |
| "f2g3 e6e7 b2b1" -> "f2g3 e6e7 b2b1 f2g3+e6e7 e6e7+b2b1" | |
| Bigrams use `+` as the join char so they're distinct from unigrams in the | |
| WordLevel tokenizer's whitespace pretokenizer. A token-bag averaging across | |
| unigrams alone loses move ordering; adding adjacent-pair tokens lets the | |
| model learn that "e2e4 e7e5" (king's pawn opening) is its own pattern. | |
| """ | |
| tokens = moves.split() | |
| if len(tokens) < 2: | |
| return moves | |
| bigrams = " ".join(f"{a}+{b}" for a, b in zip(tokens, tokens[1:])) | |
| return f"{moves} {bigrams}" | |
| def build_puzzle_pairs(row_batch: dict) -> dict: | |
| anchors, positives = [], [] | |
| for themes, opening_tags, moves in zip( | |
| row_batch["Themes"], row_batch["OpeningTags"], row_batch["Moves"] | |
| ): | |
| themes_txt = _join_tags(themes) | |
| opening_txt = _join_tags(opening_tags) | |
| if not themes_txt: | |
| continue | |
| anchor = themes_txt + (f" {opening_txt}" if opening_txt else "") | |
| positive = f"themes {themes_txt}" | |
| if opening_txt: | |
| positive += f" opening {opening_txt}" | |
| positive += f" moves {_bigram_token_str(moves)}" | |
| anchors.append(anchor) | |
| positives.append(positive) | |
| return {"anchor": anchors, "positive": positives} | |
| def build_opening_pairs(row_batch: dict) -> dict: | |
| anchors, positives = [], [] | |
| for name, eco, pgn in zip(row_batch["name"], row_batch["eco"], row_batch["pgn"]): | |
| san = _strip_pgn_move_numbers(pgn) | |
| anchors.append(f"{name} {eco}") | |
| positives.append(f"name {name} eco {eco} pgn {_bigram_token_str(san)}") | |
| return {"anchor": anchors, "positive": positives} | |
| def load_chess_pairs() -> tuple[Dataset, Dataset]: | |
| """Returns (train, holdout) where the holdout anchors are rare combinations | |
| NEVER seen in train. | |
| Old eval used the top-200 most-common theme strings as queries. The model | |
| memorized these in training (each appears ~50k times) so eval was a recall | |
| test on memorized lookups, not generalization. Replaced with compositional | |
| held-out anchors: | |
| - Pick anchor strings with frequency in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX]: | |
| rare enough to be informative, common enough to have multiple positives | |
| for multi-relevant eval. | |
| - REMOVE all pairs with those anchors from train (no leakage). | |
| - Use those rare anchors as eval queries; the held-out pairs become the | |
| eval corpus. | |
| - Individual theme tokens within those anchors still appear *separately* | |
| in many other training anchors, so the model has learned each token's | |
| embedding -- it just hasn't seen this particular combination. Tests | |
| compositional generalization. | |
| """ | |
| logging.info("Loading Lichess/chess-puzzles (5.8M rows)") | |
| puzzles = load_dataset("Lichess/chess-puzzles", split="train") | |
| if SMOKE_TEST: | |
| puzzles = puzzles.select(range(2_000)) | |
| pair_puzzles = puzzles.map( | |
| build_puzzle_pairs, | |
| batched=True, | |
| batch_size=10_000, | |
| remove_columns=puzzles.column_names, | |
| desc="puzzles -> pairs", | |
| ) | |
| logging.info(f" built {len(pair_puzzles):,} puzzle pairs") | |
| logging.info("Loading Lichess/chess-openings (3.6K rows)") | |
| openings = load_dataset("Lichess/chess-openings", split="train").remove_columns(["img"]) | |
| pair_openings = openings.map( | |
| build_opening_pairs, | |
| batched=True, | |
| remove_columns=openings.column_names, | |
| desc="openings -> pairs", | |
| ) | |
| logging.info(f" built {len(pair_openings):,} opening pairs") | |
| # Count anchor frequencies across the puzzle pairs. | |
| logging.info("Computing anchor frequencies for held-out selection") | |
| anchors = pair_puzzles["anchor"] | |
| freq: dict[str, int] = defaultdict(int) | |
| for a in anchors: | |
| freq[a] += 1 | |
| logging.info(f" {len(freq):,} unique anchors in puzzle pairs") | |
| # Pick rare anchors: each appears in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX] pairs. | |
| # In smoke mode, lower the min so the tiny corpus still produces enough | |
| # held-out queries (smoke has ~2k puzzles, most anchors freq 1-2). | |
| min_freq = 2 if SMOKE_TEST else HELDOUT_FREQ_MIN | |
| max_freq = HELDOUT_FREQ_MAX | |
| rare_pool = sorted( | |
| ((a, c) for a, c in freq.items() if min_freq <= c <= max_freq), | |
| key=lambda kv: kv[1], # ascending: rarest first | |
| ) | |
| n_queries_target = 20 if SMOKE_TEST else EVAL_QUERIES | |
| if len(rare_pool) < n_queries_target: | |
| logging.warning( | |
| f"Only {len(rare_pool)} anchors in freq range [{HELDOUT_FREQ_MIN},{HELDOUT_FREQ_MAX}]; " | |
| f"using all of them ({n_queries_target} requested)" | |
| ) | |
| heldout_anchors = {a for a, _ in rare_pool[:n_queries_target]} | |
| logging.info( | |
| f" selected {len(heldout_anchors)} held-out anchors " | |
| f"(freq range: {rare_pool[0][1] if rare_pool else 0}..{rare_pool[min(n_queries_target, len(rare_pool))-1][1] if rare_pool else 0})" | |
| ) | |
| # Filter: pairs whose anchor is held-out -> eval; everything else -> train. | |
| held_mask = [a in heldout_anchors for a in anchors] | |
| holdout = pair_puzzles.select([i for i, h in enumerate(held_mask) if h]) | |
| train_puzzles = pair_puzzles.select([i for i, h in enumerate(held_mask) if not h]) | |
| logging.info( | |
| f" split by held-out anchors: train={len(train_puzzles):,}, holdout={len(holdout):,}" | |
| ) | |
| # Train includes the (non-held) puzzle pairs + all openings. | |
| train = concatenate_datasets([train_puzzles, pair_openings]).shuffle(seed=12) | |
| logging.info(f" train: {len(train):,} pairs | holdout: {len(holdout):,} pairs") | |
| return train, holdout | |
| def make_balanced_dataset(train: Dataset, n_per_anchor: int) -> Dataset: | |
| """Cap each anchor's positives to `n_per_anchor` random picks. Breaks the | |
| 5.8M pairs' redundancy (each anchor x ~1933 positives) so the model can't | |
| memorize specific (anchor, positive) pairings while still seeing useful | |
| positive variety per anchor. | |
| """ | |
| by_anchor: dict[str, list[str]] = defaultdict(list) | |
| for row in train: | |
| by_anchor[row["anchor"]].append(row["positive"]) | |
| rng = random.Random(12) | |
| new_anchors, new_positives = [], [] | |
| for anchor, positives in by_anchor.items(): | |
| sample = ( | |
| rng.sample(positives, n_per_anchor) | |
| if len(positives) > n_per_anchor | |
| else positives | |
| ) | |
| for p in sample: | |
| new_anchors.append(anchor) | |
| new_positives.append(p) | |
| logging.info( | |
| f"Balanced dataset: {len(by_anchor):,} unique anchors -> " | |
| f"{len(new_anchors):,} pairs (cap {n_per_anchor}/anchor)" | |
| ) | |
| return Dataset.from_dict({"anchor": new_anchors, "positive": new_positives}).shuffle(seed=12) | |
| def make_anchor_masker(mask_prob: float, rng_seed: int = 12): | |
| """Return a `set_transform` callable that randomly replaces theme tokens | |
| with [UNK] in the anchor. Token-bag dropout: forces the model to use | |
| remaining tokens instead of memorizing the exact combination.""" | |
| if mask_prob <= 0: | |
| return None | |
| rng = random.Random(rng_seed) | |
| def _mask(batch: dict) -> dict: | |
| anchors = batch["anchor"] | |
| new_anchors = [] | |
| for a in anchors: | |
| tokens = a.split() | |
| if len(tokens) <= 1: | |
| new_anchors.append(a) | |
| continue | |
| kept = [t if rng.random() >= mask_prob else "[UNK]" for t in tokens] | |
| # Guard against masking everything: if all UNK, restore one random token. | |
| if all(t == "[UNK]" for t in kept): | |
| kept[rng.randrange(len(kept))] = tokens[rng.randrange(len(tokens))] | |
| new_anchors.append(" ".join(kept)) | |
| return {"anchor": new_anchors, "positive": batch["positive"]} | |
| return _mask | |
| def train_chess_tokenizer(train: Dataset) -> Tokenizer: | |
| """Train or load a WordLevel tokenizer for the chess corpus. | |
| Every space-separated unit (theme word, opening word, ECO code, UCI move, | |
| SAN move) becomes one whole token. Compare to BERT WordPiece which fragments | |
| "f2g3" into 4 subword pieces -- a token-bag wastes capacity on subword joins | |
| that carry no chess meaning. | |
| Caching: if TOKENIZER_PATH exists, load and return it instead of rebuilding. | |
| The WordLevelTrainer is single-threaded Rust and takes ~6 min on 11.6M | |
| strings. Tokenizer is deterministic given the same corpus + config, so | |
| caching is safe. Set RETRAIN_TOKENIZER=1 to force rebuild. | |
| """ | |
| if not RETRAIN_TOKENIZER and os.path.exists(TOKENIZER_PATH): | |
| tok = Tokenizer.from_file(TOKENIZER_PATH) | |
| logging.info( | |
| f"Reusing cached tokenizer ({tok.get_vocab_size():,} tokens) from {TOKENIZER_PATH}" | |
| ) | |
| return tok | |
| logging.info(f"Training WordLevel tokenizer on {len(train):,} pairs (vocab={VOCAB_SIZE})") | |
| tok = Tokenizer(WordLevel(unk_token="[UNK]")) | |
| tok.pre_tokenizer = Whitespace() | |
| trainer = WordLevelTrainer( | |
| vocab_size=VOCAB_SIZE, | |
| special_tokens=["[UNK]", "[PAD]"], | |
| min_frequency=2, | |
| ) | |
| def text_iter(): | |
| for row in train: | |
| yield row["anchor"] | |
| yield row["positive"] | |
| tok.train_from_iterator(text_iter(), trainer=trainer, length=2 * len(train)) | |
| actual_vocab = tok.get_vocab_size() | |
| logging.info(f" tokenizer trained: {actual_vocab:,} tokens (cap was {VOCAB_SIZE:,})") | |
| os.makedirs(os.path.dirname(TOKENIZER_PATH) or ".", exist_ok=True) | |
| tok.save(TOKENIZER_PATH) | |
| logging.info(f" saved tokenizer to {TOKENIZER_PATH}") | |
| return tok | |
| def _strip_theme_echo(positive: str) -> str: | |
| """Eval corpus must not echo the themes the query asks about, or the | |
| baseline (random-init) scores high just from lexical token overlap. Keep | |
| only the moves segment.""" | |
| idx = positive.find(" moves ") | |
| return positive[idx + 1 :] if idx != -1 else positive | |
| def _build_compositional_ir_evaluator( | |
| holdout: Dataset, corpus: dict[str, str], name: str | |
| ) -> InformationRetrievalEvaluator: | |
| """Compositional: each unseen anchor string is a query.""" | |
| by_anchor: dict[str, set[str]] = defaultdict(set) | |
| for i, row in enumerate(holdout): | |
| by_anchor[row["anchor"]].add(f"d{i}") | |
| sorted_anchors = sorted(by_anchor.items(), key=lambda kv: -len(kv[1])) | |
| queries = {f"q{i}": anchor for i, (anchor, _) in enumerate(sorted_anchors)} | |
| relevant_docs = {f"q{i}": docs for i, (_, docs) in enumerate(sorted_anchors)} | |
| avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs)) | |
| logging.info( | |
| f" [{name}] {len(queries)} queries (unseen combos), avg relevant/query={avg_rel:.1f}" | |
| ) | |
| return _ir_evaluator(queries, corpus, relevant_docs, name) | |
| def _build_single_theme_ir_evaluator( | |
| holdout: Dataset, corpus: dict[str, str], name: str | |
| ) -> InformationRetrievalEvaluator: | |
| """Single-theme: each individual theme token from the held-out anchors is | |
| a query. Tests whether per-token embeddings are useful in isolation. | |
| Relevant docs for query "fork" = any held-out doc whose anchor contains | |
| the token "fork". Coarser than the compositional eval (much higher avg | |
| relevant/query) but a sharper test of token-level meaning. | |
| """ | |
| theme_to_docs: dict[str, set[str]] = defaultdict(set) | |
| for i, row in enumerate(holdout): | |
| for token in row["anchor"].split(): | |
| theme_to_docs[token].add(f"d{i}") | |
| min_relevant = 2 if SMOKE_TEST else 3 | |
| candidates = [(t, d) for t, d in theme_to_docs.items() if len(d) >= min_relevant] | |
| candidates.sort(key=lambda kv: -len(kv[1])) | |
| queries = {f"t{i}": tok for i, (tok, _) in enumerate(candidates)} | |
| relevant_docs = {f"t{i}": docs for i, (_, docs) in enumerate(candidates)} | |
| avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs)) | |
| logging.info( | |
| f" [{name}] {len(queries)} single-token queries, avg relevant/query={avg_rel:.1f}" | |
| ) | |
| return _ir_evaluator(queries, corpus, relevant_docs, name) | |
| def _ir_evaluator(queries, corpus, relevant_docs, name): | |
| return InformationRetrievalEvaluator( | |
| queries=queries, | |
| corpus=corpus, | |
| relevant_docs=relevant_docs, | |
| name=name, | |
| ndcg_at_k=[10], | |
| mrr_at_k=[10], | |
| accuracy_at_k=[1, 10], | |
| precision_recall_at_k=[1, 10], | |
| show_progress_bar=False, | |
| batch_size=256, | |
| ) | |
| def build_ir_evaluator(holdout: Dataset, name: str = "chess-ir") -> SequentialEvaluator: | |
| """Wraps two evaluators (compositional + single-theme) into a sequential | |
| pass. The compositional one's score drives best-model selection; the | |
| single-theme one is informational. | |
| """ | |
| corpus = {f"d{i}": _strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)} | |
| logging.info(f"IR eval setup ({len(corpus)} corpus docs):") | |
| compositional = _build_compositional_ir_evaluator(holdout, corpus, name=name) | |
| single_theme = _build_single_theme_ir_evaluator(holdout, corpus, name=f"{name}-tokens") | |
| # First evaluator's score drives load_best_model_at_end (compositional). | |
| return SequentialEvaluator( | |
| [compositional, single_theme], | |
| main_score_function=lambda scores: scores[0], | |
| ) | |
| def main() -> None: | |
| setup_logging() | |
| train_dataset, holdout = load_chess_pairs() | |
| if SMOKE_TEST: | |
| train_dataset = train_dataset.select(range(min(500, len(train_dataset)))) | |
| # Train the tokenizer on the FULL (pre-balanced) corpus -- we want every | |
| # token to be seen as many times as possible for the vocab pass. | |
| tokenizer = train_chess_tokenizer(train_dataset) | |
| # Now down-sample to a balanced dataset for the contrastive training. | |
| train_dataset = make_balanced_dataset(train_dataset, BALANCED_POSITIVES_PER_ANCHOR) | |
| # Optional anchor-token masking applied on the fly via set_transform. | |
| masker = make_anchor_masker(ANCHOR_MASK_PROB) | |
| if masker is not None: | |
| logging.info(f"Anchor token masking enabled (p={ANCHOR_MASK_PROB})") | |
| train_dataset.set_transform(masker) | |
| logging.info(f"Random-init StaticEmbedding (dim={EMBEDDING_DIM})") | |
| static_embedding = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM) | |
| model = SentenceTransformer( | |
| modules=[static_embedding], | |
| model_card_data=SentenceTransformerModelCardData( | |
| language="en", | |
| license="apache-2.0", | |
| model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- themes/openings <-> positions", | |
| ), | |
| ) | |
| evaluator = build_ir_evaluator(holdout) | |
| inner = MultipleNegativesRankingLoss(model) | |
| if DISABLE_MATRYOSHKA: | |
| logging.info("Matryoshka DISABLED -- training at single dim (diagnostic)") | |
| loss = inner | |
| else: | |
| loss = MatryoshkaLoss(model, inner, matryoshka_dims=MATRYOSHKA_DIMS) | |
| logging.info("Baseline evaluation (random init -- expect near-zero):") | |
| with autocast_ctx(): | |
| baseline_eval = evaluator(model)[evaluator.primary_metric] | |
| metric_key = f"eval_{evaluator.primary_metric}" | |
| logging.info(f" baseline {evaluator.primary_metric} = {baseline_eval:.4f}") | |
| if SMOKE_TEST: | |
| max_steps = 1 | |
| elif MAX_STEPS_OVERRIDE: | |
| max_steps = MAX_STEPS_OVERRIDE | |
| else: | |
| max_steps = -1 | |
| eval_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 # 20 evals/run | |
| save_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 | |
| args = SentenceTransformerTrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| # Balanced dataset is small (~300k pairs); need many epochs to reach | |
| # comparable total training signal. Early stopping handles excess. | |
| num_train_epochs=20, | |
| max_steps=max_steps, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| learning_rate=1e-2, # was 5e-2 -- much slower convergence, shifts peak later | |
| weight_decay=0.01, # was 0.0 -- regularization on the embedding table | |
| warmup_steps=0.1, | |
| lr_scheduler_type="linear", | |
| bf16=IS_CUDA and torch.cuda.is_bf16_supported(), | |
| fp16=IS_CUDA and not torch.cuda.is_bf16_supported(), | |
| # was NO_DUPLICATES -- linked-list scan over deferred conflicts gives | |
| # O(epoch_progress) per-batch cost. With ~3000 unique anchors over | |
| # 5.8M pairs, dedup is fighting impossible odds. BATCH_SAMPLER (random) | |
| # is fast and accepts mild within-batch anchor duplication. | |
| batch_sampler=BatchSamplers.BATCH_SAMPLER, | |
| eval_strategy="steps", | |
| eval_steps=eval_steps, | |
| save_strategy="steps", | |
| save_steps=save_steps, | |
| save_total_limit=2, | |
| logging_steps=0.01, | |
| logging_first_step=True, | |
| load_best_model_at_end=True, | |
| metric_for_best_model=metric_key, | |
| greater_is_better=True, | |
| # Trackio crashes at first checkpoint push: empty `router_mapping` | |
| # struct can't be written to parquet. Disable. | |
| report_to="none", | |
| run_name=RUN_NAME, | |
| seed=12, | |
| # HF Jobs: container is destroyed after run -- push every checkpoint to | |
| # the Hub so partial progress survives a timeout. The end-of-run | |
| # model.push_to_hub() below is the belt to this suspenders. | |
| push_to_hub=not SMOKE_TEST, | |
| hub_model_id=HUB_MODEL_ID, | |
| hub_strategy="every_save", | |
| ) | |
| trainer = SentenceTransformerTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| loss=loss, | |
| evaluator=evaluator, | |
| callbacks=[ | |
| # Auto-stop if compositional NDCG@10 doesn't improve for 3 evals. | |
| # Lower lr makes curves smoother -- give it slack vs the patience=2 | |
| # we used at lr=5e-2. | |
| EarlyStoppingCallback(early_stopping_patience=3), | |
| # Per-step memory + dt logging. | |
| StepTimingCallback(), | |
| ], | |
| ) | |
| trainer.train() | |
| logging.info("Post-training evaluation:") | |
| with autocast_ctx(): | |
| score = evaluator(model)[evaluator.primary_metric] | |
| delta = score - baseline_eval | |
| verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION" | |
| logging.info( | |
| f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline_eval:.4f} | delta={delta:+.4f}" | |
| ) | |
| final_dir = f"{OUTPUT_DIR}/final" | |
| model.save_pretrained(final_dir) | |
| logging.info(f"Saved final model to {final_dir}") | |
| if SMOKE_TEST: | |
| logging.info("SMOKE_TEST=1: skipping Hub push") | |
| return | |
| try: | |
| commit_url = model.push_to_hub(HUB_MODEL_ID) | |
| logging.info(f"Pushed model to {commit_url.rsplit('/commit/', 1)[0]}") | |
| except Exception: | |
| import traceback | |
| logging.error(f"Hub push failed:\n{traceback.format_exc()}") | |
| if __name__ == "__main__": | |
| main() | |