static-embedding-chess / scripts /train_chess_static.py
oneryalcin's picture
Add files using upload-large-folder tool
f8392aa verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "sentence-transformers[train]>=5.5.0",
# "datasets>=2.19.0",
# "accelerate>=0.26.0",
# "tokenizers>=0.20",
# "trackio",
# ]
# ///
"""Train a StaticEmbedding model for chess retrieval.
Pair shape:
anchor = "<themes> [<opening words>]"
positive = "themes <themes> [opening <words>] moves <uci>" (puzzles)
"name <words> eco <code> pgn <san>" (openings)
Datasets:
- Lichess/chess-puzzles (5.8M rows; themes + opening tags + UCI moves)
- Lichess/chess-openings (3.6K rows; opening name + ECO + SAN moves)
Use case: free-text search over a chess corpus. "fork endgame short" -> puzzles
with that motif; "Sicilian Najdorf" -> matching openings.
Design choices:
- Custom WordLevel + Whitespace tokenizer trained on the corpus. Every chess
token (UCI move e2e4, SAN move Nxd4, ECO code B90, theme name, opening word)
is one whole token -- BERT WordPiece would shred them 4-way.
- FEN dropped: position-as-character-soup doesn't fit a token-bag.
- PGN move numbers stripped ("1. e4 c5" -> "e4 c5") so SAN moves are high-freq.
- IR eval is custom (themes -> puzzles), not NanoBEIR -- general-English IR
benchmarks don't measure chess retrieval.
Run:
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_static.py
uv run --exclude-newer=2026-05-12 train_chess_static.py
"""
from __future__ import annotations
import logging
import os
import re
from collections import defaultdict
from contextlib import nullcontext
import datasets
import random
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerModelCardData,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.base.sampler import BatchSamplers
from sentence_transformers.sentence_transformer.evaluation import (
InformationRetrievalEvaluator,
SequentialEvaluator,
)
from sentence_transformers.sentence_transformer.losses import (
MatryoshkaLoss,
MultipleNegativesRankingLoss,
)
from sentence_transformers.sentence_transformer.modules import StaticEmbedding
from transformers import EarlyStoppingCallback, TrainerCallback
import time
EMBEDDING_DIM = 512 # was 256; 512 gives more capacity for bigram tokens
MATRYOSHKA_DIMS = [512, 256, 128, 64, 32]
VOCAB_SIZE = 100_000 # was 50_000; UCI/SAN bigrams add ~20-50k vocab
OUTPUT_DIR = "models/static-embedding-chess"
RUN_NAME = "static-embedding-chess"
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "oneryalcin/static-embedding-chess")
# TOKENIZER_PATH default lives next to the model output. On Modal, set this to
# a path on the persistent volume (e.g. /cache/chess_tokenizer.json) so the
# 6-min WordLevelTrainer run is amortized across launches.
TOKENIZER_PATH = os.environ.get(
"TOKENIZER_PATH", f"{OUTPUT_DIR}/chess_tokenizer.json"
)
RETRAIN_TOKENIZER = os.environ.get("RETRAIN_TOKENIZER") == "1"
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
FORCE_CPU = os.environ.get("FORCE_CPU") == "1"
# Diagnostic knobs (default: full recipe). Both MPS and T4 show monotonic
# step-time growth with the full Matryoshka stack -- toggle these to isolate.
DISABLE_MATRYOSHKA = os.environ.get("DISABLE_MATRYOSHKA") == "1"
MAX_STEPS_OVERRIDE = int(os.environ.get("MAX_STEPS", "0")) or None
EVAL_STEPS_OVERRIDE = int(os.environ.get("EVAL_STEPS", "0")) or None
EVAL_QUERIES = 200
EVAL_CORPUS = 5_000
# Held-out anchor selection: pick rare combos in this freq range. Low end > 1
# keeps multi-relevant NDCG meaningful; high end caps memorization potential.
HELDOUT_FREQ_MIN = 3
HELDOUT_FREQ_MAX = 30
# Balanced-dataset config: each unique anchor expands to N (anchor, sampled_pos)
# rows. The original 5.8M pairs let the model memorize specific (anchor, pos)
# pairings since each anchor has ~1933 distinct positives. Capping at 100
# random samples per anchor gives the model meaningful variety without the
# 50x redundancy that fuels overfitting.
BALANCED_POSITIVES_PER_ANCHOR = int(os.environ.get("POSITIVES_PER_ANCHOR", "100"))
# Anchor token masking probability during training. 0 disables.
ANCHOR_MASK_PROB = float(os.environ.get("ANCHOR_MASK_PROB", "0.15"))
# Device-aware defaults. MPS (Apple Silicon) can't do bf16 and has unified-
# memory pressure, so the CUDA-targeted skill template defaults (batch=2048,
# bf16=True) don't apply. Scale BATCH_SIZE up if your M-series has 36GB+.
IS_CUDA = torch.cuda.is_available() and not FORCE_CPU
IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() and not FORCE_CPU
# StaticEmbedding is a lookup+average -- no transformer activations to fit.
# Memory cost is the (batch x batch) similarity matrix + (batch x seq x dim)
# lookups, both tiny. CachedMultipleNegativesRankingLoss is NOT compatible
# with StaticEmbedding (no encoder to GradCache through), so we just crank
# the real batch. Scale up freely if your M-series has the headroom.
BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256)
MOVE_NUM_RE = re.compile(r"\d+\.+")
class StepTimingCallback(TrainerCallback):
"""Per-step instrumentation: wall time, CUDA memory, allocator state.
Costs ~1ms/step. Run-once-and-read approach to diagnosing slowdowns
instead of swapping configs and rerunning.
"""
def on_step_begin(self, args, state, control, **kw):
if torch.cuda.is_available():
torch.cuda.synchronize()
self._t0 = time.perf_counter()
def on_step_end(self, args, state, control, **kw):
if torch.cuda.is_available():
torch.cuda.synchronize()
dt = time.perf_counter() - self._t0
# Log every step for the first 20 to see startup; then every 10th.
if state.global_step <= 20 or state.global_step % 10 == 0:
if torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / 1e6
reserved = torch.cuda.memory_reserved() / 1e6
logging.info(
f"STEP {state.global_step}: dt={dt:.3f}s mem={mem:.0f}MB reserved={reserved:.0f}MB"
)
else:
logging.info(f"STEP {state.global_step}: dt={dt:.3f}s (cpu/mps)")
def autocast_ctx():
if IS_CUDA:
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.autocast("cuda", dtype=dtype)
if IS_MPS:
return torch.autocast("mps", dtype=torch.float16)
return nullcontext()
def setup_logging():
os.makedirs("logs", exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")],
force=True,
)
for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
logging.getLogger(noisy).setLevel(logging.WARNING)
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
def _join_tags(tags) -> str:
if not tags:
return ""
return " ".join(t.replace("_", " ") for t in tags)
def _strip_pgn_move_numbers(pgn: str) -> str:
return MOVE_NUM_RE.sub("", pgn).strip()
def _bigram_token_str(moves: str) -> str:
"""Append bigram tokens to a whitespace-separated move sequence.
"f2g3 e6e7 b2b1" -> "f2g3 e6e7 b2b1 f2g3+e6e7 e6e7+b2b1"
Bigrams use `+` as the join char so they're distinct from unigrams in the
WordLevel tokenizer's whitespace pretokenizer. A token-bag averaging across
unigrams alone loses move ordering; adding adjacent-pair tokens lets the
model learn that "e2e4 e7e5" (king's pawn opening) is its own pattern.
"""
tokens = moves.split()
if len(tokens) < 2:
return moves
bigrams = " ".join(f"{a}+{b}" for a, b in zip(tokens, tokens[1:]))
return f"{moves} {bigrams}"
def build_puzzle_pairs(row_batch: dict) -> dict:
anchors, positives = [], []
for themes, opening_tags, moves in zip(
row_batch["Themes"], row_batch["OpeningTags"], row_batch["Moves"]
):
themes_txt = _join_tags(themes)
opening_txt = _join_tags(opening_tags)
if not themes_txt:
continue
anchor = themes_txt + (f" {opening_txt}" if opening_txt else "")
positive = f"themes {themes_txt}"
if opening_txt:
positive += f" opening {opening_txt}"
positive += f" moves {_bigram_token_str(moves)}"
anchors.append(anchor)
positives.append(positive)
return {"anchor": anchors, "positive": positives}
def build_opening_pairs(row_batch: dict) -> dict:
anchors, positives = [], []
for name, eco, pgn in zip(row_batch["name"], row_batch["eco"], row_batch["pgn"]):
san = _strip_pgn_move_numbers(pgn)
anchors.append(f"{name} {eco}")
positives.append(f"name {name} eco {eco} pgn {_bigram_token_str(san)}")
return {"anchor": anchors, "positive": positives}
def load_chess_pairs() -> tuple[Dataset, Dataset]:
"""Returns (train, holdout) where the holdout anchors are rare combinations
NEVER seen in train.
Old eval used the top-200 most-common theme strings as queries. The model
memorized these in training (each appears ~50k times) so eval was a recall
test on memorized lookups, not generalization. Replaced with compositional
held-out anchors:
- Pick anchor strings with frequency in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX]:
rare enough to be informative, common enough to have multiple positives
for multi-relevant eval.
- REMOVE all pairs with those anchors from train (no leakage).
- Use those rare anchors as eval queries; the held-out pairs become the
eval corpus.
- Individual theme tokens within those anchors still appear *separately*
in many other training anchors, so the model has learned each token's
embedding -- it just hasn't seen this particular combination. Tests
compositional generalization.
"""
logging.info("Loading Lichess/chess-puzzles (5.8M rows)")
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
if SMOKE_TEST:
puzzles = puzzles.select(range(2_000))
pair_puzzles = puzzles.map(
build_puzzle_pairs,
batched=True,
batch_size=10_000,
remove_columns=puzzles.column_names,
desc="puzzles -> pairs",
)
logging.info(f" built {len(pair_puzzles):,} puzzle pairs")
logging.info("Loading Lichess/chess-openings (3.6K rows)")
openings = load_dataset("Lichess/chess-openings", split="train").remove_columns(["img"])
pair_openings = openings.map(
build_opening_pairs,
batched=True,
remove_columns=openings.column_names,
desc="openings -> pairs",
)
logging.info(f" built {len(pair_openings):,} opening pairs")
# Count anchor frequencies across the puzzle pairs.
logging.info("Computing anchor frequencies for held-out selection")
anchors = pair_puzzles["anchor"]
freq: dict[str, int] = defaultdict(int)
for a in anchors:
freq[a] += 1
logging.info(f" {len(freq):,} unique anchors in puzzle pairs")
# Pick rare anchors: each appears in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX] pairs.
# In smoke mode, lower the min so the tiny corpus still produces enough
# held-out queries (smoke has ~2k puzzles, most anchors freq 1-2).
min_freq = 2 if SMOKE_TEST else HELDOUT_FREQ_MIN
max_freq = HELDOUT_FREQ_MAX
rare_pool = sorted(
((a, c) for a, c in freq.items() if min_freq <= c <= max_freq),
key=lambda kv: kv[1], # ascending: rarest first
)
n_queries_target = 20 if SMOKE_TEST else EVAL_QUERIES
if len(rare_pool) < n_queries_target:
logging.warning(
f"Only {len(rare_pool)} anchors in freq range [{HELDOUT_FREQ_MIN},{HELDOUT_FREQ_MAX}]; "
f"using all of them ({n_queries_target} requested)"
)
heldout_anchors = {a for a, _ in rare_pool[:n_queries_target]}
logging.info(
f" selected {len(heldout_anchors)} held-out anchors "
f"(freq range: {rare_pool[0][1] if rare_pool else 0}..{rare_pool[min(n_queries_target, len(rare_pool))-1][1] if rare_pool else 0})"
)
# Filter: pairs whose anchor is held-out -> eval; everything else -> train.
held_mask = [a in heldout_anchors for a in anchors]
holdout = pair_puzzles.select([i for i, h in enumerate(held_mask) if h])
train_puzzles = pair_puzzles.select([i for i, h in enumerate(held_mask) if not h])
logging.info(
f" split by held-out anchors: train={len(train_puzzles):,}, holdout={len(holdout):,}"
)
# Train includes the (non-held) puzzle pairs + all openings.
train = concatenate_datasets([train_puzzles, pair_openings]).shuffle(seed=12)
logging.info(f" train: {len(train):,} pairs | holdout: {len(holdout):,} pairs")
return train, holdout
def make_balanced_dataset(train: Dataset, n_per_anchor: int) -> Dataset:
"""Cap each anchor's positives to `n_per_anchor` random picks. Breaks the
5.8M pairs' redundancy (each anchor x ~1933 positives) so the model can't
memorize specific (anchor, positive) pairings while still seeing useful
positive variety per anchor.
"""
by_anchor: dict[str, list[str]] = defaultdict(list)
for row in train:
by_anchor[row["anchor"]].append(row["positive"])
rng = random.Random(12)
new_anchors, new_positives = [], []
for anchor, positives in by_anchor.items():
sample = (
rng.sample(positives, n_per_anchor)
if len(positives) > n_per_anchor
else positives
)
for p in sample:
new_anchors.append(anchor)
new_positives.append(p)
logging.info(
f"Balanced dataset: {len(by_anchor):,} unique anchors -> "
f"{len(new_anchors):,} pairs (cap {n_per_anchor}/anchor)"
)
return Dataset.from_dict({"anchor": new_anchors, "positive": new_positives}).shuffle(seed=12)
def make_anchor_masker(mask_prob: float, rng_seed: int = 12):
"""Return a `set_transform` callable that randomly replaces theme tokens
with [UNK] in the anchor. Token-bag dropout: forces the model to use
remaining tokens instead of memorizing the exact combination."""
if mask_prob <= 0:
return None
rng = random.Random(rng_seed)
def _mask(batch: dict) -> dict:
anchors = batch["anchor"]
new_anchors = []
for a in anchors:
tokens = a.split()
if len(tokens) <= 1:
new_anchors.append(a)
continue
kept = [t if rng.random() >= mask_prob else "[UNK]" for t in tokens]
# Guard against masking everything: if all UNK, restore one random token.
if all(t == "[UNK]" for t in kept):
kept[rng.randrange(len(kept))] = tokens[rng.randrange(len(tokens))]
new_anchors.append(" ".join(kept))
return {"anchor": new_anchors, "positive": batch["positive"]}
return _mask
def train_chess_tokenizer(train: Dataset) -> Tokenizer:
"""Train or load a WordLevel tokenizer for the chess corpus.
Every space-separated unit (theme word, opening word, ECO code, UCI move,
SAN move) becomes one whole token. Compare to BERT WordPiece which fragments
"f2g3" into 4 subword pieces -- a token-bag wastes capacity on subword joins
that carry no chess meaning.
Caching: if TOKENIZER_PATH exists, load and return it instead of rebuilding.
The WordLevelTrainer is single-threaded Rust and takes ~6 min on 11.6M
strings. Tokenizer is deterministic given the same corpus + config, so
caching is safe. Set RETRAIN_TOKENIZER=1 to force rebuild.
"""
if not RETRAIN_TOKENIZER and os.path.exists(TOKENIZER_PATH):
tok = Tokenizer.from_file(TOKENIZER_PATH)
logging.info(
f"Reusing cached tokenizer ({tok.get_vocab_size():,} tokens) from {TOKENIZER_PATH}"
)
return tok
logging.info(f"Training WordLevel tokenizer on {len(train):,} pairs (vocab={VOCAB_SIZE})")
tok = Tokenizer(WordLevel(unk_token="[UNK]"))
tok.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(
vocab_size=VOCAB_SIZE,
special_tokens=["[UNK]", "[PAD]"],
min_frequency=2,
)
def text_iter():
for row in train:
yield row["anchor"]
yield row["positive"]
tok.train_from_iterator(text_iter(), trainer=trainer, length=2 * len(train))
actual_vocab = tok.get_vocab_size()
logging.info(f" tokenizer trained: {actual_vocab:,} tokens (cap was {VOCAB_SIZE:,})")
os.makedirs(os.path.dirname(TOKENIZER_PATH) or ".", exist_ok=True)
tok.save(TOKENIZER_PATH)
logging.info(f" saved tokenizer to {TOKENIZER_PATH}")
return tok
def _strip_theme_echo(positive: str) -> str:
"""Eval corpus must not echo the themes the query asks about, or the
baseline (random-init) scores high just from lexical token overlap. Keep
only the moves segment."""
idx = positive.find(" moves ")
return positive[idx + 1 :] if idx != -1 else positive
def _build_compositional_ir_evaluator(
holdout: Dataset, corpus: dict[str, str], name: str
) -> InformationRetrievalEvaluator:
"""Compositional: each unseen anchor string is a query."""
by_anchor: dict[str, set[str]] = defaultdict(set)
for i, row in enumerate(holdout):
by_anchor[row["anchor"]].add(f"d{i}")
sorted_anchors = sorted(by_anchor.items(), key=lambda kv: -len(kv[1]))
queries = {f"q{i}": anchor for i, (anchor, _) in enumerate(sorted_anchors)}
relevant_docs = {f"q{i}": docs for i, (_, docs) in enumerate(sorted_anchors)}
avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
logging.info(
f" [{name}] {len(queries)} queries (unseen combos), avg relevant/query={avg_rel:.1f}"
)
return _ir_evaluator(queries, corpus, relevant_docs, name)
def _build_single_theme_ir_evaluator(
holdout: Dataset, corpus: dict[str, str], name: str
) -> InformationRetrievalEvaluator:
"""Single-theme: each individual theme token from the held-out anchors is
a query. Tests whether per-token embeddings are useful in isolation.
Relevant docs for query "fork" = any held-out doc whose anchor contains
the token "fork". Coarser than the compositional eval (much higher avg
relevant/query) but a sharper test of token-level meaning.
"""
theme_to_docs: dict[str, set[str]] = defaultdict(set)
for i, row in enumerate(holdout):
for token in row["anchor"].split():
theme_to_docs[token].add(f"d{i}")
min_relevant = 2 if SMOKE_TEST else 3
candidates = [(t, d) for t, d in theme_to_docs.items() if len(d) >= min_relevant]
candidates.sort(key=lambda kv: -len(kv[1]))
queries = {f"t{i}": tok for i, (tok, _) in enumerate(candidates)}
relevant_docs = {f"t{i}": docs for i, (_, docs) in enumerate(candidates)}
avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
logging.info(
f" [{name}] {len(queries)} single-token queries, avg relevant/query={avg_rel:.1f}"
)
return _ir_evaluator(queries, corpus, relevant_docs, name)
def _ir_evaluator(queries, corpus, relevant_docs, name):
return InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name=name,
ndcg_at_k=[10],
mrr_at_k=[10],
accuracy_at_k=[1, 10],
precision_recall_at_k=[1, 10],
show_progress_bar=False,
batch_size=256,
)
def build_ir_evaluator(holdout: Dataset, name: str = "chess-ir") -> SequentialEvaluator:
"""Wraps two evaluators (compositional + single-theme) into a sequential
pass. The compositional one's score drives best-model selection; the
single-theme one is informational.
"""
corpus = {f"d{i}": _strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)}
logging.info(f"IR eval setup ({len(corpus)} corpus docs):")
compositional = _build_compositional_ir_evaluator(holdout, corpus, name=name)
single_theme = _build_single_theme_ir_evaluator(holdout, corpus, name=f"{name}-tokens")
# First evaluator's score drives load_best_model_at_end (compositional).
return SequentialEvaluator(
[compositional, single_theme],
main_score_function=lambda scores: scores[0],
)
def main() -> None:
setup_logging()
train_dataset, holdout = load_chess_pairs()
if SMOKE_TEST:
train_dataset = train_dataset.select(range(min(500, len(train_dataset))))
# Train the tokenizer on the FULL (pre-balanced) corpus -- we want every
# token to be seen as many times as possible for the vocab pass.
tokenizer = train_chess_tokenizer(train_dataset)
# Now down-sample to a balanced dataset for the contrastive training.
train_dataset = make_balanced_dataset(train_dataset, BALANCED_POSITIVES_PER_ANCHOR)
# Optional anchor-token masking applied on the fly via set_transform.
masker = make_anchor_masker(ANCHOR_MASK_PROB)
if masker is not None:
logging.info(f"Anchor token masking enabled (p={ANCHOR_MASK_PROB})")
train_dataset.set_transform(masker)
logging.info(f"Random-init StaticEmbedding (dim={EMBEDDING_DIM})")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM)
model = SentenceTransformer(
modules=[static_embedding],
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- themes/openings <-> positions",
),
)
evaluator = build_ir_evaluator(holdout)
inner = MultipleNegativesRankingLoss(model)
if DISABLE_MATRYOSHKA:
logging.info("Matryoshka DISABLED -- training at single dim (diagnostic)")
loss = inner
else:
loss = MatryoshkaLoss(model, inner, matryoshka_dims=MATRYOSHKA_DIMS)
logging.info("Baseline evaluation (random init -- expect near-zero):")
with autocast_ctx():
baseline_eval = evaluator(model)[evaluator.primary_metric]
metric_key = f"eval_{evaluator.primary_metric}"
logging.info(f" baseline {evaluator.primary_metric} = {baseline_eval:.4f}")
if SMOKE_TEST:
max_steps = 1
elif MAX_STEPS_OVERRIDE:
max_steps = MAX_STEPS_OVERRIDE
else:
max_steps = -1
eval_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 # 20 evals/run
save_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05
args = SentenceTransformerTrainingArguments(
output_dir=OUTPUT_DIR,
# Balanced dataset is small (~300k pairs); need many epochs to reach
# comparable total training signal. Early stopping handles excess.
num_train_epochs=20,
max_steps=max_steps,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=1e-2, # was 5e-2 -- much slower convergence, shifts peak later
weight_decay=0.01, # was 0.0 -- regularization on the embedding table
warmup_steps=0.1,
lr_scheduler_type="linear",
bf16=IS_CUDA and torch.cuda.is_bf16_supported(),
fp16=IS_CUDA and not torch.cuda.is_bf16_supported(),
# was NO_DUPLICATES -- linked-list scan over deferred conflicts gives
# O(epoch_progress) per-batch cost. With ~3000 unique anchors over
# 5.8M pairs, dedup is fighting impossible odds. BATCH_SAMPLER (random)
# is fast and accepts mild within-batch anchor duplication.
batch_sampler=BatchSamplers.BATCH_SAMPLER,
eval_strategy="steps",
eval_steps=eval_steps,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=2,
logging_steps=0.01,
logging_first_step=True,
load_best_model_at_end=True,
metric_for_best_model=metric_key,
greater_is_better=True,
# Trackio crashes at first checkpoint push: empty `router_mapping`
# struct can't be written to parquet. Disable.
report_to="none",
run_name=RUN_NAME,
seed=12,
# HF Jobs: container is destroyed after run -- push every checkpoint to
# the Hub so partial progress survives a timeout. The end-of-run
# model.push_to_hub() below is the belt to this suspenders.
push_to_hub=not SMOKE_TEST,
hub_model_id=HUB_MODEL_ID,
hub_strategy="every_save",
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
evaluator=evaluator,
callbacks=[
# Auto-stop if compositional NDCG@10 doesn't improve for 3 evals.
# Lower lr makes curves smoother -- give it slack vs the patience=2
# we used at lr=5e-2.
EarlyStoppingCallback(early_stopping_patience=3),
# Per-step memory + dt logging.
StepTimingCallback(),
],
)
trainer.train()
logging.info("Post-training evaluation:")
with autocast_ctx():
score = evaluator(model)[evaluator.primary_metric]
delta = score - baseline_eval
verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
logging.info(
f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline_eval:.4f} | delta={delta:+.4f}"
)
final_dir = f"{OUTPUT_DIR}/final"
model.save_pretrained(final_dir)
logging.info(f"Saved final model to {final_dir}")
if SMOKE_TEST:
logging.info("SMOKE_TEST=1: skipping Hub push")
return
try:
commit_url = model.push_to_hub(HUB_MODEL_ID)
logging.info(f"Pushed model to {commit_url.rsplit('/commit/', 1)[0]}")
except Exception:
import traceback
logging.error(f"Hub push failed:\n{traceback.format_exc()}")
if __name__ == "__main__":
main()