#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = [ # "sentence-transformers[train]>=5.5.0", # "datasets>=2.19.0", # "accelerate>=0.26.0", # "tokenizers>=0.20", # "trackio", # ] # /// """Train a StaticEmbedding model for chess retrieval. Pair shape: anchor = " []" positive = "themes [opening ] moves " (puzzles) "name eco pgn " (openings) Datasets: - Lichess/chess-puzzles (5.8M rows; themes + opening tags + UCI moves) - Lichess/chess-openings (3.6K rows; opening name + ECO + SAN moves) Use case: free-text search over a chess corpus. "fork endgame short" -> puzzles with that motif; "Sicilian Najdorf" -> matching openings. Design choices: - Custom WordLevel + Whitespace tokenizer trained on the corpus. Every chess token (UCI move e2e4, SAN move Nxd4, ECO code B90, theme name, opening word) is one whole token -- BERT WordPiece would shred them 4-way. - FEN dropped: position-as-character-soup doesn't fit a token-bag. - PGN move numbers stripped ("1. e4 c5" -> "e4 c5") so SAN moves are high-freq. - IR eval is custom (themes -> puzzles), not NanoBEIR -- general-English IR benchmarks don't measure chess retrieval. Run: SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_static.py uv run --exclude-newer=2026-05-12 train_chess_static.py """ from __future__ import annotations import logging import os import re from collections import defaultdict from contextlib import nullcontext import datasets import random import torch from datasets import Dataset, concatenate_datasets, load_dataset from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace from tokenizers.trainers import WordLevelTrainer from sentence_transformers import ( SentenceTransformer, SentenceTransformerModelCardData, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.base.sampler import BatchSamplers from sentence_transformers.sentence_transformer.evaluation import ( InformationRetrievalEvaluator, SequentialEvaluator, ) from sentence_transformers.sentence_transformer.losses import ( MatryoshkaLoss, MultipleNegativesRankingLoss, ) from sentence_transformers.sentence_transformer.modules import StaticEmbedding from transformers import EarlyStoppingCallback, TrainerCallback import time EMBEDDING_DIM = 512 # was 256; 512 gives more capacity for bigram tokens MATRYOSHKA_DIMS = [512, 256, 128, 64, 32] VOCAB_SIZE = 100_000 # was 50_000; UCI/SAN bigrams add ~20-50k vocab OUTPUT_DIR = "models/static-embedding-chess" RUN_NAME = "static-embedding-chess" HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "oneryalcin/static-embedding-chess") # TOKENIZER_PATH default lives next to the model output. On Modal, set this to # a path on the persistent volume (e.g. /cache/chess_tokenizer.json) so the # 6-min WordLevelTrainer run is amortized across launches. TOKENIZER_PATH = os.environ.get( "TOKENIZER_PATH", f"{OUTPUT_DIR}/chess_tokenizer.json" ) RETRAIN_TOKENIZER = os.environ.get("RETRAIN_TOKENIZER") == "1" SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1" FORCE_CPU = os.environ.get("FORCE_CPU") == "1" # Diagnostic knobs (default: full recipe). Both MPS and T4 show monotonic # step-time growth with the full Matryoshka stack -- toggle these to isolate. DISABLE_MATRYOSHKA = os.environ.get("DISABLE_MATRYOSHKA") == "1" MAX_STEPS_OVERRIDE = int(os.environ.get("MAX_STEPS", "0")) or None EVAL_STEPS_OVERRIDE = int(os.environ.get("EVAL_STEPS", "0")) or None EVAL_QUERIES = 200 EVAL_CORPUS = 5_000 # Held-out anchor selection: pick rare combos in this freq range. Low end > 1 # keeps multi-relevant NDCG meaningful; high end caps memorization potential. HELDOUT_FREQ_MIN = 3 HELDOUT_FREQ_MAX = 30 # Balanced-dataset config: each unique anchor expands to N (anchor, sampled_pos) # rows. The original 5.8M pairs let the model memorize specific (anchor, pos) # pairings since each anchor has ~1933 distinct positives. Capping at 100 # random samples per anchor gives the model meaningful variety without the # 50x redundancy that fuels overfitting. BALANCED_POSITIVES_PER_ANCHOR = int(os.environ.get("POSITIVES_PER_ANCHOR", "100")) # Anchor token masking probability during training. 0 disables. ANCHOR_MASK_PROB = float(os.environ.get("ANCHOR_MASK_PROB", "0.15")) # Device-aware defaults. MPS (Apple Silicon) can't do bf16 and has unified- # memory pressure, so the CUDA-targeted skill template defaults (batch=2048, # bf16=True) don't apply. Scale BATCH_SIZE up if your M-series has 36GB+. IS_CUDA = torch.cuda.is_available() and not FORCE_CPU IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() and not FORCE_CPU # StaticEmbedding is a lookup+average -- no transformer activations to fit. # Memory cost is the (batch x batch) similarity matrix + (batch x seq x dim) # lookups, both tiny. CachedMultipleNegativesRankingLoss is NOT compatible # with StaticEmbedding (no encoder to GradCache through), so we just crank # the real batch. Scale up freely if your M-series has the headroom. BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256) MOVE_NUM_RE = re.compile(r"\d+\.+") class StepTimingCallback(TrainerCallback): """Per-step instrumentation: wall time, CUDA memory, allocator state. Costs ~1ms/step. Run-once-and-read approach to diagnosing slowdowns instead of swapping configs and rerunning. """ def on_step_begin(self, args, state, control, **kw): if torch.cuda.is_available(): torch.cuda.synchronize() self._t0 = time.perf_counter() def on_step_end(self, args, state, control, **kw): if torch.cuda.is_available(): torch.cuda.synchronize() dt = time.perf_counter() - self._t0 # Log every step for the first 20 to see startup; then every 10th. if state.global_step <= 20 or state.global_step % 10 == 0: if torch.cuda.is_available(): mem = torch.cuda.memory_allocated() / 1e6 reserved = torch.cuda.memory_reserved() / 1e6 logging.info( f"STEP {state.global_step}: dt={dt:.3f}s mem={mem:.0f}MB reserved={reserved:.0f}MB" ) else: logging.info(f"STEP {state.global_step}: dt={dt:.3f}s (cpu/mps)") def autocast_ctx(): if IS_CUDA: dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 return torch.autocast("cuda", dtype=dtype) if IS_MPS: return torch.autocast("mps", dtype=torch.float16) return nullcontext() def setup_logging(): os.makedirs("logs", exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")], force=True, ) for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"): logging.getLogger(noisy).setLevel(logging.WARNING) if torch.cuda.is_available(): torch.set_float32_matmul_precision("high") def _join_tags(tags) -> str: if not tags: return "" return " ".join(t.replace("_", " ") for t in tags) def _strip_pgn_move_numbers(pgn: str) -> str: return MOVE_NUM_RE.sub("", pgn).strip() def _bigram_token_str(moves: str) -> str: """Append bigram tokens to a whitespace-separated move sequence. "f2g3 e6e7 b2b1" -> "f2g3 e6e7 b2b1 f2g3+e6e7 e6e7+b2b1" Bigrams use `+` as the join char so they're distinct from unigrams in the WordLevel tokenizer's whitespace pretokenizer. A token-bag averaging across unigrams alone loses move ordering; adding adjacent-pair tokens lets the model learn that "e2e4 e7e5" (king's pawn opening) is its own pattern. """ tokens = moves.split() if len(tokens) < 2: return moves bigrams = " ".join(f"{a}+{b}" for a, b in zip(tokens, tokens[1:])) return f"{moves} {bigrams}" def build_puzzle_pairs(row_batch: dict) -> dict: anchors, positives = [], [] for themes, opening_tags, moves in zip( row_batch["Themes"], row_batch["OpeningTags"], row_batch["Moves"] ): themes_txt = _join_tags(themes) opening_txt = _join_tags(opening_tags) if not themes_txt: continue anchor = themes_txt + (f" {opening_txt}" if opening_txt else "") positive = f"themes {themes_txt}" if opening_txt: positive += f" opening {opening_txt}" positive += f" moves {_bigram_token_str(moves)}" anchors.append(anchor) positives.append(positive) return {"anchor": anchors, "positive": positives} def build_opening_pairs(row_batch: dict) -> dict: anchors, positives = [], [] for name, eco, pgn in zip(row_batch["name"], row_batch["eco"], row_batch["pgn"]): san = _strip_pgn_move_numbers(pgn) anchors.append(f"{name} {eco}") positives.append(f"name {name} eco {eco} pgn {_bigram_token_str(san)}") return {"anchor": anchors, "positive": positives} def load_chess_pairs() -> tuple[Dataset, Dataset]: """Returns (train, holdout) where the holdout anchors are rare combinations NEVER seen in train. Old eval used the top-200 most-common theme strings as queries. The model memorized these in training (each appears ~50k times) so eval was a recall test on memorized lookups, not generalization. Replaced with compositional held-out anchors: - Pick anchor strings with frequency in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX]: rare enough to be informative, common enough to have multiple positives for multi-relevant eval. - REMOVE all pairs with those anchors from train (no leakage). - Use those rare anchors as eval queries; the held-out pairs become the eval corpus. - Individual theme tokens within those anchors still appear *separately* in many other training anchors, so the model has learned each token's embedding -- it just hasn't seen this particular combination. Tests compositional generalization. """ logging.info("Loading Lichess/chess-puzzles (5.8M rows)") puzzles = load_dataset("Lichess/chess-puzzles", split="train") if SMOKE_TEST: puzzles = puzzles.select(range(2_000)) pair_puzzles = puzzles.map( build_puzzle_pairs, batched=True, batch_size=10_000, remove_columns=puzzles.column_names, desc="puzzles -> pairs", ) logging.info(f" built {len(pair_puzzles):,} puzzle pairs") logging.info("Loading Lichess/chess-openings (3.6K rows)") openings = load_dataset("Lichess/chess-openings", split="train").remove_columns(["img"]) pair_openings = openings.map( build_opening_pairs, batched=True, remove_columns=openings.column_names, desc="openings -> pairs", ) logging.info(f" built {len(pair_openings):,} opening pairs") # Count anchor frequencies across the puzzle pairs. logging.info("Computing anchor frequencies for held-out selection") anchors = pair_puzzles["anchor"] freq: dict[str, int] = defaultdict(int) for a in anchors: freq[a] += 1 logging.info(f" {len(freq):,} unique anchors in puzzle pairs") # Pick rare anchors: each appears in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX] pairs. # In smoke mode, lower the min so the tiny corpus still produces enough # held-out queries (smoke has ~2k puzzles, most anchors freq 1-2). min_freq = 2 if SMOKE_TEST else HELDOUT_FREQ_MIN max_freq = HELDOUT_FREQ_MAX rare_pool = sorted( ((a, c) for a, c in freq.items() if min_freq <= c <= max_freq), key=lambda kv: kv[1], # ascending: rarest first ) n_queries_target = 20 if SMOKE_TEST else EVAL_QUERIES if len(rare_pool) < n_queries_target: logging.warning( f"Only {len(rare_pool)} anchors in freq range [{HELDOUT_FREQ_MIN},{HELDOUT_FREQ_MAX}]; " f"using all of them ({n_queries_target} requested)" ) heldout_anchors = {a for a, _ in rare_pool[:n_queries_target]} logging.info( f" selected {len(heldout_anchors)} held-out anchors " f"(freq range: {rare_pool[0][1] if rare_pool else 0}..{rare_pool[min(n_queries_target, len(rare_pool))-1][1] if rare_pool else 0})" ) # Filter: pairs whose anchor is held-out -> eval; everything else -> train. held_mask = [a in heldout_anchors for a in anchors] holdout = pair_puzzles.select([i for i, h in enumerate(held_mask) if h]) train_puzzles = pair_puzzles.select([i for i, h in enumerate(held_mask) if not h]) logging.info( f" split by held-out anchors: train={len(train_puzzles):,}, holdout={len(holdout):,}" ) # Train includes the (non-held) puzzle pairs + all openings. train = concatenate_datasets([train_puzzles, pair_openings]).shuffle(seed=12) logging.info(f" train: {len(train):,} pairs | holdout: {len(holdout):,} pairs") return train, holdout def make_balanced_dataset(train: Dataset, n_per_anchor: int) -> Dataset: """Cap each anchor's positives to `n_per_anchor` random picks. Breaks the 5.8M pairs' redundancy (each anchor x ~1933 positives) so the model can't memorize specific (anchor, positive) pairings while still seeing useful positive variety per anchor. """ by_anchor: dict[str, list[str]] = defaultdict(list) for row in train: by_anchor[row["anchor"]].append(row["positive"]) rng = random.Random(12) new_anchors, new_positives = [], [] for anchor, positives in by_anchor.items(): sample = ( rng.sample(positives, n_per_anchor) if len(positives) > n_per_anchor else positives ) for p in sample: new_anchors.append(anchor) new_positives.append(p) logging.info( f"Balanced dataset: {len(by_anchor):,} unique anchors -> " f"{len(new_anchors):,} pairs (cap {n_per_anchor}/anchor)" ) return Dataset.from_dict({"anchor": new_anchors, "positive": new_positives}).shuffle(seed=12) def make_anchor_masker(mask_prob: float, rng_seed: int = 12): """Return a `set_transform` callable that randomly replaces theme tokens with [UNK] in the anchor. Token-bag dropout: forces the model to use remaining tokens instead of memorizing the exact combination.""" if mask_prob <= 0: return None rng = random.Random(rng_seed) def _mask(batch: dict) -> dict: anchors = batch["anchor"] new_anchors = [] for a in anchors: tokens = a.split() if len(tokens) <= 1: new_anchors.append(a) continue kept = [t if rng.random() >= mask_prob else "[UNK]" for t in tokens] # Guard against masking everything: if all UNK, restore one random token. if all(t == "[UNK]" for t in kept): kept[rng.randrange(len(kept))] = tokens[rng.randrange(len(tokens))] new_anchors.append(" ".join(kept)) return {"anchor": new_anchors, "positive": batch["positive"]} return _mask def train_chess_tokenizer(train: Dataset) -> Tokenizer: """Train or load a WordLevel tokenizer for the chess corpus. Every space-separated unit (theme word, opening word, ECO code, UCI move, SAN move) becomes one whole token. Compare to BERT WordPiece which fragments "f2g3" into 4 subword pieces -- a token-bag wastes capacity on subword joins that carry no chess meaning. Caching: if TOKENIZER_PATH exists, load and return it instead of rebuilding. The WordLevelTrainer is single-threaded Rust and takes ~6 min on 11.6M strings. Tokenizer is deterministic given the same corpus + config, so caching is safe. Set RETRAIN_TOKENIZER=1 to force rebuild. """ if not RETRAIN_TOKENIZER and os.path.exists(TOKENIZER_PATH): tok = Tokenizer.from_file(TOKENIZER_PATH) logging.info( f"Reusing cached tokenizer ({tok.get_vocab_size():,} tokens) from {TOKENIZER_PATH}" ) return tok logging.info(f"Training WordLevel tokenizer on {len(train):,} pairs (vocab={VOCAB_SIZE})") tok = Tokenizer(WordLevel(unk_token="[UNK]")) tok.pre_tokenizer = Whitespace() trainer = WordLevelTrainer( vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[PAD]"], min_frequency=2, ) def text_iter(): for row in train: yield row["anchor"] yield row["positive"] tok.train_from_iterator(text_iter(), trainer=trainer, length=2 * len(train)) actual_vocab = tok.get_vocab_size() logging.info(f" tokenizer trained: {actual_vocab:,} tokens (cap was {VOCAB_SIZE:,})") os.makedirs(os.path.dirname(TOKENIZER_PATH) or ".", exist_ok=True) tok.save(TOKENIZER_PATH) logging.info(f" saved tokenizer to {TOKENIZER_PATH}") return tok def _strip_theme_echo(positive: str) -> str: """Eval corpus must not echo the themes the query asks about, or the baseline (random-init) scores high just from lexical token overlap. Keep only the moves segment.""" idx = positive.find(" moves ") return positive[idx + 1 :] if idx != -1 else positive def _build_compositional_ir_evaluator( holdout: Dataset, corpus: dict[str, str], name: str ) -> InformationRetrievalEvaluator: """Compositional: each unseen anchor string is a query.""" by_anchor: dict[str, set[str]] = defaultdict(set) for i, row in enumerate(holdout): by_anchor[row["anchor"]].add(f"d{i}") sorted_anchors = sorted(by_anchor.items(), key=lambda kv: -len(kv[1])) queries = {f"q{i}": anchor for i, (anchor, _) in enumerate(sorted_anchors)} relevant_docs = {f"q{i}": docs for i, (_, docs) in enumerate(sorted_anchors)} avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs)) logging.info( f" [{name}] {len(queries)} queries (unseen combos), avg relevant/query={avg_rel:.1f}" ) return _ir_evaluator(queries, corpus, relevant_docs, name) def _build_single_theme_ir_evaluator( holdout: Dataset, corpus: dict[str, str], name: str ) -> InformationRetrievalEvaluator: """Single-theme: each individual theme token from the held-out anchors is a query. Tests whether per-token embeddings are useful in isolation. Relevant docs for query "fork" = any held-out doc whose anchor contains the token "fork". Coarser than the compositional eval (much higher avg relevant/query) but a sharper test of token-level meaning. """ theme_to_docs: dict[str, set[str]] = defaultdict(set) for i, row in enumerate(holdout): for token in row["anchor"].split(): theme_to_docs[token].add(f"d{i}") min_relevant = 2 if SMOKE_TEST else 3 candidates = [(t, d) for t, d in theme_to_docs.items() if len(d) >= min_relevant] candidates.sort(key=lambda kv: -len(kv[1])) queries = {f"t{i}": tok for i, (tok, _) in enumerate(candidates)} relevant_docs = {f"t{i}": docs for i, (_, docs) in enumerate(candidates)} avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs)) logging.info( f" [{name}] {len(queries)} single-token queries, avg relevant/query={avg_rel:.1f}" ) return _ir_evaluator(queries, corpus, relevant_docs, name) def _ir_evaluator(queries, corpus, relevant_docs, name): return InformationRetrievalEvaluator( queries=queries, corpus=corpus, relevant_docs=relevant_docs, name=name, ndcg_at_k=[10], mrr_at_k=[10], accuracy_at_k=[1, 10], precision_recall_at_k=[1, 10], show_progress_bar=False, batch_size=256, ) def build_ir_evaluator(holdout: Dataset, name: str = "chess-ir") -> SequentialEvaluator: """Wraps two evaluators (compositional + single-theme) into a sequential pass. The compositional one's score drives best-model selection; the single-theme one is informational. """ corpus = {f"d{i}": _strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)} logging.info(f"IR eval setup ({len(corpus)} corpus docs):") compositional = _build_compositional_ir_evaluator(holdout, corpus, name=name) single_theme = _build_single_theme_ir_evaluator(holdout, corpus, name=f"{name}-tokens") # First evaluator's score drives load_best_model_at_end (compositional). return SequentialEvaluator( [compositional, single_theme], main_score_function=lambda scores: scores[0], ) def main() -> None: setup_logging() train_dataset, holdout = load_chess_pairs() if SMOKE_TEST: train_dataset = train_dataset.select(range(min(500, len(train_dataset)))) # Train the tokenizer on the FULL (pre-balanced) corpus -- we want every # token to be seen as many times as possible for the vocab pass. tokenizer = train_chess_tokenizer(train_dataset) # Now down-sample to a balanced dataset for the contrastive training. train_dataset = make_balanced_dataset(train_dataset, BALANCED_POSITIVES_PER_ANCHOR) # Optional anchor-token masking applied on the fly via set_transform. masker = make_anchor_masker(ANCHOR_MASK_PROB) if masker is not None: logging.info(f"Anchor token masking enabled (p={ANCHOR_MASK_PROB})") train_dataset.set_transform(masker) logging.info(f"Random-init StaticEmbedding (dim={EMBEDDING_DIM})") static_embedding = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM) model = SentenceTransformer( modules=[static_embedding], model_card_data=SentenceTransformerModelCardData( language="en", license="apache-2.0", model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- themes/openings <-> positions", ), ) evaluator = build_ir_evaluator(holdout) inner = MultipleNegativesRankingLoss(model) if DISABLE_MATRYOSHKA: logging.info("Matryoshka DISABLED -- training at single dim (diagnostic)") loss = inner else: loss = MatryoshkaLoss(model, inner, matryoshka_dims=MATRYOSHKA_DIMS) logging.info("Baseline evaluation (random init -- expect near-zero):") with autocast_ctx(): baseline_eval = evaluator(model)[evaluator.primary_metric] metric_key = f"eval_{evaluator.primary_metric}" logging.info(f" baseline {evaluator.primary_metric} = {baseline_eval:.4f}") if SMOKE_TEST: max_steps = 1 elif MAX_STEPS_OVERRIDE: max_steps = MAX_STEPS_OVERRIDE else: max_steps = -1 eval_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 # 20 evals/run save_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 args = SentenceTransformerTrainingArguments( output_dir=OUTPUT_DIR, # Balanced dataset is small (~300k pairs); need many epochs to reach # comparable total training signal. Early stopping handles excess. num_train_epochs=20, max_steps=max_steps, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, learning_rate=1e-2, # was 5e-2 -- much slower convergence, shifts peak later weight_decay=0.01, # was 0.0 -- regularization on the embedding table warmup_steps=0.1, lr_scheduler_type="linear", bf16=IS_CUDA and torch.cuda.is_bf16_supported(), fp16=IS_CUDA and not torch.cuda.is_bf16_supported(), # was NO_DUPLICATES -- linked-list scan over deferred conflicts gives # O(epoch_progress) per-batch cost. With ~3000 unique anchors over # 5.8M pairs, dedup is fighting impossible odds. BATCH_SAMPLER (random) # is fast and accepts mild within-batch anchor duplication. batch_sampler=BatchSamplers.BATCH_SAMPLER, eval_strategy="steps", eval_steps=eval_steps, save_strategy="steps", save_steps=save_steps, save_total_limit=2, logging_steps=0.01, logging_first_step=True, load_best_model_at_end=True, metric_for_best_model=metric_key, greater_is_better=True, # Trackio crashes at first checkpoint push: empty `router_mapping` # struct can't be written to parquet. Disable. report_to="none", run_name=RUN_NAME, seed=12, # HF Jobs: container is destroyed after run -- push every checkpoint to # the Hub so partial progress survives a timeout. The end-of-run # model.push_to_hub() below is the belt to this suspenders. push_to_hub=not SMOKE_TEST, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save", ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, evaluator=evaluator, callbacks=[ # Auto-stop if compositional NDCG@10 doesn't improve for 3 evals. # Lower lr makes curves smoother -- give it slack vs the patience=2 # we used at lr=5e-2. EarlyStoppingCallback(early_stopping_patience=3), # Per-step memory + dt logging. StepTimingCallback(), ], ) trainer.train() logging.info("Post-training evaluation:") with autocast_ctx(): score = evaluator(model)[evaluator.primary_metric] delta = score - baseline_eval verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION" logging.info( f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline_eval:.4f} | delta={delta:+.4f}" ) final_dir = f"{OUTPUT_DIR}/final" model.save_pretrained(final_dir) logging.info(f"Saved final model to {final_dir}") if SMOKE_TEST: logging.info("SMOKE_TEST=1: skipping Hub push") return try: commit_url = model.push_to_hub(HUB_MODEL_ID) logging.info(f"Pushed model to {commit_url.rsplit('/commit/', 1)[0]}") except Exception: import traceback logging.error(f"Hub push failed:\n{traceback.format_exc()}") if __name__ == "__main__": main()