File size: 26,465 Bytes
f8392aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "sentence-transformers[train]>=5.5.0",
#     "datasets>=2.19.0",
#     "accelerate>=0.26.0",
#     "tokenizers>=0.20",
#     "trackio",
# ]
# ///
"""Train a StaticEmbedding model for chess retrieval.

Pair shape:
    anchor   = "<themes> [<opening words>]"
    positive = "themes <themes> [opening <words>] moves <uci>"          (puzzles)
               "name <words> eco <code> pgn <san>"                       (openings)

Datasets:
- Lichess/chess-puzzles  (5.8M rows; themes + opening tags + UCI moves)
- Lichess/chess-openings (3.6K rows; opening name + ECO + SAN moves)

Use case: free-text search over a chess corpus. "fork endgame short" -> puzzles
with that motif; "Sicilian Najdorf" -> matching openings.

Design choices:
- Custom WordLevel + Whitespace tokenizer trained on the corpus. Every chess
  token (UCI move e2e4, SAN move Nxd4, ECO code B90, theme name, opening word)
  is one whole token -- BERT WordPiece would shred them 4-way.
- FEN dropped: position-as-character-soup doesn't fit a token-bag.
- PGN move numbers stripped ("1. e4 c5" -> "e4 c5") so SAN moves are high-freq.
- IR eval is custom (themes -> puzzles), not NanoBEIR -- general-English IR
  benchmarks don't measure chess retrieval.

Run:
    SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_static.py
    uv run --exclude-newer=2026-05-12 train_chess_static.py
"""

from __future__ import annotations

import logging
import os
import re
from collections import defaultdict
from contextlib import nullcontext

import datasets
import random
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.base.sampler import BatchSamplers
from sentence_transformers.sentence_transformer.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.sentence_transformer.losses import (
    MatryoshkaLoss,
    MultipleNegativesRankingLoss,
)
from sentence_transformers.sentence_transformer.modules import StaticEmbedding
from transformers import EarlyStoppingCallback, TrainerCallback
import time


EMBEDDING_DIM = 512  # was 256; 512 gives more capacity for bigram tokens
MATRYOSHKA_DIMS = [512, 256, 128, 64, 32]
VOCAB_SIZE = 100_000  # was 50_000; UCI/SAN bigrams add ~20-50k vocab

OUTPUT_DIR = "models/static-embedding-chess"
RUN_NAME = "static-embedding-chess"
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "oneryalcin/static-embedding-chess")
# TOKENIZER_PATH default lives next to the model output. On Modal, set this to
# a path on the persistent volume (e.g. /cache/chess_tokenizer.json) so the
# 6-min WordLevelTrainer run is amortized across launches.
TOKENIZER_PATH = os.environ.get(
    "TOKENIZER_PATH", f"{OUTPUT_DIR}/chess_tokenizer.json"
)
RETRAIN_TOKENIZER = os.environ.get("RETRAIN_TOKENIZER") == "1"
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
FORCE_CPU = os.environ.get("FORCE_CPU") == "1"
# Diagnostic knobs (default: full recipe). Both MPS and T4 show monotonic
# step-time growth with the full Matryoshka stack -- toggle these to isolate.
DISABLE_MATRYOSHKA = os.environ.get("DISABLE_MATRYOSHKA") == "1"
MAX_STEPS_OVERRIDE = int(os.environ.get("MAX_STEPS", "0")) or None
EVAL_STEPS_OVERRIDE = int(os.environ.get("EVAL_STEPS", "0")) or None

EVAL_QUERIES = 200
EVAL_CORPUS = 5_000
# Held-out anchor selection: pick rare combos in this freq range. Low end > 1
# keeps multi-relevant NDCG meaningful; high end caps memorization potential.
HELDOUT_FREQ_MIN = 3
HELDOUT_FREQ_MAX = 30
# Balanced-dataset config: each unique anchor expands to N (anchor, sampled_pos)
# rows. The original 5.8M pairs let the model memorize specific (anchor, pos)
# pairings since each anchor has ~1933 distinct positives. Capping at 100
# random samples per anchor gives the model meaningful variety without the
# 50x redundancy that fuels overfitting.
BALANCED_POSITIVES_PER_ANCHOR = int(os.environ.get("POSITIVES_PER_ANCHOR", "100"))
# Anchor token masking probability during training. 0 disables.
ANCHOR_MASK_PROB = float(os.environ.get("ANCHOR_MASK_PROB", "0.15"))

# Device-aware defaults. MPS (Apple Silicon) can't do bf16 and has unified-
# memory pressure, so the CUDA-targeted skill template defaults (batch=2048,
# bf16=True) don't apply. Scale BATCH_SIZE up if your M-series has 36GB+.
IS_CUDA = torch.cuda.is_available() and not FORCE_CPU
IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() and not FORCE_CPU
# StaticEmbedding is a lookup+average -- no transformer activations to fit.
# Memory cost is the (batch x batch) similarity matrix + (batch x seq x dim)
# lookups, both tiny. CachedMultipleNegativesRankingLoss is NOT compatible
# with StaticEmbedding (no encoder to GradCache through), so we just crank
# the real batch. Scale up freely if your M-series has the headroom.
BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256)

MOVE_NUM_RE = re.compile(r"\d+\.+")


class StepTimingCallback(TrainerCallback):
    """Per-step instrumentation: wall time, CUDA memory, allocator state.
    Costs ~1ms/step. Run-once-and-read approach to diagnosing slowdowns
    instead of swapping configs and rerunning.
    """

    def on_step_begin(self, args, state, control, **kw):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        self._t0 = time.perf_counter()

    def on_step_end(self, args, state, control, **kw):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        dt = time.perf_counter() - self._t0
        # Log every step for the first 20 to see startup; then every 10th.
        if state.global_step <= 20 or state.global_step % 10 == 0:
            if torch.cuda.is_available():
                mem = torch.cuda.memory_allocated() / 1e6
                reserved = torch.cuda.memory_reserved() / 1e6
                logging.info(
                    f"STEP {state.global_step}: dt={dt:.3f}s mem={mem:.0f}MB reserved={reserved:.0f}MB"
                )
            else:
                logging.info(f"STEP {state.global_step}: dt={dt:.3f}s (cpu/mps)")


def autocast_ctx():
    if IS_CUDA:
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        return torch.autocast("cuda", dtype=dtype)
    if IS_MPS:
        return torch.autocast("mps", dtype=torch.float16)
    return nullcontext()


def setup_logging():
    os.makedirs("logs", exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")],
        force=True,
    )
    for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
        logging.getLogger(noisy).setLevel(logging.WARNING)
    if torch.cuda.is_available():
        torch.set_float32_matmul_precision("high")


def _join_tags(tags) -> str:
    if not tags:
        return ""
    return " ".join(t.replace("_", " ") for t in tags)


def _strip_pgn_move_numbers(pgn: str) -> str:
    return MOVE_NUM_RE.sub("", pgn).strip()


def _bigram_token_str(moves: str) -> str:
    """Append bigram tokens to a whitespace-separated move sequence.

    "f2g3 e6e7 b2b1" -> "f2g3 e6e7 b2b1 f2g3+e6e7 e6e7+b2b1"

    Bigrams use `+` as the join char so they're distinct from unigrams in the
    WordLevel tokenizer's whitespace pretokenizer. A token-bag averaging across
    unigrams alone loses move ordering; adding adjacent-pair tokens lets the
    model learn that "e2e4 e7e5" (king's pawn opening) is its own pattern.
    """
    tokens = moves.split()
    if len(tokens) < 2:
        return moves
    bigrams = " ".join(f"{a}+{b}" for a, b in zip(tokens, tokens[1:]))
    return f"{moves} {bigrams}"


def build_puzzle_pairs(row_batch: dict) -> dict:
    anchors, positives = [], []
    for themes, opening_tags, moves in zip(
        row_batch["Themes"], row_batch["OpeningTags"], row_batch["Moves"]
    ):
        themes_txt = _join_tags(themes)
        opening_txt = _join_tags(opening_tags)
        if not themes_txt:
            continue
        anchor = themes_txt + (f" {opening_txt}" if opening_txt else "")
        positive = f"themes {themes_txt}"
        if opening_txt:
            positive += f" opening {opening_txt}"
        positive += f" moves {_bigram_token_str(moves)}"
        anchors.append(anchor)
        positives.append(positive)
    return {"anchor": anchors, "positive": positives}


def build_opening_pairs(row_batch: dict) -> dict:
    anchors, positives = [], []
    for name, eco, pgn in zip(row_batch["name"], row_batch["eco"], row_batch["pgn"]):
        san = _strip_pgn_move_numbers(pgn)
        anchors.append(f"{name} {eco}")
        positives.append(f"name {name} eco {eco} pgn {_bigram_token_str(san)}")
    return {"anchor": anchors, "positive": positives}


def load_chess_pairs() -> tuple[Dataset, Dataset]:
    """Returns (train, holdout) where the holdout anchors are rare combinations
    NEVER seen in train.

    Old eval used the top-200 most-common theme strings as queries. The model
    memorized these in training (each appears ~50k times) so eval was a recall
    test on memorized lookups, not generalization. Replaced with compositional
    held-out anchors:

      - Pick anchor strings with frequency in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX]:
        rare enough to be informative, common enough to have multiple positives
        for multi-relevant eval.
      - REMOVE all pairs with those anchors from train (no leakage).
      - Use those rare anchors as eval queries; the held-out pairs become the
        eval corpus.
      - Individual theme tokens within those anchors still appear *separately*
        in many other training anchors, so the model has learned each token's
        embedding -- it just hasn't seen this particular combination. Tests
        compositional generalization.
    """
    logging.info("Loading Lichess/chess-puzzles (5.8M rows)")
    puzzles = load_dataset("Lichess/chess-puzzles", split="train")
    if SMOKE_TEST:
        puzzles = puzzles.select(range(2_000))
    pair_puzzles = puzzles.map(
        build_puzzle_pairs,
        batched=True,
        batch_size=10_000,
        remove_columns=puzzles.column_names,
        desc="puzzles -> pairs",
    )
    logging.info(f"  built {len(pair_puzzles):,} puzzle pairs")

    logging.info("Loading Lichess/chess-openings (3.6K rows)")
    openings = load_dataset("Lichess/chess-openings", split="train").remove_columns(["img"])
    pair_openings = openings.map(
        build_opening_pairs,
        batched=True,
        remove_columns=openings.column_names,
        desc="openings -> pairs",
    )
    logging.info(f"  built {len(pair_openings):,} opening pairs")

    # Count anchor frequencies across the puzzle pairs.
    logging.info("Computing anchor frequencies for held-out selection")
    anchors = pair_puzzles["anchor"]
    freq: dict[str, int] = defaultdict(int)
    for a in anchors:
        freq[a] += 1
    logging.info(f"  {len(freq):,} unique anchors in puzzle pairs")

    # Pick rare anchors: each appears in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX] pairs.
    # In smoke mode, lower the min so the tiny corpus still produces enough
    # held-out queries (smoke has ~2k puzzles, most anchors freq 1-2).
    min_freq = 2 if SMOKE_TEST else HELDOUT_FREQ_MIN
    max_freq = HELDOUT_FREQ_MAX
    rare_pool = sorted(
        ((a, c) for a, c in freq.items() if min_freq <= c <= max_freq),
        key=lambda kv: kv[1],  # ascending: rarest first
    )
    n_queries_target = 20 if SMOKE_TEST else EVAL_QUERIES
    if len(rare_pool) < n_queries_target:
        logging.warning(
            f"Only {len(rare_pool)} anchors in freq range [{HELDOUT_FREQ_MIN},{HELDOUT_FREQ_MAX}]; "
            f"using all of them ({n_queries_target} requested)"
        )
    heldout_anchors = {a for a, _ in rare_pool[:n_queries_target]}
    logging.info(
        f"  selected {len(heldout_anchors)} held-out anchors "
        f"(freq range: {rare_pool[0][1] if rare_pool else 0}..{rare_pool[min(n_queries_target, len(rare_pool))-1][1] if rare_pool else 0})"
    )

    # Filter: pairs whose anchor is held-out -> eval; everything else -> train.
    held_mask = [a in heldout_anchors for a in anchors]
    holdout = pair_puzzles.select([i for i, h in enumerate(held_mask) if h])
    train_puzzles = pair_puzzles.select([i for i, h in enumerate(held_mask) if not h])
    logging.info(
        f"  split by held-out anchors: train={len(train_puzzles):,}, holdout={len(holdout):,}"
    )

    # Train includes the (non-held) puzzle pairs + all openings.
    train = concatenate_datasets([train_puzzles, pair_openings]).shuffle(seed=12)
    logging.info(f"  train: {len(train):,} pairs | holdout: {len(holdout):,} pairs")
    return train, holdout


def make_balanced_dataset(train: Dataset, n_per_anchor: int) -> Dataset:
    """Cap each anchor's positives to `n_per_anchor` random picks. Breaks the
    5.8M pairs' redundancy (each anchor x ~1933 positives) so the model can't
    memorize specific (anchor, positive) pairings while still seeing useful
    positive variety per anchor.
    """
    by_anchor: dict[str, list[str]] = defaultdict(list)
    for row in train:
        by_anchor[row["anchor"]].append(row["positive"])
    rng = random.Random(12)
    new_anchors, new_positives = [], []
    for anchor, positives in by_anchor.items():
        sample = (
            rng.sample(positives, n_per_anchor)
            if len(positives) > n_per_anchor
            else positives
        )
        for p in sample:
            new_anchors.append(anchor)
            new_positives.append(p)
    logging.info(
        f"Balanced dataset: {len(by_anchor):,} unique anchors -> "
        f"{len(new_anchors):,} pairs (cap {n_per_anchor}/anchor)"
    )
    return Dataset.from_dict({"anchor": new_anchors, "positive": new_positives}).shuffle(seed=12)


def make_anchor_masker(mask_prob: float, rng_seed: int = 12):
    """Return a `set_transform` callable that randomly replaces theme tokens
    with [UNK] in the anchor. Token-bag dropout: forces the model to use
    remaining tokens instead of memorizing the exact combination."""
    if mask_prob <= 0:
        return None
    rng = random.Random(rng_seed)

    def _mask(batch: dict) -> dict:
        anchors = batch["anchor"]
        new_anchors = []
        for a in anchors:
            tokens = a.split()
            if len(tokens) <= 1:
                new_anchors.append(a)
                continue
            kept = [t if rng.random() >= mask_prob else "[UNK]" for t in tokens]
            # Guard against masking everything: if all UNK, restore one random token.
            if all(t == "[UNK]" for t in kept):
                kept[rng.randrange(len(kept))] = tokens[rng.randrange(len(tokens))]
            new_anchors.append(" ".join(kept))
        return {"anchor": new_anchors, "positive": batch["positive"]}

    return _mask


def train_chess_tokenizer(train: Dataset) -> Tokenizer:
    """Train or load a WordLevel tokenizer for the chess corpus.

    Every space-separated unit (theme word, opening word, ECO code, UCI move,
    SAN move) becomes one whole token. Compare to BERT WordPiece which fragments
    "f2g3" into 4 subword pieces -- a token-bag wastes capacity on subword joins
    that carry no chess meaning.

    Caching: if TOKENIZER_PATH exists, load and return it instead of rebuilding.
    The WordLevelTrainer is single-threaded Rust and takes ~6 min on 11.6M
    strings. Tokenizer is deterministic given the same corpus + config, so
    caching is safe. Set RETRAIN_TOKENIZER=1 to force rebuild.
    """
    if not RETRAIN_TOKENIZER and os.path.exists(TOKENIZER_PATH):
        tok = Tokenizer.from_file(TOKENIZER_PATH)
        logging.info(
            f"Reusing cached tokenizer ({tok.get_vocab_size():,} tokens) from {TOKENIZER_PATH}"
        )
        return tok

    logging.info(f"Training WordLevel tokenizer on {len(train):,} pairs (vocab={VOCAB_SIZE})")
    tok = Tokenizer(WordLevel(unk_token="[UNK]"))
    tok.pre_tokenizer = Whitespace()
    trainer = WordLevelTrainer(
        vocab_size=VOCAB_SIZE,
        special_tokens=["[UNK]", "[PAD]"],
        min_frequency=2,
    )

    def text_iter():
        for row in train:
            yield row["anchor"]
            yield row["positive"]

    tok.train_from_iterator(text_iter(), trainer=trainer, length=2 * len(train))
    actual_vocab = tok.get_vocab_size()
    logging.info(f"  tokenizer trained: {actual_vocab:,} tokens (cap was {VOCAB_SIZE:,})")
    os.makedirs(os.path.dirname(TOKENIZER_PATH) or ".", exist_ok=True)
    tok.save(TOKENIZER_PATH)
    logging.info(f"  saved tokenizer to {TOKENIZER_PATH}")
    return tok


def _strip_theme_echo(positive: str) -> str:
    """Eval corpus must not echo the themes the query asks about, or the
    baseline (random-init) scores high just from lexical token overlap. Keep
    only the moves segment."""
    idx = positive.find(" moves ")
    return positive[idx + 1 :] if idx != -1 else positive


def _build_compositional_ir_evaluator(
    holdout: Dataset, corpus: dict[str, str], name: str
) -> InformationRetrievalEvaluator:
    """Compositional: each unseen anchor string is a query."""
    by_anchor: dict[str, set[str]] = defaultdict(set)
    for i, row in enumerate(holdout):
        by_anchor[row["anchor"]].add(f"d{i}")
    sorted_anchors = sorted(by_anchor.items(), key=lambda kv: -len(kv[1]))
    queries = {f"q{i}": anchor for i, (anchor, _) in enumerate(sorted_anchors)}
    relevant_docs = {f"q{i}": docs for i, (_, docs) in enumerate(sorted_anchors)}
    avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
    logging.info(
        f"  [{name}] {len(queries)} queries (unseen combos), avg relevant/query={avg_rel:.1f}"
    )
    return _ir_evaluator(queries, corpus, relevant_docs, name)


def _build_single_theme_ir_evaluator(
    holdout: Dataset, corpus: dict[str, str], name: str
) -> InformationRetrievalEvaluator:
    """Single-theme: each individual theme token from the held-out anchors is
    a query. Tests whether per-token embeddings are useful in isolation.

    Relevant docs for query "fork" = any held-out doc whose anchor contains
    the token "fork". Coarser than the compositional eval (much higher avg
    relevant/query) but a sharper test of token-level meaning.
    """
    theme_to_docs: dict[str, set[str]] = defaultdict(set)
    for i, row in enumerate(holdout):
        for token in row["anchor"].split():
            theme_to_docs[token].add(f"d{i}")
    min_relevant = 2 if SMOKE_TEST else 3
    candidates = [(t, d) for t, d in theme_to_docs.items() if len(d) >= min_relevant]
    candidates.sort(key=lambda kv: -len(kv[1]))
    queries = {f"t{i}": tok for i, (tok, _) in enumerate(candidates)}
    relevant_docs = {f"t{i}": docs for i, (_, docs) in enumerate(candidates)}
    avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
    logging.info(
        f"  [{name}] {len(queries)} single-token queries, avg relevant/query={avg_rel:.1f}"
    )
    return _ir_evaluator(queries, corpus, relevant_docs, name)


def _ir_evaluator(queries, corpus, relevant_docs, name):
    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        ndcg_at_k=[10],
        mrr_at_k=[10],
        accuracy_at_k=[1, 10],
        precision_recall_at_k=[1, 10],
        show_progress_bar=False,
        batch_size=256,
    )


def build_ir_evaluator(holdout: Dataset, name: str = "chess-ir") -> SequentialEvaluator:
    """Wraps two evaluators (compositional + single-theme) into a sequential
    pass. The compositional one's score drives best-model selection; the
    single-theme one is informational.
    """
    corpus = {f"d{i}": _strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)}
    logging.info(f"IR eval setup ({len(corpus)} corpus docs):")
    compositional = _build_compositional_ir_evaluator(holdout, corpus, name=name)
    single_theme = _build_single_theme_ir_evaluator(holdout, corpus, name=f"{name}-tokens")
    # First evaluator's score drives load_best_model_at_end (compositional).
    return SequentialEvaluator(
        [compositional, single_theme],
        main_score_function=lambda scores: scores[0],
    )


def main() -> None:
    setup_logging()

    train_dataset, holdout = load_chess_pairs()
    if SMOKE_TEST:
        train_dataset = train_dataset.select(range(min(500, len(train_dataset))))

    # Train the tokenizer on the FULL (pre-balanced) corpus -- we want every
    # token to be seen as many times as possible for the vocab pass.
    tokenizer = train_chess_tokenizer(train_dataset)

    # Now down-sample to a balanced dataset for the contrastive training.
    train_dataset = make_balanced_dataset(train_dataset, BALANCED_POSITIVES_PER_ANCHOR)

    # Optional anchor-token masking applied on the fly via set_transform.
    masker = make_anchor_masker(ANCHOR_MASK_PROB)
    if masker is not None:
        logging.info(f"Anchor token masking enabled (p={ANCHOR_MASK_PROB})")
        train_dataset.set_transform(masker)

    logging.info(f"Random-init StaticEmbedding (dim={EMBEDDING_DIM})")
    static_embedding = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM)
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- themes/openings <-> positions",
        ),
    )

    evaluator = build_ir_evaluator(holdout)
    inner = MultipleNegativesRankingLoss(model)
    if DISABLE_MATRYOSHKA:
        logging.info("Matryoshka DISABLED -- training at single dim (diagnostic)")
        loss = inner
    else:
        loss = MatryoshkaLoss(model, inner, matryoshka_dims=MATRYOSHKA_DIMS)

    logging.info("Baseline evaluation (random init -- expect near-zero):")
    with autocast_ctx():
        baseline_eval = evaluator(model)[evaluator.primary_metric]
    metric_key = f"eval_{evaluator.primary_metric}"
    logging.info(f"  baseline {evaluator.primary_metric} = {baseline_eval:.4f}")

    if SMOKE_TEST:
        max_steps = 1
    elif MAX_STEPS_OVERRIDE:
        max_steps = MAX_STEPS_OVERRIDE
    else:
        max_steps = -1
    eval_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05  # 20 evals/run
    save_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05

    args = SentenceTransformerTrainingArguments(
        output_dir=OUTPUT_DIR,
        # Balanced dataset is small (~300k pairs); need many epochs to reach
        # comparable total training signal. Early stopping handles excess.
        num_train_epochs=20,
        max_steps=max_steps,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=1e-2,  # was 5e-2 -- much slower convergence, shifts peak later
        weight_decay=0.01,   # was 0.0  -- regularization on the embedding table
        warmup_steps=0.1,
        lr_scheduler_type="linear",
        bf16=IS_CUDA and torch.cuda.is_bf16_supported(),
        fp16=IS_CUDA and not torch.cuda.is_bf16_supported(),
        # was NO_DUPLICATES -- linked-list scan over deferred conflicts gives
        # O(epoch_progress) per-batch cost. With ~3000 unique anchors over
        # 5.8M pairs, dedup is fighting impossible odds. BATCH_SAMPLER (random)
        # is fast and accepts mild within-batch anchor duplication.
        batch_sampler=BatchSamplers.BATCH_SAMPLER,
        eval_strategy="steps",
        eval_steps=eval_steps,
        save_strategy="steps",
        save_steps=save_steps,
        save_total_limit=2,
        logging_steps=0.01,
        logging_first_step=True,
        load_best_model_at_end=True,
        metric_for_best_model=metric_key,
        greater_is_better=True,
        # Trackio crashes at first checkpoint push: empty `router_mapping`
        # struct can't be written to parquet. Disable.
        report_to="none",
        run_name=RUN_NAME,
        seed=12,
        # HF Jobs: container is destroyed after run -- push every checkpoint to
        # the Hub so partial progress survives a timeout. The end-of-run
        # model.push_to_hub() below is the belt to this suspenders.
        push_to_hub=not SMOKE_TEST,
        hub_model_id=HUB_MODEL_ID,
        hub_strategy="every_save",
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
        evaluator=evaluator,
        callbacks=[
            # Auto-stop if compositional NDCG@10 doesn't improve for 3 evals.
            # Lower lr makes curves smoother -- give it slack vs the patience=2
            # we used at lr=5e-2.
            EarlyStoppingCallback(early_stopping_patience=3),
            # Per-step memory + dt logging.
            StepTimingCallback(),
        ],
    )
    trainer.train()

    logging.info("Post-training evaluation:")
    with autocast_ctx():
        score = evaluator(model)[evaluator.primary_metric]
    delta = score - baseline_eval
    verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
    logging.info(
        f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline_eval:.4f} | delta={delta:+.4f}"
    )

    final_dir = f"{OUTPUT_DIR}/final"
    model.save_pretrained(final_dir)
    logging.info(f"Saved final model to {final_dir}")

    if SMOKE_TEST:
        logging.info("SMOKE_TEST=1: skipping Hub push")
        return

    try:
        commit_url = model.push_to_hub(HUB_MODEL_ID)
        logging.info(f"Pushed model to {commit_url.rsplit('/commit/', 1)[0]}")
    except Exception:
        import traceback

        logging.error(f"Hub push failed:\n{traceback.format_exc()}")


if __name__ == "__main__":
    main()