Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "sentence-transformers[train]>=5.5.0", | |
| # "datasets>=2.19.0", | |
| # "accelerate>=0.26.0", | |
| # "tokenizers>=0.20", | |
| # ] | |
| # /// | |
| """Multi-task training: chess-aware semantic structure + hard-negative MNRL. | |
| Two simultaneous training signals: | |
| 1. THEME-DISTILL dataset: (theme_token, mpnet_definition_emb) | |
| - 73 rows (one per Lichess theme) | |
| - Loss: EmbedDistillLoss (project student 512d -> 768d, match teacher) | |
| - Effect: enc("fork") moves toward MPNet("a tactical motif where one piece...") | |
| - Solves orthogonal-token-embeddings problem identified in Phase 1 | |
| 2. CHESS-CONTENT dataset: (anchor, positive, hard_negative) | |
| - From mined hard-negs of v3 model | |
| - Loss: MultipleNegativesRankingLoss (handles triplets natively) | |
| - Effect: maintains chess-content associations, sharpens discriminative ability | |
| Multi-task trainer interleaves batches from both datasets. The theme dataset is | |
| tiny (73 rows) but high-impact -- it injects semantic structure into 73 token | |
| embeddings. The chess dataset is large (1.6M+ triplets) and shapes the rest. | |
| Run: | |
| SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_multitask.py | |
| uv run --exclude-newer=2026-05-12 train_chess_multitask.py | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import time | |
| from collections import defaultdict | |
| from contextlib import nullcontext | |
| import numpy as np | |
| import torch | |
| from datasets import Dataset, concatenate_datasets, load_dataset | |
| from tokenizers import Tokenizer | |
| from sentence_transformers import ( | |
| SentenceTransformer, | |
| SentenceTransformerModelCardData, | |
| SentenceTransformerTrainer, | |
| SentenceTransformerTrainingArguments, | |
| ) | |
| from sentence_transformers.base.sampler import BatchSamplers, MultiDatasetBatchSamplers | |
| from sentence_transformers.sentence_transformer.evaluation import ( | |
| InformationRetrievalEvaluator, | |
| ) | |
| from sentence_transformers.sentence_transformer.losses import ( | |
| EmbedDistillLoss, | |
| MultipleNegativesRankingLoss, | |
| ) | |
| from sentence_transformers.sentence_transformer.modules import StaticEmbedding | |
| from transformers import EarlyStoppingCallback, TrainerCallback | |
| THEME_DEFS_PATH = "models/theme_definitions.parquet" | |
| TRIPLETS_PATH = "models/hard_negatives.parquet" | |
| TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "models/static-embedding-chess/chess_tokenizer.json") | |
| OUTPUT_DIR = "models/static-embedding-chess-multitask" | |
| RUN_NAME = "static-embedding-chess-multitask" | |
| SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1" | |
| EMBEDDING_DIM = 512 | |
| TEACHER_DIM = 768 | |
| HELDOUT_FREQ_MIN = 3 | |
| HELDOUT_FREQ_MAX = 30 | |
| EVAL_QUERIES = 200 | |
| THEME_REPLICAS = int(os.environ.get("THEME_REPLICAS", "500")) # oversample theme dataset | |
| IS_CUDA = torch.cuda.is_available() | |
| IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() | |
| BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256) | |
| def setup_logging(): | |
| os.makedirs("logs", exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=logging.INFO, | |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")], | |
| force=True, | |
| ) | |
| for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"): | |
| logging.getLogger(noisy).setLevel(logging.WARNING) | |
| def _join_tags(tags): | |
| return " ".join(t.replace("_", " ") for t in tags) if tags else "" | |
| def _bigram_token_str(moves): | |
| toks = moves.split() | |
| if len(toks) < 2: | |
| return moves | |
| bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) | |
| return f"{moves} {bigrams}" | |
| def build_puzzle_pairs(batch): | |
| anchors, positives = [], [] | |
| for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]): | |
| themes_txt = _join_tags(themes) | |
| op_txt = _join_tags(op) | |
| if not themes_txt: | |
| continue | |
| anchor = themes_txt + (f" {op_txt}" if op_txt else "") | |
| positive = f"themes {themes_txt}" | |
| if op_txt: | |
| positive += f" opening {op_txt}" | |
| positive += f" moves {_bigram_token_str(moves)}" | |
| anchors.append(anchor) | |
| positives.append(positive) | |
| return {"anchor": anchors, "positive": positives} | |
| def strip_theme_echo(p): | |
| i = p.find(" moves ") | |
| return p[i + 1 :] if i != -1 else p | |
| def build_evaluator(holdout): | |
| corpus = {f"d{i}": strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)} | |
| by_anchor = defaultdict(set) | |
| for i, row in enumerate(holdout): | |
| by_anchor[row["anchor"]].add(f"d{i}") | |
| sorted_a = sorted(by_anchor.items(), key=lambda kv: -len(kv[1])) | |
| queries = {f"q{i}": a for i, (a, _) in enumerate(sorted_a)} | |
| relevant = {f"q{i}": ids for i, (_, ids) in enumerate(sorted_a)} | |
| return InformationRetrievalEvaluator( | |
| queries=queries, corpus=corpus, relevant_docs=relevant, | |
| name="chess-ir", ndcg_at_k=[10], mrr_at_k=[10], | |
| accuracy_at_k=[1, 10], precision_recall_at_k=[1, 10], | |
| show_progress_bar=False, batch_size=256, | |
| ) | |
| def autocast_ctx(): | |
| if IS_CUDA: | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| return torch.autocast("cuda", dtype=dtype) | |
| if IS_MPS: | |
| return torch.autocast("mps", dtype=torch.float16) | |
| return nullcontext() | |
| def main(): | |
| setup_logging() | |
| logging.info(f"Loading tokenizer from {TOKENIZER_PATH}") | |
| tokenizer = Tokenizer.from_file(TOKENIZER_PATH) | |
| logging.info(f" vocab: {tokenizer.get_vocab_size():,}") | |
| logging.info(f"Building random-init StaticEmbedding (dim={EMBEDDING_DIM})") | |
| static = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM) | |
| model = SentenceTransformer( | |
| modules=[static], | |
| model_card_data=SentenceTransformerModelCardData( | |
| language="en", license="apache-2.0", | |
| model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- multi-task (theme distill + hard-neg MNRL)", | |
| ), | |
| ) | |
| # === Dataset A: theme distillation === | |
| logging.info(f"Loading theme definitions from {THEME_DEFS_PATH}") | |
| theme_ds_full = Dataset.from_parquet(THEME_DEFS_PATH) | |
| # EmbedDistillLoss expects columns: sentence, label | |
| theme_ds = theme_ds_full.rename_columns({"theme": "sentence", "embedding": "label"}).remove_columns(["definition"]) | |
| # Oversample to be seen alongside the much-larger chess dataset | |
| if not SMOKE_TEST: | |
| theme_ds = concatenate_datasets([theme_ds] * THEME_REPLICAS).shuffle(seed=12) | |
| logging.info(f" {len(theme_ds):,} theme rows (after oversampling)") | |
| # === Dataset B: chess triplets === | |
| logging.info(f"Loading triplets from {TRIPLETS_PATH}") | |
| triplet_ds = Dataset.from_parquet(TRIPLETS_PATH) | |
| if SMOKE_TEST: | |
| triplet_ds = triplet_ds.select(range(min(500, len(triplet_ds)))) | |
| logging.info(f" {len(triplet_ds):,} triplets, columns: {triplet_ds.column_names}") | |
| # === Build eval (same as previous runs) === | |
| logging.info("Building held-out eval") | |
| puzzles = load_dataset("Lichess/chess-puzzles", split="train") | |
| if SMOKE_TEST: | |
| puzzles = puzzles.select(range(2_000)) | |
| pair_puzzles = puzzles.map( | |
| build_puzzle_pairs, batched=True, batch_size=20_000, | |
| remove_columns=puzzles.column_names, num_proc=4, | |
| ) | |
| anchors = pair_puzzles["anchor"] | |
| freq = defaultdict(int) | |
| for a in anchors: | |
| freq[a] += 1 | |
| rare_pool = sorted( | |
| ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), | |
| key=lambda kv: kv[1], | |
| ) | |
| n_eval = 20 if SMOKE_TEST else EVAL_QUERIES | |
| heldout = {a for a, _ in rare_pool[:n_eval]} | |
| held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h] | |
| holdout = pair_puzzles.select(held_idx) | |
| logging.info(f" holdout: {len(holdout)}") | |
| evaluator = build_evaluator(holdout) | |
| logging.info("Baseline eval (random init):") | |
| with autocast_ctx(): | |
| baseline = evaluator(model)[evaluator.primary_metric] | |
| metric_key = f"eval_{evaluator.primary_metric}" | |
| logging.info(f" baseline {evaluator.primary_metric} = {baseline:.4f}") | |
| # === Multi-task setup === | |
| train_datasets = { | |
| "chess": triplet_ds, | |
| "themes": theme_ds, | |
| } | |
| losses = { | |
| "chess": MultipleNegativesRankingLoss(model), | |
| "themes": EmbedDistillLoss(model, distance_metric="cosine", projection_dim=TEACHER_DIM), | |
| } | |
| args = SentenceTransformerTrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=5, | |
| max_steps=1 if SMOKE_TEST else -1, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| learning_rate=1e-2, | |
| weight_decay=0.01, | |
| warmup_steps=0.1, | |
| lr_scheduler_type="linear", | |
| bf16=IS_CUDA and torch.cuda.is_bf16_supported(), | |
| fp16=IS_CUDA and not torch.cuda.is_bf16_supported(), | |
| batch_sampler=BatchSamplers.BATCH_SAMPLER, | |
| multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, | |
| eval_strategy="steps", | |
| eval_steps=0.05, | |
| save_strategy="steps", | |
| save_steps=0.05, | |
| save_total_limit=2, | |
| logging_steps=0.02, | |
| logging_first_step=True, | |
| load_best_model_at_end=True, | |
| metric_for_best_model=metric_key, | |
| greater_is_better=True, | |
| report_to="none", | |
| run_name=RUN_NAME, | |
| seed=12, | |
| push_to_hub=False, | |
| ) | |
| trainer = SentenceTransformerTrainer( | |
| model=model, args=args, | |
| train_dataset=train_datasets, loss=losses, evaluator=evaluator, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], | |
| ) | |
| trainer.train() | |
| logging.info("Post-training eval:") | |
| with autocast_ctx(): | |
| score = evaluator(model)[evaluator.primary_metric] | |
| delta = score - baseline | |
| verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION" | |
| logging.info( | |
| f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline:.4f} | delta={delta:+.4f}" | |
| ) | |
| # Also report current absolute vs v3 baseline (0.080) | |
| v3_baseline = 0.0801 | |
| logging.info(f" vs v3 (0.0801): delta = {score - v3_baseline:+.4f}") | |
| final_dir = f"{OUTPUT_DIR}/final" | |
| model.save_pretrained(final_dir) | |
| logging.info(f"Saved final model to {final_dir}") | |
| if __name__ == "__main__": | |
| main() | |