Sentence Similarity
sentence-transformers
Safetensors
English
static-embedding
chess
retrieval
exploratory
Instructions to use oneryalcin/static-embedding-chess with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/static-embedding-chess with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/static-embedding-chess") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse files- README.md +384 -506
- data/hard_negatives_chess.parquet +3 -0
- data/hard_negatives_english.parquet +3 -0
- data/theme_definitions.parquet +3 -0
- model.safetensors +1 -1
- scripts/compare_variants.py +175 -0
- scripts/convert_to_english.py +216 -0
- scripts/diag_ce_vs_bm25.py +145 -0
- scripts/generate_theme_defs.py +168 -0
- scripts/mine_hard_negs_v2.py +213 -0
- scripts/train_chess_multitask.py +287 -0
- scripts/train_chess_static.py +640 -0
README.md
CHANGED
|
@@ -1,538 +1,416 @@
|
|
| 1 |
---
|
| 2 |
-
language:
|
| 3 |
-
- en
|
| 4 |
license: apache-2.0
|
|
|
|
|
|
|
| 5 |
tags:
|
| 6 |
- sentence-transformers
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
-
|
| 13 |
-
|
| 14 |
-
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
c3f3 d5b4 f3c6 b4c6 f5f6+c5e6 c5e6+h5g6 h5g6+h7g6 h7g6+c3f3 c3f3+d5b4 d5b4+f3c6
|
| 29 |
-
f3c6+b4c6
|
| 30 |
-
- themes advancedPawn advantage endgame long master promotion rookEndgame moves
|
| 31 |
-
h3h2 g1g2 g3g2 a6a7 h2h1q a7b8q h3h2+g1g2 g1g2+g3g2 g3g2+a6a7 a6a7+h2h1q h2h1q+a7b8q
|
| 32 |
-
- themes crushing intermezzo master middlegame sacrifice veryLong moves a6c4 d6f6
|
| 33 |
-
f1f6 h6h1 g1f2 h8f6 f2e2 f6e7 a6c4+d6f6 d6f6+f1f6 f1f6+h6h1 h6h1+g1f2 g1f2+h8f6
|
| 34 |
-
h8f6+f2e2 f2e2+f6e7
|
| 35 |
-
- source_sentence: advantage hangingPiece middlegame short Nimzo-Larsen Attack Nimzo-Larsen
|
| 36 |
-
Attack Modern [UNK]
|
| 37 |
-
sentences:
|
| 38 |
-
- themes hangingPiece mate mateIn1 middlegame oneMove opening Trompowsky Attack
|
| 39 |
-
Trompowsky Attack Classical Defense moves f4g4 d8d1 f4g4+d8d1
|
| 40 |
-
- themes advancedPawn crushing defensiveMove endgame master quietMove veryLong moves
|
| 41 |
-
f1e1 h3h2 f8h8 f5h4 h8e5 g3g2 e5e4 h4f3 f1e1+h3h2 h3h2+f8h8 f8h8+f5h4 f5h4+h8e5
|
| 42 |
-
h8e5+g3g2 g3g2+e5e4 e5e4+h4f3
|
| 43 |
-
- themes advantage hangingPiece middlegame short opening Nimzo-Larsen Attack Nimzo-Larsen
|
| 44 |
-
Attack Modern Variation moves f5d7 b5g5 e3e2 d1d2 f5d7+b5g5 b5g5+e3e2 e3e2+d1d2
|
| 45 |
-
- source_sentence: '[UNK] defensiveMove [UNK] [UNK] veryLong'
|
| 46 |
-
sentences:
|
| 47 |
-
- themes advantage discoveredAttack exposedKing middlegame trappedPiece veryLong
|
| 48 |
-
opening French Defense French Defense Orthoschnapp Gambit moves e2d1 c4e3 d2e3
|
| 49 |
-
b5f1 d1d2 f1g2 g1e2 g2h1 e2d1+c4e3 c4e3+d2e3 d2e3+b5f1 b5f1+d1d2 d1d2+f1g2 f1g2+g1e2
|
| 50 |
-
g1e2+g2h1
|
| 51 |
-
- themes crushing defensiveMove enPassant middlegame veryLong moves g2e2 a3f3 f7f5
|
| 52 |
-
e5f6 c4f4 g3f4 e2g2 f3g3 g2e2+a3f3 a3f3+f7f5 f7f5+e5f6 e5f6+c4f4 c4f4+g3f4 g3f4+e2g2
|
| 53 |
-
e2g2+f3g3
|
| 54 |
-
- themes advancedPawn bishopEndgame crushing defensiveMove endgame veryLong moves
|
| 55 |
-
f3e4 a3a2 g6g7 e6f7 e5e6 f7g8 e6e7 c5e7 f3e4+a3a2 a3a2+g6g7 g6g7+e6f7 e6f7+e5e6
|
| 56 |
-
e5e6+f7g8 f7g8+e6e7 e6e7+c5e7
|
| 57 |
-
- source_sentence: '[UNK] deflection discoveredAttack [UNK] queensideAttack short
|
| 58 |
-
Philidor Defense [UNK] Defense Other variations'
|
| 59 |
-
sentences:
|
| 60 |
-
- themes crushing middlegame pin queensideAttack short opening Sicilian Defense
|
| 61 |
-
Sicilian Defense Najdorf Variation moves c3d5 c5b3 c1b1 b3d2 c3d5+c5b3 c5b3+c1b1
|
| 62 |
-
c1b1+b3d2
|
| 63 |
-
- themes crushing deflection discoveredAttack middlegame queensideAttack short opening
|
| 64 |
-
Philidor Defense Philidor Defense Other variations moves d3c3 d4b3 c1b1 d7d1 d3c3+d4b3
|
| 65 |
-
d4b3+c1b1 c1b1+d7d1
|
| 66 |
-
- themes advantage discoveredAttack middlegame short opening Philidor Defense Philidor
|
| 67 |
-
Defense Other variations moves e4d4 d3f5 c8b8 d1d4 e4d4+d3f5 d3f5+c8b8 c8b8+d1d4
|
| 68 |
-
pipeline_tag: sentence-similarity
|
| 69 |
-
library_name: sentence-transformers
|
| 70 |
-
metrics:
|
| 71 |
-
- cosine_accuracy@1
|
| 72 |
-
- cosine_accuracy@10
|
| 73 |
-
- cosine_precision@1
|
| 74 |
-
- cosine_precision@10
|
| 75 |
-
- cosine_recall@1
|
| 76 |
-
- cosine_recall@10
|
| 77 |
-
- cosine_ndcg@10
|
| 78 |
-
- cosine_mrr@10
|
| 79 |
-
- cosine_map@100
|
| 80 |
-
model-index:
|
| 81 |
-
- name: Static chess embedding (512d) -- themes/openings <-> positions
|
| 82 |
-
results:
|
| 83 |
-
- task:
|
| 84 |
-
type: information-retrieval
|
| 85 |
-
name: Information Retrieval
|
| 86 |
-
dataset:
|
| 87 |
-
name: chess ir
|
| 88 |
-
type: chess-ir
|
| 89 |
-
metrics:
|
| 90 |
-
- type: cosine_accuracy@1
|
| 91 |
-
value: 0.005
|
| 92 |
-
name: Cosine Accuracy@1
|
| 93 |
-
- type: cosine_accuracy@10
|
| 94 |
-
value: 0.07
|
| 95 |
-
name: Cosine Accuracy@10
|
| 96 |
-
- type: cosine_precision@1
|
| 97 |
-
value: 0.005
|
| 98 |
-
name: Cosine Precision@1
|
| 99 |
-
- type: cosine_precision@10
|
| 100 |
-
value: 0.008
|
| 101 |
-
name: Cosine Precision@10
|
| 102 |
-
- type: cosine_recall@1
|
| 103 |
-
value: 0.0016666666666666666
|
| 104 |
-
name: Cosine Recall@1
|
| 105 |
-
- type: cosine_recall@10
|
| 106 |
-
value: 0.02666666666666666
|
| 107 |
-
name: Cosine Recall@10
|
| 108 |
-
- type: cosine_ndcg@10
|
| 109 |
-
value: 0.01682968253099316
|
| 110 |
-
name: Cosine Ndcg@10
|
| 111 |
-
- type: cosine_mrr@10
|
| 112 |
-
value: 0.020728174603174603
|
| 113 |
-
name: Cosine Mrr@10
|
| 114 |
-
- type: cosine_map@100
|
| 115 |
-
value: 0.014144217882495914
|
| 116 |
-
name: Cosine Map@100
|
| 117 |
-
- task:
|
| 118 |
-
type: information-retrieval
|
| 119 |
-
name: Information Retrieval
|
| 120 |
-
dataset:
|
| 121 |
-
name: chess ir tokens
|
| 122 |
-
type: chess-ir-tokens
|
| 123 |
-
metrics:
|
| 124 |
-
- type: cosine_accuracy@1
|
| 125 |
-
value: 0.07936507936507936
|
| 126 |
-
name: Cosine Accuracy@1
|
| 127 |
-
- type: cosine_accuracy@10
|
| 128 |
-
value: 0.25925925925925924
|
| 129 |
-
name: Cosine Accuracy@10
|
| 130 |
-
- type: cosine_precision@1
|
| 131 |
-
value: 0.07936507936507936
|
| 132 |
-
name: Cosine Precision@1
|
| 133 |
-
- type: cosine_precision@10
|
| 134 |
-
value: 0.06031746031746032
|
| 135 |
-
name: Cosine Precision@10
|
| 136 |
-
- type: cosine_recall@1
|
| 137 |
-
value: 0.00224439005944158
|
| 138 |
-
name: Cosine Recall@1
|
| 139 |
-
- type: cosine_recall@10
|
| 140 |
-
value: 0.023957890091684336
|
| 141 |
-
name: Cosine Recall@10
|
| 142 |
-
- type: cosine_ndcg@10
|
| 143 |
-
value: 0.067202690066618
|
| 144 |
-
name: Cosine Ndcg@10
|
| 145 |
-
- type: cosine_mrr@10
|
| 146 |
-
value: 0.12332031578063325
|
| 147 |
-
name: Cosine Mrr@10
|
| 148 |
-
- type: cosine_map@100
|
| 149 |
-
value: 0.03321093573791526
|
| 150 |
-
name: Cosine Map@100
|
| 151 |
---
|
| 152 |
|
| 153 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
|
| 157 |
-
##
|
| 158 |
|
| 159 |
-
|
| 160 |
-
-
|
| 161 |
-
|
| 162 |
-
-
|
| 163 |
-
-
|
| 164 |
-
-
|
| 165 |
-
- **
|
| 166 |
-
<!-- - **Training Dataset:** Unknown -->
|
| 167 |
-
- **Language:** en
|
| 168 |
-
- **License:** apache-2.0
|
| 169 |
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
-
|
| 174 |
-
- **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
|
| 175 |
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
|
| 178 |
```
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
```
|
| 183 |
|
| 184 |
-
|
| 185 |
|
| 186 |
-
##
|
| 187 |
|
| 188 |
-
|
|
|
|
| 189 |
|
| 190 |
-
|
| 191 |
-
pip install -U sentence-transformers
|
| 192 |
-
```
|
| 193 |
-
Then you can load this model and run inference.
|
| 194 |
-
```python
|
| 195 |
-
from sentence_transformers import SentenceTransformer
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
# Run inference
|
| 200 |
-
queries = [
|
| 201 |
-
'[UNK] deflection discoveredAttack [UNK] queensideAttack short Philidor Defense [UNK] Defense Other variations',
|
| 202 |
-
]
|
| 203 |
-
documents = [
|
| 204 |
-
'themes crushing deflection discoveredAttack middlegame queensideAttack short opening Philidor Defense Philidor Defense Other variations moves d3c3 d4b3 c1b1 d7d1 d3c3+d4b3 d4b3+c1b1 c1b1+d7d1',
|
| 205 |
-
'themes advantage discoveredAttack middlegame short opening Philidor Defense Philidor Defense Other variations moves e4d4 d3f5 c8b8 d1d4 e4d4+d3f5 d3f5+c8b8 c8b8+d1d4',
|
| 206 |
-
'themes crushing middlegame pin queensideAttack short opening Sicilian Defense Sicilian Defense Najdorf Variation moves c3d5 c5b3 c1b1 b3d2 c3d5+c5b3 c5b3+c1b1 c1b1+b3d2',
|
| 207 |
-
]
|
| 208 |
-
query_embeddings = model.encode_query(queries)
|
| 209 |
-
document_embeddings = model.encode_document(documents)
|
| 210 |
-
print(query_embeddings.shape, document_embeddings.shape)
|
| 211 |
-
# [1, 512] [3, 512]
|
| 212 |
-
|
| 213 |
-
# Get the similarity scores for the embeddings
|
| 214 |
-
similarities = model.similarity(query_embeddings, document_embeddings)
|
| 215 |
-
print(similarities)
|
| 216 |
-
# tensor([[0.8405, 0.5061, 0.2136]])
|
| 217 |
-
```
|
| 218 |
-
<!--
|
| 219 |
-
### Direct Usage (Transformers)
|
| 220 |
-
|
| 221 |
-
<details><summary>Click to see the direct usage in Transformers</summary>
|
| 222 |
-
|
| 223 |
-
</details>
|
| 224 |
-
-->
|
| 225 |
-
|
| 226 |
-
<!--
|
| 227 |
-
### Downstream Usage (Sentence Transformers)
|
| 228 |
-
|
| 229 |
-
You can finetune this model on your own dataset.
|
| 230 |
-
|
| 231 |
-
<details><summary>Click to expand</summary>
|
| 232 |
-
|
| 233 |
-
</details>
|
| 234 |
-
-->
|
| 235 |
-
|
| 236 |
-
<!--
|
| 237 |
-
### Out-of-Scope Use
|
| 238 |
-
|
| 239 |
-
*List how the model may foreseeably be misused and address what users ought not to do with the model.*
|
| 240 |
-
-->
|
| 241 |
-
|
| 242 |
-
## Evaluation
|
| 243 |
-
|
| 244 |
-
### Metrics
|
| 245 |
-
|
| 246 |
-
#### Information Retrieval
|
| 247 |
-
|
| 248 |
-
* Datasets: `chess-ir` and `chess-ir-tokens`
|
| 249 |
-
* Evaluated with [<code>InformationRetrievalEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.sentence_transformer.evaluation.InformationRetrievalEvaluator)
|
| 250 |
-
|
| 251 |
-
| Metric | chess-ir | chess-ir-tokens |
|
| 252 |
-
|:--------------------|:-----------|:----------------|
|
| 253 |
-
| cosine_accuracy@1 | 0.005 | 0.0794 |
|
| 254 |
-
| cosine_accuracy@10 | 0.07 | 0.2593 |
|
| 255 |
-
| cosine_precision@1 | 0.005 | 0.0794 |
|
| 256 |
-
| cosine_precision@10 | 0.008 | 0.0603 |
|
| 257 |
-
| cosine_recall@1 | 0.0017 | 0.0022 |
|
| 258 |
-
| cosine_recall@10 | 0.0267 | 0.024 |
|
| 259 |
-
| **cosine_ndcg@10** | **0.0168** | **0.0672** |
|
| 260 |
-
| cosine_mrr@10 | 0.0207 | 0.1233 |
|
| 261 |
-
| cosine_map@100 | 0.0141 | 0.0332 |
|
| 262 |
-
|
| 263 |
-
<!--
|
| 264 |
-
## Bias, Risks and Limitations
|
| 265 |
-
|
| 266 |
-
*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
|
| 267 |
-
-->
|
| 268 |
-
|
| 269 |
-
<!--
|
| 270 |
-
### Recommendations
|
| 271 |
-
|
| 272 |
-
*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
|
| 273 |
-
-->
|
| 274 |
-
|
| 275 |
-
## Training Details
|
| 276 |
-
|
| 277 |
-
### Training Dataset
|
| 278 |
-
|
| 279 |
-
#### Unnamed Dataset
|
| 280 |
-
|
| 281 |
-
* Size: 1,619,946 training samples
|
| 282 |
-
* Columns: <code>anchor</code> and <code>positive</code>
|
| 283 |
-
* Approximate statistics based on the first 100 samples:
|
| 284 |
-
| | anchor | positive |
|
| 285 |
-
|:---------|:------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------|
|
| 286 |
-
| type | string | string |
|
| 287 |
-
| modality | text | text |
|
| 288 |
-
| details | <ul><li>min: 21 characters</li><li>mean: 75.57 characters</li><li>max: 122 characters</li></ul> | <ul><li>min: 86 characters</li><li>mean: 158.13 characters</li><li>max: 256 characters</li></ul> |
|
| 289 |
-
* Samples:
|
| 290 |
-
| anchor | positive |
|
| 291 |
-
|:---------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 292 |
-
| <code>kingsideAttack mate mateIn1 middlegame oneMove Horwitz Defense Horwitz Defense [UNK] variations</code> | <code>themes kingsideAttack mate mateIn1 middlegame oneMove opening Horwitz Defense Horwitz Defense Other variations moves f7h8 g6g2 f7h8+g6g2</code> |
|
| 293 |
-
| <code>backRankMate endgame mate mateIn2 short Kings Knight Opening Kings Knight Opening [UNK] [UNK]</code> | <code>themes backRankMate endgame mate mateIn2 short opening Kings Knight Opening Kings Knight Opening Other variations moves c5d4 c3c8 g5d8 c8d8 c5d4+c3c8 c3c8+g5d8 g5d8+c8d8</code> |
|
| 294 |
-
| <code>kingsideAttack mate mateIn1 middlegame oneMove Sicilian Defense Sicilian Defense Paulsen-Basman Defense</code> | <code>themes kingsideAttack mate mateIn1 middlegame oneMove opening Sicilian Defense Sicilian Defense Paulsen-Basman Defense moves g3f3 c7h2 g3f3+c7h2</code> |
|
| 295 |
-
* Loss: [<code>MatryoshkaLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#matryoshkaloss) with these parameters:
|
| 296 |
-
```json
|
| 297 |
-
{
|
| 298 |
-
"loss": "MultipleNegativesRankingLoss",
|
| 299 |
-
"matryoshka_dims": [
|
| 300 |
-
512,
|
| 301 |
-
256,
|
| 302 |
-
128,
|
| 303 |
-
64,
|
| 304 |
-
32
|
| 305 |
-
],
|
| 306 |
-
"matryoshka_weights": [
|
| 307 |
-
1,
|
| 308 |
-
1,
|
| 309 |
-
1,
|
| 310 |
-
1,
|
| 311 |
-
1
|
| 312 |
-
],
|
| 313 |
-
"n_dims_per_step": -1
|
| 314 |
-
}
|
| 315 |
-
```
|
| 316 |
-
|
| 317 |
-
### Training Hyperparameters
|
| 318 |
-
#### Non-Default Hyperparameters
|
| 319 |
-
|
| 320 |
-
- `per_device_train_batch_size`: 4096
|
| 321 |
-
- `num_train_epochs`: 20
|
| 322 |
-
- `learning_rate`: 0.01
|
| 323 |
-
- `warmup_steps`: 0.1
|
| 324 |
-
- `weight_decay`: 0.01
|
| 325 |
-
- `per_device_eval_batch_size`: 4096
|
| 326 |
-
- `push_to_hub`: True
|
| 327 |
-
- `hub_model_id`: oneryalcin/static-embedding-chess
|
| 328 |
-
- `load_best_model_at_end`: True
|
| 329 |
-
- `seed`: 12
|
| 330 |
-
|
| 331 |
-
#### All Hyperparameters
|
| 332 |
-
<details><summary>Click to expand</summary>
|
| 333 |
-
|
| 334 |
-
- `per_device_train_batch_size`: 4096
|
| 335 |
-
- `num_train_epochs`: 20
|
| 336 |
-
- `max_steps`: -1
|
| 337 |
-
- `learning_rate`: 0.01
|
| 338 |
-
- `lr_scheduler_type`: linear
|
| 339 |
-
- `lr_scheduler_kwargs`: None
|
| 340 |
-
- `warmup_steps`: 0.1
|
| 341 |
-
- `optim`: adamw_torch_fused
|
| 342 |
-
- `optim_args`: None
|
| 343 |
-
- `weight_decay`: 0.01
|
| 344 |
-
- `adam_beta1`: 0.9
|
| 345 |
-
- `adam_beta2`: 0.999
|
| 346 |
-
- `adam_epsilon`: 1e-08
|
| 347 |
-
- `optim_target_modules`: None
|
| 348 |
-
- `gradient_accumulation_steps`: 1
|
| 349 |
-
- `average_tokens_across_devices`: True
|
| 350 |
-
- `max_grad_norm`: 1.0
|
| 351 |
-
- `label_smoothing_factor`: 0.0
|
| 352 |
-
- `bf16`: False
|
| 353 |
-
- `fp16`: False
|
| 354 |
-
- `bf16_full_eval`: False
|
| 355 |
-
- `fp16_full_eval`: False
|
| 356 |
-
- `tf32`: None
|
| 357 |
-
- `gradient_checkpointing`: False
|
| 358 |
-
- `gradient_checkpointing_kwargs`: None
|
| 359 |
-
- `torch_compile`: False
|
| 360 |
-
- `torch_compile_backend`: None
|
| 361 |
-
- `torch_compile_mode`: None
|
| 362 |
-
- `use_liger_kernel`: False
|
| 363 |
-
- `liger_kernel_config`: None
|
| 364 |
-
- `use_cache`: False
|
| 365 |
-
- `neftune_noise_alpha`: None
|
| 366 |
-
- `torch_empty_cache_steps`: None
|
| 367 |
-
- `auto_find_batch_size`: False
|
| 368 |
-
- `log_on_each_node`: True
|
| 369 |
-
- `logging_nan_inf_filter`: True
|
| 370 |
-
- `include_num_input_tokens_seen`: no
|
| 371 |
-
- `log_level`: passive
|
| 372 |
-
- `log_level_replica`: warning
|
| 373 |
-
- `disable_tqdm`: False
|
| 374 |
-
- `project`: huggingface
|
| 375 |
-
- `trackio_space_id`: None
|
| 376 |
-
- `trackio_bucket_id`: None
|
| 377 |
-
- `trackio_static_space_id`: None
|
| 378 |
-
- `per_device_eval_batch_size`: 4096
|
| 379 |
-
- `prediction_loss_only`: True
|
| 380 |
-
- `eval_on_start`: False
|
| 381 |
-
- `eval_do_concat_batches`: True
|
| 382 |
-
- `eval_use_gather_object`: False
|
| 383 |
-
- `eval_accumulation_steps`: None
|
| 384 |
-
- `include_for_metrics`: []
|
| 385 |
-
- `batch_eval_metrics`: False
|
| 386 |
-
- `save_only_model`: False
|
| 387 |
-
- `save_on_each_node`: False
|
| 388 |
-
- `enable_jit_checkpoint`: False
|
| 389 |
-
- `push_to_hub`: True
|
| 390 |
-
- `hub_private_repo`: None
|
| 391 |
-
- `hub_model_id`: oneryalcin/static-embedding-chess
|
| 392 |
-
- `hub_strategy`: every_save
|
| 393 |
-
- `hub_always_push`: False
|
| 394 |
-
- `hub_revision`: None
|
| 395 |
-
- `load_best_model_at_end`: True
|
| 396 |
-
- `ignore_data_skip`: False
|
| 397 |
-
- `restore_callback_states_from_checkpoint`: False
|
| 398 |
-
- `full_determinism`: False
|
| 399 |
-
- `seed`: 12
|
| 400 |
-
- `data_seed`: None
|
| 401 |
-
- `use_cpu`: False
|
| 402 |
-
- `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
|
| 403 |
-
- `parallelism_config`: None
|
| 404 |
-
- `dataloader_drop_last`: False
|
| 405 |
-
- `dataloader_num_workers`: 0
|
| 406 |
-
- `dataloader_pin_memory`: True
|
| 407 |
-
- `dataloader_persistent_workers`: False
|
| 408 |
-
- `dataloader_prefetch_factor`: None
|
| 409 |
-
- `remove_unused_columns`: True
|
| 410 |
-
- `label_names`: None
|
| 411 |
-
- `train_sampling_strategy`: random
|
| 412 |
-
- `length_column_name`: length
|
| 413 |
-
- `ddp_find_unused_parameters`: None
|
| 414 |
-
- `ddp_bucket_cap_mb`: None
|
| 415 |
-
- `ddp_broadcast_buffers`: False
|
| 416 |
-
- `ddp_static_graph`: None
|
| 417 |
-
- `ddp_backend`: None
|
| 418 |
-
- `ddp_timeout`: 1800
|
| 419 |
-
- `fsdp`: []
|
| 420 |
-
- `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
|
| 421 |
-
- `deepspeed`: None
|
| 422 |
-
- `debug`: []
|
| 423 |
-
- `skip_memory_metrics`: True
|
| 424 |
-
- `do_predict`: False
|
| 425 |
-
- `resume_from_checkpoint`: None
|
| 426 |
-
- `warmup_ratio`: None
|
| 427 |
-
- `local_rank`: -1
|
| 428 |
-
- `prompts`: None
|
| 429 |
-
- `batch_sampler`: batch_sampler
|
| 430 |
-
- `multi_dataset_batch_sampler`: proportional
|
| 431 |
-
- `router_mapping`: {}
|
| 432 |
-
- `learning_rate_mapping`: {}
|
| 433 |
-
|
| 434 |
-
</details>
|
| 435 |
-
|
| 436 |
-
### Training Logs
|
| 437 |
-
| Epoch | Step | Training Loss | chess-ir_cosine_ndcg@10 | chess-ir-tokens_cosine_ndcg@10 |
|
| 438 |
-
|:------:|:----:|:-------------:|:-----------------------:|:------------------------------:|
|
| 439 |
-
| -1 | -1 | - | 0.0123 | 0.0561 |
|
| 440 |
-
| 0.0025 | 1 | 27.3123 | - | - |
|
| 441 |
-
| 0.2020 | 80 | 26.3304 | - | - |
|
| 442 |
-
| 0.4040 | 160 | 22.2114 | - | - |
|
| 443 |
-
| 0.6061 | 240 | 17.4522 | - | - |
|
| 444 |
-
| 0.8081 | 320 | 12.8864 | - | - |
|
| 445 |
-
| 1.0 | 396 | - | 0.0800 | 0.1181 |
|
| 446 |
-
| 1.0101 | 400 | 9.1439 | - | - |
|
| 447 |
-
| 1.2121 | 480 | 6.5434 | - | - |
|
| 448 |
-
| 1.4141 | 560 | 4.9138 | - | - |
|
| 449 |
-
| 1.6162 | 640 | 3.9819 | - | - |
|
| 450 |
-
| 1.8182 | 720 | 3.4584 | - | - |
|
| 451 |
-
| 2.0 | 792 | - | 0.0505 | 0.0938 |
|
| 452 |
-
| 2.0202 | 800 | 3.1303 | - | - |
|
| 453 |
-
| 2.2222 | 880 | 2.9652 | - | - |
|
| 454 |
-
| 2.4242 | 960 | 2.8584 | - | - |
|
| 455 |
-
| 2.6263 | 1040 | 2.7907 | - | - |
|
| 456 |
-
| 2.8283 | 1120 | 2.7475 | - | - |
|
| 457 |
-
| 3.0 | 1188 | - | 0.0251 | 0.0830 |
|
| 458 |
-
| 3.0303 | 1200 | 2.7031 | - | - |
|
| 459 |
-
| 3.2323 | 1280 | 2.6927 | - | - |
|
| 460 |
-
| 3.4343 | 1360 | 2.6516 | - | - |
|
| 461 |
-
| 3.6364 | 1440 | 2.6441 | - | - |
|
| 462 |
-
| 3.8384 | 1520 | 2.6202 | - | - |
|
| 463 |
-
| 4.0 | 1584 | - | 0.0168 | 0.0672 |
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
### Training Time
|
| 467 |
-
- **Training**: 4.1 minutes
|
| 468 |
-
- **Evaluation**: 0.2 seconds
|
| 469 |
-
- **Total**: 4.1 minutes
|
| 470 |
-
|
| 471 |
-
### Framework Versions
|
| 472 |
-
- Python: 3.12.10
|
| 473 |
-
- Sentence Transformers: 5.5.0
|
| 474 |
-
- Transformers: 5.8.0
|
| 475 |
-
- PyTorch: 2.11.0
|
| 476 |
-
- Accelerate: 1.13.0
|
| 477 |
-
- Datasets: 4.8.5
|
| 478 |
-
- Tokenizers: 0.22.2
|
| 479 |
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
-
#### MatryoshkaLoss
|
| 498 |
-
```bibtex
|
| 499 |
-
@misc{kusupati2024matryoshka,
|
| 500 |
-
title={Matryoshka Representation Learning},
|
| 501 |
-
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
|
| 502 |
-
year={2024},
|
| 503 |
-
eprint={2205.13147},
|
| 504 |
-
archivePrefix={arXiv},
|
| 505 |
-
primaryClass={cs.LG}
|
| 506 |
-
}
|
| 507 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
-
##
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
```
|
| 521 |
|
| 522 |
-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
-
|
| 526 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
-
|
| 529 |
-
## Model Card Authors
|
| 530 |
|
| 531 |
-
|
| 532 |
-
|
| 533 |
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
-
|
| 538 |
-
--
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
language: en
|
|
|
|
| 3 |
license: apache-2.0
|
| 4 |
+
library_name: sentence-transformers
|
| 5 |
+
pipeline_tag: sentence-similarity
|
| 6 |
tags:
|
| 7 |
- sentence-transformers
|
| 8 |
+
- static-embedding
|
| 9 |
+
- chess
|
| 10 |
+
- retrieval
|
| 11 |
+
- exploratory
|
| 12 |
+
datasets:
|
| 13 |
+
- Lichess/chess-puzzles
|
| 14 |
+
- Lichess/chess-openings
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Chess Static Embedding (v4-C2) — Open Exploration
|
| 18 |
+
|
| 19 |
+
A 4M-parameter `StaticEmbedding` model for chess content retrieval, plus the
|
| 20 |
+
full **open-science methodology document** describing what we tried, what
|
| 21 |
+
worked, what failed, and why.
|
| 22 |
+
|
| 23 |
+
This repo is **exploratory experimental work**, published as-is. The model is
|
| 24 |
+
genuinely useful (NDCG@10 = 0.12 on a compositional held-out eval, 50× smaller
|
| 25 |
+
than typical retrieval encoders) but the bigger contribution is the
|
| 26 |
+
**methodology narrative** below — particularly the *LLM-bridge* and
|
| 27 |
+
*deterministic-bridge* findings.
|
| 28 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
---
|
| 30 |
|
| 31 |
+
## Quick start
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
from sentence_transformers import SentenceTransformer
|
| 35 |
+
|
| 36 |
+
model = SentenceTransformer("oneryalcin/static-embedding-chess")
|
| 37 |
+
query = "fork endgame short"
|
| 38 |
+
docs = [
|
| 39 |
+
"themes crushing endgame fork short opening Sicilian Defense moves f2g3 e6e7",
|
| 40 |
+
"themes mate mateIn1 oneMove opening Caro-Kann moves d2d4 e7e5",
|
| 41 |
+
]
|
| 42 |
+
sims = model.encode(query) @ model.encode(docs).T
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Static embedding: lookup table + average. Sub-millisecond CPU inference. No GPU
|
| 46 |
+
required.
|
| 47 |
|
| 48 |
+
---
|
| 49 |
|
| 50 |
+
## Headline result
|
| 51 |
|
| 52 |
+
| Variant | NDCG@10 | vs random init |
|
| 53 |
+
|---------|---------|---------------|
|
| 54 |
+
| v3 baseline (random init + MNRL) | 0.0801 | — |
|
| 55 |
+
| v4-A hard-neg only | 0.1000 | +25% |
|
| 56 |
+
| v4-B theme distill only | 0.0112 | -86% (regression — see methodology) |
|
| 57 |
+
| v4-C multitask 500× | 0.1154 | +44% |
|
| 58 |
+
| **v4-C2 multitask 5000× (this model)** | **0.1202** | **+50%** |
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
Held-out eval: 200 unseen anchor combinations × 600-doc corpus. Compositional
|
| 61 |
+
generalization — the model never saw these exact theme combinations during
|
| 62 |
+
training, only the individual tokens in other combos.
|
| 63 |
|
| 64 |
+
For **production-ready** chess search, see the **two-stage architecture** below
|
| 65 |
+
(static + BM25 over English-bridged docs) that delivers NDCG@10 = 0.59-0.87.
|
|
|
|
| 66 |
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## What's in this repo
|
| 70 |
|
| 71 |
```
|
| 72 |
+
model.safetensors # 4M-param StaticEmbedding weights (~9MB)
|
| 73 |
+
chess_tokenizer.json # WordLevel chess tokenizer (4,336 tokens)
|
| 74 |
+
tokenizer.json # Same, in HF format for ST loading
|
| 75 |
+
config_sentence_transformers.json # Module config
|
| 76 |
+
modules.json # Module pipeline
|
| 77 |
+
|
| 78 |
+
data/
|
| 79 |
+
├── theme_definitions.parquet # 73 chess themes + LLM-generated English defs + MPNet embeddings (the LLM-bridge teacher signal)
|
| 80 |
+
├── hard_negatives_chess.parquet # 1.6M (anchor, positive, negative) triplets, chess-token format
|
| 81 |
+
└── hard_negatives_english.parquet # Same, English-bridged via deterministic conversion
|
| 82 |
+
|
| 83 |
+
scripts/
|
| 84 |
+
├── train_chess_static.py # Main training entrypoint (multi-version, env-flag controlled)
|
| 85 |
+
├── train_chess_multitask.py # The v4-C2 winning recipe (theme distill + hard-neg MNRL)
|
| 86 |
+
├── convert_to_english.py # Deterministic chess→English (no LLM needed; python-chess + regex)
|
| 87 |
+
├── mine_hard_negs_v2.py # Memory-bounded custom hard-negative miner
|
| 88 |
+
├── generate_theme_defs.py # LLM-bridge: DeepSeek-v4-flash writes chess concept definitions
|
| 89 |
+
├── compare_variants.py # Side-by-side eval framework across all variants
|
| 90 |
+
└── diag_ce_vs_bm25.py # The critical "is your CE really helping" diagnostic
|
| 91 |
```
|
| 92 |
|
| 93 |
+
---
|
| 94 |
|
| 95 |
+
## Methodology — the full experimental journey
|
| 96 |
|
| 97 |
+
This was 36+ hours of iterative exploration. The model is the small visible
|
| 98 |
+
output; the methodology is the bigger contribution.
|
| 99 |
|
| 100 |
+
### 1. Problem and approach
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
**Task:** Free-text search over a chess puzzle corpus. User types something
|
| 103 |
+
like `"fork endgame short"` and gets matching Lichess puzzles.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
**Why static embedding:** Tom Aarsen's
|
| 106 |
+
[static-retrieval-mrl-en-v1](https://huggingface.co/sentence-transformers/static-retrieval-mrl-en-v1)
|
| 107 |
+
showed StaticEmbedding can be a useful retrieval primitive with the right
|
| 108 |
+
training. We adapted the recipe for a chess-specific domain with a custom
|
| 109 |
+
WordLevel tokenizer so chess tokens (UCI moves, theme names, ECO codes) are
|
| 110 |
+
first-class.
|
| 111 |
|
| 112 |
+
**Data:** Lichess/chess-puzzles (5.8M puzzles, CC0) + Lichess/chess-openings
|
| 113 |
+
(3.6K openings, CC0).
|
| 114 |
+
|
| 115 |
+
### 2. Eval design — the hardest part
|
| 116 |
+
|
| 117 |
+
**Initial mistake:** First eval used top-200 most-common theme strings as
|
| 118 |
+
queries. The model had seen each of these ~50,000 times in training. Baseline
|
| 119 |
+
NDCG@10 was inflated to 0.81 by lexical overlap before any training. Useless.
|
| 120 |
+
|
| 121 |
+
**Fixed eval (used throughout):** *Compositional held-out anchors*. Pick 200
|
| 122 |
+
theme-combination strings that appear exactly 3 times in the data
|
| 123 |
+
(rare-but-multi-relevant), remove all matching pairs from train, use those rare
|
| 124 |
+
combos as queries. Tests whether the model can compose meaning from individual
|
| 125 |
+
theme tokens it learned, without having seen the specific combination.
|
| 126 |
+
|
| 127 |
+
This is harsh — the model can never "memorize" the eval queries — and that's
|
| 128 |
+
the point. Random-init baseline drops to NDCG@10 ≈ 0.01.
|
| 129 |
+
|
| 130 |
+
### 3. Phase 1 — diagnostic of the v3 model (0.08 NDCG@10)
|
| 131 |
+
|
| 132 |
+
A working baseline existed. Question: **why isn't it better?**
|
| 133 |
+
|
| 134 |
+
Token-similarity probe revealed the core issue:
|
| 135 |
+
|
| 136 |
+
| Pair | v3 cosine similarity |
|
| 137 |
+
|---|---|
|
| 138 |
+
| `fork` ↔ `pin` | +0.01 |
|
| 139 |
+
| `fork` ↔ `skewer` | -0.12 |
|
| 140 |
+
| `endgame` ↔ `middlegame` | -0.30 |
|
| 141 |
+
|
| 142 |
+
**Token embeddings were essentially orthogonal.** The model learned per-token
|
| 143 |
+
mappings to chess-content clusters but no relationships *between* tokens.
|
| 144 |
+
Compositional generalization (the eval task) requires those relationships.
|
| 145 |
+
|
| 146 |
+
Also discovered: 51% of held-out queries returned zero relevant in top-10
|
| 147 |
+
(median NDCG@10 = 0). Bimodal failure pattern.
|
| 148 |
+
|
| 149 |
+
Also discovered: model beat BM25 by 7.5× (0.08 vs 0.01), confirming it does
|
| 150 |
+
real semantic work beyond keyword match.
|
| 151 |
+
|
| 152 |
+
### 4. Phase 2 — distillation from raw MPNet (DEAD END)
|
| 153 |
+
|
| 154 |
+
Hypothesis: distill student token embeddings to match teacher (MPNet)
|
| 155 |
+
embeddings. Teacher knows English; should know that `fork ≈ pin`.
|
| 156 |
+
|
| 157 |
+
**Result:** REGRESSION. Why? **MPNet itself scores NDCG@10 = 0.0094 on our
|
| 158 |
+
eval.** 95.5% of queries get zero in top-10. MPNet doesn't know chess: UCI
|
| 159 |
+
moves are character soup to its WordPiece tokenizer.
|
| 160 |
+
|
| 161 |
+
**You can't distill what the teacher doesn't know.** This was the first key
|
| 162 |
+
lesson.
|
| 163 |
+
|
| 164 |
+
### 5. Phase 3 — LLM-bridge for theme distillation (BREAKTHROUGH)
|
| 165 |
+
|
| 166 |
+
Key insight: an LLM can read both chess (in camelCase) AND English. Use it as
|
| 167 |
+
a **translator** to put chess concepts into language MPNet *can* understand
|
| 168 |
+
semantically.
|
| 169 |
+
|
| 170 |
+
**Steps:**
|
| 171 |
+
|
| 172 |
+
1. DeepSeek-v4-flash writes English definitions for 73 Lichess themes:
|
| 173 |
+
- `fork` → "A tactical motif where a single piece attacks two or more
|
| 174 |
+
enemy pieces simultaneously, forcing a material gain."
|
| 175 |
+
2. MPNet embeds the *English definitions* (it knows English fluently).
|
| 176 |
+
3. Distill the student's per-token embedding to match the definition embedding.
|
| 177 |
+
|
| 178 |
+
After step 2 alone, MPNet's `fork ↔ skewer` similarity jumps from 0.39 (raw
|
| 179 |
+
camelCase) to **0.87** (via definitions). Real semantic structure.
|
| 180 |
+
|
| 181 |
+
Combined with hard-negative MNRL training (v4-C2): **NDCG@10 = 0.1202**, +50%
|
| 182 |
+
over v3.
|
| 183 |
+
|
| 184 |
+
Cost: 73 themes × DeepSeek API ≈ $0.01 + ~1 minute generation.
|
| 185 |
+
|
| 186 |
+
This is the **LLM-bridge** pattern: when system A doesn't speak system B's
|
| 187 |
+
language, use an LLM as a translator. The LLM is one-shot work, not part of
|
| 188 |
+
inference.
|
| 189 |
+
|
| 190 |
+
### 6. Phase 4 — hard-negative mining
|
| 191 |
+
|
| 192 |
+
Used the v3 model to mine confusable documents per anchor. Custom
|
| 193 |
+
memory-bounded miner because the sentence-transformers built-in OOMs on M4 at
|
| 194 |
+
327k unique anchors × 327k positives. See `scripts/mine_hard_negs_v2.py`.
|
| 195 |
+
|
| 196 |
+
1.6M triplets mined. Positive-negative margin: 0.135 mean (good signal for
|
| 197 |
+
training).
|
| 198 |
+
|
| 199 |
+
### 7. Phase 5 — multi-task training (v4-C2 winner)
|
| 200 |
+
|
| 201 |
+
Multi-dataset trainer combining:
|
| 202 |
+
- **Chess triplets** (1.6M, MNRL loss): teaches content associations
|
| 203 |
+
- **Theme distillation** (73 themes × 5000 replicas via `EmbedDistillLoss`):
|
| 204 |
+
injects semantic structure between tokens
|
| 205 |
+
|
| 206 |
+
With proportional sampling, theme tokens see ~500 gradient updates per epoch
|
| 207 |
+
(via replication) vs chess pairs once. Theme distillation oversampling matters:
|
| 208 |
+
|
| 209 |
+
| Theme replicas | NDCG@10 |
|
| 210 |
+
|---|---|
|
| 211 |
+
| 500× | 0.1154 |
|
| 212 |
+
| 5000× | 0.1202 |
|
| 213 |
+
|
| 214 |
+
### 8. Phase 6 — cross-encoder reranker attempts (ALL FAILED)
|
| 215 |
+
|
| 216 |
+
Tried three variants:
|
| 217 |
+
- MS-MARCO MiniLM (English-pretrained, 22M params) on chess-format docs
|
| 218 |
+
- Same, with theme echo stripped from training docs
|
| 219 |
+
- Fresh-init tiny BERT (5M params) with our chess tokenizer
|
| 220 |
+
|
| 221 |
+
**All regressed below static-only.** Diagnosis: trained CEs operate at
|
| 222 |
+
random-ordering level on the eval. Inspection of training predictions showed
|
| 223 |
+
the trained CE got pair-ordering wrong 2/3 of the time on sample inputs.
|
| 224 |
+
|
| 225 |
+
**Root cause:** documents are UCI move sequences (`f2g3 e6e7 ...`). To
|
| 226 |
+
English-pretrained CE tokenizers these are character fragments with no
|
| 227 |
+
meaningful representation. The CE can't learn what makes a "fork-y" move
|
| 228 |
+
sequence from sparse labels alone. Static embedding worked because token-bag
|
| 229 |
+
averaging is sample-efficient (each `fork` token gets gradients from many
|
| 230 |
+
examples → converges to a useful cluster); the CE's pair-level processing is
|
| 231 |
+
hungrier for signal not available in our data.
|
| 232 |
+
|
| 233 |
+
### 9. Phase 7 — deterministic English bridge for documents (REVEALED THE TRUTH)
|
| 234 |
+
|
| 235 |
+
Insight: we don't need an LLM to translate documents either. `python-chess`
|
| 236 |
+
deterministically converts UCI → SAN with board context (`f2g3` → `Bxg3`).
|
| 237 |
+
Regex decamelizes themes (`backRankMate` → `back rank mate`). Free, instant,
|
| 238 |
+
reproducible. The `convert_to_english.py` script does the full 5.8M corpus in
|
| 239 |
+
~3 minutes.
|
| 240 |
+
|
| 241 |
+
Re-ran reranker training on English-bridged docs. **Untrained MS-MARCO CE hit
|
| 242 |
+
the oracle ceiling (0.5947 at top-100).** Massive jump.
|
| 243 |
+
|
| 244 |
+
But: ran a final diagnostic comparing trained CE vs **BM25** over the same
|
| 245 |
+
English docs. They were *identical*:
|
| 246 |
+
|
| 247 |
+
| K | Static | +CE | +BM25 | Oracle |
|
| 248 |
+
|---|---|---|---|---|
|
| 249 |
+
| 100 | 0.1202 | **0.5947** | **0.5947** | 0.5947 |
|
| 250 |
+
| 200 | 0.1202 | 0.7706 | 0.7706 | 0.7706 |
|
| 251 |
+
| 300 | 0.1202 | 0.8718 | 0.8718 | 0.8718 |
|
| 252 |
+
|
| 253 |
+
The "LLM-bridge effect" we observed was **lexical match enabled by the
|
| 254 |
+
English conversion**, not semantic CE understanding. BM25 over English docs
|
| 255 |
+
does the same job.
|
| 256 |
+
|
| 257 |
+
**Stress test**: stripped theme tokens from English docs too. Forces the CE
|
| 258 |
+
to genuinely understand "fork query ↔ fork-pattern moves":
|
| 259 |
+
|
| 260 |
+
| K | Static | +CE | +BM25 | Oracle |
|
| 261 |
+
|---|---|---|---|---|
|
| 262 |
+
| 100 | 0.1202 | 0.0726 | 0.4327 | 0.5947 |
|
| 263 |
+
| 300 | 0.1202 | 0.0706 | 0.6252 | 0.8718 |
|
| 264 |
+
|
| 265 |
+
CE drops below static (negative transfer — memorized "theme overlap = match"
|
| 266 |
+
during training; can't generalize). BM25 still partially works via opening
|
| 267 |
+
name overlap.
|
| 268 |
+
|
| 269 |
+
**True semantic CE chess understanding is not achievable** with 22M-param
|
| 270 |
+
English-pretrained models on our training signal.
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
## Production recommendation
|
| 275 |
+
|
| 276 |
+
For a real chess search system, the winning architecture is:
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
```
|
| 279 |
+
Stage 1: Static embedding (this model)
|
| 280 |
+
- Encode chess-format corpus (4M params, ~9MB)
|
| 281 |
+
- Sub-millisecond CPU inference
|
| 282 |
+
- Retrieve top-200 candidates via cosine similarity
|
| 283 |
+
- Recall@200 = 93.5%
|
| 284 |
+
|
| 285 |
+
Stage 2: BM25 over English-bridged corpus
|
| 286 |
+
- python-chess + regex (one-time, $0)
|
| 287 |
+
- Index the English versions of all docs
|
| 288 |
+
- Rerank top-200 candidates to top-10
|
| 289 |
+
- NDCG@10 ≈ 0.55-0.62
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
**Total: <10ms/query, $0 inference cost, no GPU.**
|
| 293 |
+
|
| 294 |
+
The cross-encoder is only worth adding if you have GPU available AND you train
|
| 295 |
+
it on a fundamentally different signal (e.g., human-annotated relevance,
|
| 296 |
+
chess-engine strategic descriptions, or much more parameters with chess in
|
| 297 |
+
pretraining).
|
| 298 |
+
|
| 299 |
+
---
|
| 300 |
|
| 301 |
+
## Key learnings worth keeping (general, not chess-specific)
|
| 302 |
+
|
| 303 |
+
1. **Eval methodology dominates.** Most time spent debugging the "model isn't
|
| 304 |
+
improving" turned out to be eval issues, not training issues. Compositional
|
| 305 |
+
held-out > top-frequent-string eval. Strip lexical leakage between query
|
| 306 |
+
and corpus when testing generalization.
|
| 307 |
+
|
| 308 |
+
2. **Sentence-transformers' `NoDuplicatesBatchSampler` is O(epoch-progress)
|
| 309 |
+
per batch.** It walks a linked-list of deferred conflicts. For datasets
|
| 310 |
+
with limited unique anchors (our ~327k anchors over 5.8M pairs), this
|
| 311 |
+
creates monotonic step-time blowup. Switch to `BatchSamplers.BATCH_SAMPLER`.
|
| 312 |
+
|
| 313 |
+
3. **`CachedMultipleNegativesRankingLoss` is incompatible with
|
| 314 |
+
`StaticEmbedding`** — explicit error. Token-bag has no transformer
|
| 315 |
+
activations to GradCache through.
|
| 316 |
+
|
| 317 |
+
4. **Trackio crashes on first checkpoint push** with sentence-transformers
|
| 318 |
+
due to an empty `router_mapping` struct that pyarrow can't write. Use
|
| 319 |
+
`report_to="none"`.
|
| 320 |
+
|
| 321 |
+
5. **The "LLM-bridge" pattern**: when system A speaks language X and system
|
| 322 |
+
B speaks language Y, use an LLM to translate B→X once (not at inference).
|
| 323 |
+
For chess: LLM writes English definitions of themes → general English
|
| 324 |
+
teacher can now embed them → distill into chess-specific model.
|
| 325 |
+
|
| 326 |
+
6. **Deterministic translation often suffices** for the bridge. Don't pay LLM
|
| 327 |
+
API costs if `python-chess` and regex can produce the same English text.
|
| 328 |
+
Reserve LLMs for the parts that genuinely need understanding (concept
|
| 329 |
+
definitions, paraphrases, strategic narratives).
|
| 330 |
+
|
| 331 |
+
7. **Compare your trained model against BM25** on the actual eval. If they
|
| 332 |
+
tie, your model is doing keyword matching, not semantic work. Diagnostic
|
| 333 |
+
in `scripts/diag_ce_vs_bm25.py`.
|
| 334 |
+
|
| 335 |
+
8. **Modal `.spawn()` only survives entrypoint exit on deployed apps.** For
|
| 336 |
+
ephemeral `modal run`, the app dies when entrypoint returns — including
|
| 337 |
+
spawned calls. Use `.remote()` with `--detach`.
|
| 338 |
+
|
| 339 |
+
9. **Apple Silicon M4 is competitive with cloud A100** for tiny models. Token
|
| 340 |
+
bag + small batch easily hits 17 it/s on MPS. GPU cost is wasted unless
|
| 341 |
+
the model is compute-bound.
|
| 342 |
+
|
| 343 |
+
---
|
| 344 |
+
|
| 345 |
+
## Reproducibility
|
| 346 |
+
|
| 347 |
+
Clone this repo, then with sentence-transformers v5.5+:
|
| 348 |
+
|
| 349 |
+
```bash
|
| 350 |
+
# Inspect the recipe
|
| 351 |
+
cat scripts/train_chess_multitask.py
|
| 352 |
+
|
| 353 |
+
# Reproduce the data prep (one-time, ~10 min)
|
| 354 |
+
python scripts/generate_theme_defs.py # Needs DeepSeek API key in macOS keychain
|
| 355 |
+
python scripts/convert_to_english.py # python-chess + regex, $0
|
| 356 |
+
python scripts/mine_hard_negs_v2.py # ~10 min on M4 MPS
|
| 357 |
+
|
| 358 |
+
# Reproduce the winning training
|
| 359 |
+
python scripts/train_chess_multitask.py # ~5 min on M4 MPS
|
| 360 |
+
|
| 361 |
+
# Verify
|
| 362 |
+
python scripts/compare_variants.py # Side-by-side eval table
|
| 363 |
+
python scripts/diag_ce_vs_bm25.py # Is the rerank doing real work?
|
| 364 |
```
|
| 365 |
|
| 366 |
+
---
|
| 367 |
+
|
| 368 |
+
## Limitations and honest caveats
|
| 369 |
+
|
| 370 |
+
- **NDCG@10 = 0.12 is modest in absolute terms.** Industry retrieval encoders
|
| 371 |
+
reach 0.4-0.6 on similar tasks. This model is competitive on size/speed,
|
| 372 |
+
not absolute quality.
|
| 373 |
+
- **The two-stage architecture (NDCG@10 ≈ 0.6) is the production answer**
|
| 374 |
+
but relies on BM25 over English-converted docs, not on the cross-encoder.
|
| 375 |
+
- **Cross-encoder didn't add semantic value** in our setup; results came from
|
| 376 |
+
lexical match enabled by the English bridge.
|
| 377 |
+
- **Bimodal failure**: even the best model misses half of queries entirely
|
| 378 |
+
(median NDCG@10 = 0). The architecture has fundamental limits for chess
|
| 379 |
+
reasoning.
|
| 380 |
+
- **English-pretrained models don't know chess.** Tried MPNet, MiniLM,
|
| 381 |
+
Jina-v5; all fail on UCI moves. Bigger English models won't fix this; only
|
| 382 |
+
chess-pretrained or deterministic conversion helps.
|
| 383 |
+
- **No engine evaluation.** "Is this puzzle a fork?" was determined by
|
| 384 |
+
Lichess theme tags; we never ran a chess engine. A real production system
|
| 385 |
+
would integrate Stockfish for ground-truth tactical pattern detection.
|
| 386 |
+
|
| 387 |
+
---
|
| 388 |
+
|
| 389 |
+
## What this is NOT
|
| 390 |
|
| 391 |
+
- Not a chess engine. See [`thomasahle/fastchess`](https://github.com/thomasahle/fastchess)
|
| 392 |
+
for FastText-based move prediction (closest related work).
|
| 393 |
+
- Not a position similarity model. See `chess2vec` lineage on GitHub for
|
| 394 |
+
position-level embeddings.
|
| 395 |
+
- Not a state-of-the-art retrieval model. It's a tiny first-stage filter
|
| 396 |
+
designed to pair with a reranker.
|
| 397 |
+
|
| 398 |
+
---
|
| 399 |
|
| 400 |
+
## License
|
|
|
|
| 401 |
|
| 402 |
+
Apache 2.0 (model + scripts). Data derived from Lichess/chess-puzzles which is
|
| 403 |
+
CC0 — derived parquets in this repo are also released under CC0.
|
| 404 |
|
| 405 |
+
## Acknowledgments
|
| 406 |
+
|
| 407 |
+
- [Lichess](https://lichess.org) for releasing puzzles + openings under CC0.
|
| 408 |
+
- [Tom Aarsen](https://huggingface.co/tomaarsen) for the
|
| 409 |
+
`train-sentence-transformers` skill and `StaticEmbedding` recipe.
|
| 410 |
+
- DeepSeek for the v4-flash API used for theme definitions.
|
| 411 |
+
|
| 412 |
+
## Citation
|
| 413 |
|
| 414 |
+
If this work is useful, please link to this repo. The scientific findings
|
| 415 |
+
(particularly the deterministic-bridge insight that BM25 over English-bridged
|
| 416 |
+
docs equals a trained cross-encoder for this task) are the main contribution.
|
data/hard_negatives_chess.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3dc7f1bfcb497ba5f5e61c1b9fffe76ca52825758454c65b3a2dc2010e3e68bb
|
| 3 |
+
size 161012028
|
data/hard_negatives_english.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50b28d80013527fcb6f27554ee0cda91116e4b3967a74472320a089a7b1fa873
|
| 3 |
+
size 111083130
|
data/theme_definitions.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f70e1629bfda29faedfca1474d2195bd527590eeb48b628fd862da12a2070f3
|
| 3 |
+
size 456977
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 8880224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fa4d9dd8e62c4ef6d7f288ea1822f30d5f75f3a5ab178a923c4330e3b09652d
|
| 3 |
size 8880224
|
scripts/compare_variants.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "sentence-transformers[train]>=5.5.0",
|
| 6 |
+
# "datasets>=2.19.0",
|
| 7 |
+
# "numpy",
|
| 8 |
+
# ]
|
| 9 |
+
# ///
|
| 10 |
+
"""Side-by-side comparison of all chess static-embedding variants on the same
|
| 11 |
+
held-out compositional eval. Produces the final table for NOTES.md.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
|
| 23 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 24 |
+
|
| 25 |
+
VARIANTS = [
|
| 26 |
+
("v3 baseline", "models/static-embedding-chess/final"),
|
| 27 |
+
("v4-A hard-neg only", "models/static-embedding-chess-triplet/final"),
|
| 28 |
+
("v4-B theme distill", "models/static-embedding-chess-theme-only/final"),
|
| 29 |
+
("v4-C multitask 500x", "models/static-embedding-chess-multitask-500x/final"),
|
| 30 |
+
("v4-C2 multitask 5000x", "models/static-embedding-chess-multitask-5000x/final"),
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
HELDOUT_FREQ_MIN = 3
|
| 34 |
+
HELDOUT_FREQ_MAX = 30
|
| 35 |
+
EVAL_QUERIES = 200
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _join_tags(tags):
|
| 39 |
+
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _bigram_token_str(moves):
|
| 43 |
+
toks = moves.split()
|
| 44 |
+
if len(toks) < 2:
|
| 45 |
+
return moves
|
| 46 |
+
return moves + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def build_puzzle_pairs(batch):
|
| 50 |
+
anchors, positives = [], []
|
| 51 |
+
for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
|
| 52 |
+
themes_txt = _join_tags(themes)
|
| 53 |
+
op_txt = _join_tags(op)
|
| 54 |
+
if not themes_txt:
|
| 55 |
+
continue
|
| 56 |
+
anchor = themes_txt + (f" {op_txt}" if op_txt else "")
|
| 57 |
+
positive = f"themes {themes_txt}"
|
| 58 |
+
if op_txt:
|
| 59 |
+
positive += f" opening {op_txt}"
|
| 60 |
+
positive += f" moves {_bigram_token_str(moves)}"
|
| 61 |
+
anchors.append(anchor)
|
| 62 |
+
positives.append(positive)
|
| 63 |
+
return {"anchor": anchors, "positive": positives}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def strip_theme_echo(p):
|
| 67 |
+
i = p.find(" moves ")
|
| 68 |
+
return p[i + 1 :] if i != -1 else p
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def ndcg_at_k(scores, rel, k=10):
|
| 72 |
+
ranked = sorted(scores, key=lambda kv: -kv[1])[:k]
|
| 73 |
+
dcg = sum((1.0 if d in rel else 0.0) / np.log2(r + 2) for r, (d, _) in enumerate(ranked))
|
| 74 |
+
idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), k)))
|
| 75 |
+
return dcg / idcg if idcg > 0 else 0.0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
print("Loading + held-out selection...")
|
| 80 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 81 |
+
pair_puzzles = puzzles.map(
|
| 82 |
+
build_puzzle_pairs,
|
| 83 |
+
batched=True, batch_size=20_000,
|
| 84 |
+
remove_columns=puzzles.column_names,
|
| 85 |
+
num_proc=4,
|
| 86 |
+
)
|
| 87 |
+
anchors = pair_puzzles["anchor"]
|
| 88 |
+
freq = defaultdict(int)
|
| 89 |
+
for a in anchors:
|
| 90 |
+
freq[a] += 1
|
| 91 |
+
rare_pool = sorted(
|
| 92 |
+
((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
|
| 93 |
+
key=lambda kv: kv[1],
|
| 94 |
+
)
|
| 95 |
+
heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]}
|
| 96 |
+
held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h]
|
| 97 |
+
held_anchors = [anchors[i] for i in held_idx]
|
| 98 |
+
corpus_texts = [strip_theme_echo(pair_puzzles["positive"][i]) for i in held_idx]
|
| 99 |
+
corpus_ids = [f"d{i}" for i in range(len(corpus_texts))]
|
| 100 |
+
by_anchor = defaultdict(list)
|
| 101 |
+
for i, a in enumerate(held_anchors):
|
| 102 |
+
by_anchor[a].append(corpus_ids[i])
|
| 103 |
+
queries = list(by_anchor.keys())
|
| 104 |
+
print(f" {len(queries)} queries, {len(corpus_texts)} corpus")
|
| 105 |
+
|
| 106 |
+
results = []
|
| 107 |
+
|
| 108 |
+
for name, path in VARIANTS:
|
| 109 |
+
if not os.path.exists(path):
|
| 110 |
+
print(f"\nSKIPPING {name}: {path} not found")
|
| 111 |
+
continue
|
| 112 |
+
print(f"\n=== {name} ({path}) ===")
|
| 113 |
+
m = SentenceTransformer(path)
|
| 114 |
+
c = m.encode(corpus_texts, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
|
| 115 |
+
c = c / np.linalg.norm(c, axis=1, keepdims=True)
|
| 116 |
+
q = m.encode(queries, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
|
| 117 |
+
q = q / np.linalg.norm(q, axis=1, keepdims=True)
|
| 118 |
+
sims = q @ c.T
|
| 119 |
+
ndcgs = []
|
| 120 |
+
for qi, query in enumerate(queries):
|
| 121 |
+
score_pairs = [(corpus_ids[ci], float(sims[qi, ci])) for ci in range(len(corpus_ids))]
|
| 122 |
+
rel = set(by_anchor[query])
|
| 123 |
+
ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
|
| 124 |
+
ndcg = np.mean(ndcgs)
|
| 125 |
+
median = np.median(ndcgs)
|
| 126 |
+
zero = sum(1 for n in ndcgs if n == 0)
|
| 127 |
+
results.append((name, ndcg, median, zero, len(ndcgs)))
|
| 128 |
+
print(f" NDCG@10 = {ndcg:.4f} median = {median:.4f} zero = {zero}/{len(ndcgs)}")
|
| 129 |
+
|
| 130 |
+
print("\n" + "=" * 70)
|
| 131 |
+
print(f"{'Variant':<30} {'NDCG@10':>10} {'Median':>10} {'Zero/All':>15}")
|
| 132 |
+
print("=" * 70)
|
| 133 |
+
for name, ndcg, median, zero, total in results:
|
| 134 |
+
print(f"{name:<30} {ndcg:>10.4f} {median:>10.4f} {zero:>7}/{total:<7}")
|
| 135 |
+
print("=" * 70)
|
| 136 |
+
|
| 137 |
+
# === Token-similarity probe ===
|
| 138 |
+
# Measures the orthogonal-tokens problem from Phase 1: do related themes
|
| 139 |
+
# cluster in embedding space? Higher = more semantic structure.
|
| 140 |
+
print("\n=== Theme-token similarity (higher = more semantic clustering) ===")
|
| 141 |
+
PROBES = [
|
| 142 |
+
("fork", "skewer"), # tactical motifs (should be close)
|
| 143 |
+
("fork", "pin"),
|
| 144 |
+
("backRankMate", "smotheredMate"), # mate patterns
|
| 145 |
+
("kingsideAttack", "queensideAttack"),
|
| 146 |
+
("endgame", "middlegame"), # phases
|
| 147 |
+
("fork", "promotion"), # unrelated (control)
|
| 148 |
+
]
|
| 149 |
+
print(f"{'Pair':<40}", end="")
|
| 150 |
+
for name, _ in VARIANTS:
|
| 151 |
+
if os.path.exists([p for n, p in VARIANTS if n == name][0]):
|
| 152 |
+
print(f" {name[:14]:>16}", end="")
|
| 153 |
+
print()
|
| 154 |
+
print("-" * 70)
|
| 155 |
+
for a, b in PROBES:
|
| 156 |
+
line = f"{a} <-> {b}".ljust(40)
|
| 157 |
+
for name, path in VARIANTS:
|
| 158 |
+
if not os.path.exists(path):
|
| 159 |
+
continue
|
| 160 |
+
m = SentenceTransformer(path)
|
| 161 |
+
ea = m.encode([a], convert_to_numpy=True)[0]
|
| 162 |
+
eb = m.encode([b], convert_to_numpy=True)[0]
|
| 163 |
+
ea = ea / max(np.linalg.norm(ea), 1e-9)
|
| 164 |
+
eb = eb / max(np.linalg.norm(eb), 1e-9)
|
| 165 |
+
sim = float(np.dot(ea, eb))
|
| 166 |
+
line += f" {sim:>+16.3f}"
|
| 167 |
+
print(line)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
main()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
main()
|
scripts/convert_to_english.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = ["chess", "datasets>=2.19", "tqdm"]
|
| 5 |
+
# ///
|
| 6 |
+
"""Deterministic chess→English converter for puzzles.
|
| 7 |
+
|
| 8 |
+
Generates a standardized English-readable description of each puzzle WITHOUT
|
| 9 |
+
any LLM. Uses python-chess for UCI→SAN conversion (with board context), regex
|
| 10 |
+
for decamelizing themes, and a fixed template.
|
| 11 |
+
|
| 12 |
+
For each puzzle, produces a doc like:
|
| 13 |
+
|
| 14 |
+
"White to move. Short middlegame puzzle with crushing fork and hanging
|
| 15 |
+
piece motifs. Opening: King's Pawn Game. Moves: Bxg3 Rxe7 Qb1+ Nc1 Qxc1+
|
| 16 |
+
Qxc1"
|
| 17 |
+
|
| 18 |
+
Pretrained English cross-encoders have seen SAN notation in chess web content
|
| 19 |
+
during pretraining, so this doc is semantically meaningful to them — unlike
|
| 20 |
+
the raw UCI form (`f2g3`) which gets fragmented into character pieces.
|
| 21 |
+
|
| 22 |
+
Output: parquet at models/puzzles_english.parquet with columns:
|
| 23 |
+
PuzzleId, anchor (original themes+opening str), english_doc
|
| 24 |
+
|
| 25 |
+
Run:
|
| 26 |
+
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 convert_to_english.py
|
| 27 |
+
uv run --exclude-newer=2026-05-12 convert_to_english.py
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import os
|
| 32 |
+
import re
|
| 33 |
+
import sys
|
| 34 |
+
|
| 35 |
+
import chess
|
| 36 |
+
from datasets import Dataset, load_dataset
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 40 |
+
|
| 41 |
+
OUTPUT_PATH = "models/puzzles_english.parquet"
|
| 42 |
+
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
|
| 43 |
+
|
| 44 |
+
# Length tag mapping
|
| 45 |
+
LENGTH_MAP = {
|
| 46 |
+
"oneMove": "single-move",
|
| 47 |
+
"short": "short",
|
| 48 |
+
"long": "long",
|
| 49 |
+
"veryLong": "very long",
|
| 50 |
+
}
|
| 51 |
+
PHASE_TAGS = {"opening", "middlegame", "endgame"}
|
| 52 |
+
LENGTH_TAGS = set(LENGTH_MAP.keys())
|
| 53 |
+
# Anything matching `mateInN`, `mateIn1`, etc.
|
| 54 |
+
MATE_IN_PATTERN = re.compile(r"^mateIn(\d+)$")
|
| 55 |
+
# Specific mate-pattern names (their English form is just decamel)
|
| 56 |
+
# camelCase → "camel case" via regex
|
| 57 |
+
_CAMEL_BOUNDARY = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def decamelize(tag: str) -> str:
|
| 61 |
+
"""`backRankMate` → 'back rank mate'. `attackingF2F7` → 'attacking f2 f7'."""
|
| 62 |
+
return _CAMEL_BOUNDARY.sub(" ", tag).lower()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def themes_to_english(themes: list[str]) -> tuple[str, str, str, list[str]]:
|
| 66 |
+
"""Returns (side_phrase, length_phrase, phase, decamelized_other_themes).
|
| 67 |
+
|
| 68 |
+
Splits themes into structural (phase, length, mate-in-N) and motif (everything else).
|
| 69 |
+
The motifs are returned decamelized.
|
| 70 |
+
"""
|
| 71 |
+
if not themes:
|
| 72 |
+
return ("", "", "", [])
|
| 73 |
+
phase = ""
|
| 74 |
+
length = ""
|
| 75 |
+
mate_in = None
|
| 76 |
+
motifs = []
|
| 77 |
+
for t in themes:
|
| 78 |
+
if t in PHASE_TAGS:
|
| 79 |
+
phase = t
|
| 80 |
+
elif t in LENGTH_TAGS:
|
| 81 |
+
length = LENGTH_MAP[t]
|
| 82 |
+
elif (m := MATE_IN_PATTERN.match(t)):
|
| 83 |
+
mate_in = int(m.group(1))
|
| 84 |
+
else:
|
| 85 |
+
motifs.append(decamelize(t))
|
| 86 |
+
# Mate-in-N gets folded into motifs as natural-language phrase
|
| 87 |
+
if mate_in is not None:
|
| 88 |
+
motifs.append(f"mate in {mate_in}")
|
| 89 |
+
return phase, length, "", motifs # side_phrase computed separately from FEN
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def opening_tags_to_english(opening_tags: list[str]) -> str:
|
| 93 |
+
"""`['Kings_Pawn_Game', 'Kings_Pawn_Game_Leonardis_Variation']` → 'King's Pawn Game Leonardi's Variation'.
|
| 94 |
+
Dedupe by taking the longest matching tag."""
|
| 95 |
+
if not opening_tags:
|
| 96 |
+
return ""
|
| 97 |
+
# Use the longest tag (most specific) and replace underscores with spaces
|
| 98 |
+
longest = max(opening_tags, key=len)
|
| 99 |
+
return longest.replace("_", " ")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def uci_to_san_sequence(fen: str, uci_moves: str) -> str:
|
| 103 |
+
"""Convert UCI move sequence to SAN, using board context for disambiguation."""
|
| 104 |
+
try:
|
| 105 |
+
board = chess.Board(fen)
|
| 106 |
+
san_moves = []
|
| 107 |
+
for uci in uci_moves.split():
|
| 108 |
+
try:
|
| 109 |
+
move = chess.Move.from_uci(uci)
|
| 110 |
+
san = board.san(move)
|
| 111 |
+
san_moves.append(san)
|
| 112 |
+
board.push(move)
|
| 113 |
+
except Exception:
|
| 114 |
+
# Invalid move — skip rest
|
| 115 |
+
break
|
| 116 |
+
return " ".join(san_moves)
|
| 117 |
+
except Exception:
|
| 118 |
+
return uci_moves # fall back to raw UCI
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def side_to_move(fen: str) -> str:
|
| 122 |
+
parts = fen.split()
|
| 123 |
+
if len(parts) >= 2 and parts[1] == "w":
|
| 124 |
+
return "White"
|
| 125 |
+
return "Black"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_english_doc(row: dict) -> str:
|
| 129 |
+
"""Build a deterministic English description from a Lichess puzzle row."""
|
| 130 |
+
side = side_to_move(row["FEN"])
|
| 131 |
+
phase, length, _, motifs = themes_to_english(row["Themes"] or [])
|
| 132 |
+
opening = opening_tags_to_english(row.get("OpeningTags") or [])
|
| 133 |
+
san = uci_to_san_sequence(row["FEN"], row["Moves"])
|
| 134 |
+
|
| 135 |
+
# Construct sentence
|
| 136 |
+
parts = []
|
| 137 |
+
parts.append(f"{side} to move.")
|
| 138 |
+
|
| 139 |
+
# "Short middlegame puzzle with crushing fork and hanging piece motifs."
|
| 140 |
+
descriptor = []
|
| 141 |
+
if length:
|
| 142 |
+
descriptor.append(length)
|
| 143 |
+
if phase:
|
| 144 |
+
descriptor.append(phase)
|
| 145 |
+
descriptor.append("puzzle")
|
| 146 |
+
descriptor_str = " ".join(descriptor)
|
| 147 |
+
if motifs:
|
| 148 |
+
motifs_str = ", ".join(motifs)
|
| 149 |
+
descriptor_str += f" with {motifs_str} motifs"
|
| 150 |
+
parts.append(descriptor_str.capitalize() + ".")
|
| 151 |
+
|
| 152 |
+
if opening:
|
| 153 |
+
parts.append(f"Opening: {opening}.")
|
| 154 |
+
|
| 155 |
+
if san:
|
| 156 |
+
parts.append(f"Moves: {san}")
|
| 157 |
+
|
| 158 |
+
return " ".join(parts)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def build_english_anchor(row: dict) -> str:
|
| 162 |
+
"""Anchor side: same as before (themes + opening) but in deterministic English.
|
| 163 |
+
Used as query for retrieval/reranker training."""
|
| 164 |
+
phase, length, _, motifs = themes_to_english(row["Themes"] or [])
|
| 165 |
+
opening = opening_tags_to_english(row.get("OpeningTags") or [])
|
| 166 |
+
parts = []
|
| 167 |
+
if motifs:
|
| 168 |
+
parts.append(", ".join(motifs))
|
| 169 |
+
if length:
|
| 170 |
+
parts.append(length)
|
| 171 |
+
if phase:
|
| 172 |
+
parts.append(phase)
|
| 173 |
+
if opening:
|
| 174 |
+
parts.append(opening)
|
| 175 |
+
return " ".join(parts).strip()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def main():
|
| 179 |
+
print("Loading puzzles...")
|
| 180 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 181 |
+
if SMOKE_TEST:
|
| 182 |
+
puzzles = puzzles.select(range(2_000))
|
| 183 |
+
print(f" {len(puzzles):,} rows")
|
| 184 |
+
|
| 185 |
+
print("Converting to English (deterministic)...")
|
| 186 |
+
|
| 187 |
+
def proc(batch):
|
| 188 |
+
ids, anchors, docs = [], [], []
|
| 189 |
+
for r in [{k: batch[k][i] for k in batch} for i in range(len(batch["PuzzleId"]))]:
|
| 190 |
+
if not r["Themes"]:
|
| 191 |
+
continue
|
| 192 |
+
ids.append(r["PuzzleId"])
|
| 193 |
+
anchors.append(build_english_anchor(r))
|
| 194 |
+
docs.append(build_english_doc(r))
|
| 195 |
+
return {"PuzzleId": ids, "anchor_en": anchors, "doc_en": docs}
|
| 196 |
+
|
| 197 |
+
out = puzzles.map(
|
| 198 |
+
proc, batched=True, batch_size=10_000,
|
| 199 |
+
remove_columns=puzzles.column_names,
|
| 200 |
+
num_proc=4,
|
| 201 |
+
)
|
| 202 |
+
print(f" produced {len(out):,} English-converted rows")
|
| 203 |
+
|
| 204 |
+
print("\n=== Sample conversions ===")
|
| 205 |
+
for i in [0, 100, 1000]:
|
| 206 |
+
r = out[i]
|
| 207 |
+
print(f"\nPuzzleId: {r['PuzzleId']}")
|
| 208 |
+
print(f" anchor: {r['anchor_en']!r}")
|
| 209 |
+
print(f" doc: {r['doc_en'][:200]!r}")
|
| 210 |
+
|
| 211 |
+
out.to_parquet(OUTPUT_PATH)
|
| 212 |
+
print(f"\nSaved to {OUTPUT_PATH} ({os.path.getsize(OUTPUT_PATH) / 1e6:.1f} MB)")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
main()
|
scripts/diag_ce_vs_bm25.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"]
|
| 5 |
+
# ///
|
| 6 |
+
"""Compare trained CE vs BM25 on English-bridged docs, plus top-K sweep.
|
| 7 |
+
|
| 8 |
+
Tests:
|
| 9 |
+
1. Is the 0.59 CE result just lexical match that BM25 could also do?
|
| 10 |
+
2. Does increasing K to 200/300 push past oracle 0.59 → 0.77 → 0.87?
|
| 11 |
+
"""
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
from datasets import Dataset, load_dataset
|
| 18 |
+
from rank_bm25 import BM25Okapi
|
| 19 |
+
from sentence_transformers import CrossEncoder, SentenceTransformer
|
| 20 |
+
|
| 21 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 22 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 23 |
+
from convert_to_english import build_english_anchor, build_english_doc
|
| 24 |
+
|
| 25 |
+
HELDOUT_FREQ_MIN = 3
|
| 26 |
+
HELDOUT_FREQ_MAX = 30
|
| 27 |
+
EVAL_QUERIES = 200
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _join_tags(tags):
|
| 31 |
+
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _bigram(m):
|
| 35 |
+
toks = m.split()
|
| 36 |
+
return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_chess_anchor(themes, op):
|
| 40 |
+
tt = _join_tags(themes)
|
| 41 |
+
ot = _join_tags(op or [])
|
| 42 |
+
return tt + (f" {ot}" if ot else "")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build_chess_doc_stripped(themes, op, moves):
|
| 46 |
+
return f"moves {_bigram(moves)}"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def ndcg_at_k(scores, rel, k=10):
|
| 50 |
+
r = sorted(scores, key=lambda kv: -kv[1])[:k]
|
| 51 |
+
dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r))
|
| 52 |
+
idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k)))
|
| 53 |
+
return dcg / idcg if idcg > 0 else 0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
+
print("Building eval set...")
|
| 58 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 59 |
+
freq = defaultdict(int)
|
| 60 |
+
rows_by_anchor = defaultdict(list)
|
| 61 |
+
for r in puzzles:
|
| 62 |
+
if not r["Themes"]:
|
| 63 |
+
continue
|
| 64 |
+
ca = build_chess_anchor(r["Themes"], r["OpeningTags"])
|
| 65 |
+
freq[ca] += 1
|
| 66 |
+
rows_by_anchor[ca].append(r)
|
| 67 |
+
rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1])
|
| 68 |
+
heldout = [a for a, _ in rare[:EVAL_QUERIES]]
|
| 69 |
+
print(f" {len(heldout)} held-out anchors")
|
| 70 |
+
|
| 71 |
+
qchess, qen = [], []
|
| 72 |
+
corp_chess, corp_en = [], []
|
| 73 |
+
held_per_doc = []
|
| 74 |
+
ch_to_en = {}
|
| 75 |
+
for ca in heldout:
|
| 76 |
+
for r in rows_by_anchor[ca]:
|
| 77 |
+
corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"]))
|
| 78 |
+
corp_en.append(build_english_doc(r))
|
| 79 |
+
held_per_doc.append(ca)
|
| 80 |
+
if ca not in ch_to_en:
|
| 81 |
+
ch_to_en[ca] = build_english_anchor(r)
|
| 82 |
+
qchess = list(heldout)
|
| 83 |
+
qen = [ch_to_en[a] for a in qchess]
|
| 84 |
+
by_anchor = defaultdict(list)
|
| 85 |
+
for i, a in enumerate(held_per_doc):
|
| 86 |
+
by_anchor[a].append(i)
|
| 87 |
+
print(f" corpus: {len(corp_chess)} docs")
|
| 88 |
+
|
| 89 |
+
print("\nLoading static (v4-C2) for first-stage...")
|
| 90 |
+
static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final")
|
| 91 |
+
sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
|
| 92 |
+
sc = sc / np.linalg.norm(sc, axis=1, keepdims=True)
|
| 93 |
+
sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
|
| 94 |
+
sq = sq / np.linalg.norm(sq, axis=1, keepdims=True)
|
| 95 |
+
static_sims = sq @ sc.T
|
| 96 |
+
|
| 97 |
+
# Loaded trained CE
|
| 98 |
+
print("Loading trained CE...")
|
| 99 |
+
ce = CrossEncoder("models/chess-reranker-english/final")
|
| 100 |
+
|
| 101 |
+
# BM25 on English docs
|
| 102 |
+
print("Building BM25 over English docs...")
|
| 103 |
+
bm25 = BM25Okapi([d.split() for d in corp_en])
|
| 104 |
+
|
| 105 |
+
print("\n" + "=" * 80)
|
| 106 |
+
print(f" {'K':>4} {'Static':>10} {'+CE':>10} {'+BM25':>10} {'Oracle':>10}")
|
| 107 |
+
print("=" * 80)
|
| 108 |
+
for k in [10, 50, 100, 200, 300]:
|
| 109 |
+
if k > len(corp_chess):
|
| 110 |
+
continue
|
| 111 |
+
static_ndcg = []
|
| 112 |
+
ce_ndcg = []
|
| 113 |
+
bm25_ndcg = []
|
| 114 |
+
oracle_ndcg = []
|
| 115 |
+
for qi, q_chess in enumerate(qchess):
|
| 116 |
+
rel = set(by_anchor[q_chess])
|
| 117 |
+
# Static-only at top-10
|
| 118 |
+
top10 = np.argsort(-static_sims[qi])[:10]
|
| 119 |
+
sp = [(int(i), float(static_sims[qi, int(i)])) for i in top10]
|
| 120 |
+
static_ndcg.append(ndcg_at_k(sp, rel, k=10))
|
| 121 |
+
# Top-K shortlist
|
| 122 |
+
topk = np.argsort(-static_sims[qi])[:k]
|
| 123 |
+
# CE rerank
|
| 124 |
+
pairs = [[qen[qi], corp_en[int(i)]] for i in topk]
|
| 125 |
+
ce_scores = ce.predict(pairs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
|
| 126 |
+
ce_sp = [(int(topk[j]), float(ce_scores[j])) for j in range(len(topk))]
|
| 127 |
+
ce_ndcg.append(ndcg_at_k(ce_sp, rel, k=10))
|
| 128 |
+
# BM25 rerank over top-K shortlist
|
| 129 |
+
bm_full = bm25.get_scores(qen[qi].split())
|
| 130 |
+
bm_sp = [(int(topk[j]), float(bm_full[int(topk[j])])) for j in range(len(topk))]
|
| 131 |
+
bm25_ndcg.append(ndcg_at_k(bm_sp, rel, k=10))
|
| 132 |
+
# Oracle ceiling
|
| 133 |
+
rel_in_topk = len(rel & set(int(i) for i in topk))
|
| 134 |
+
n10 = min(10, rel_in_topk)
|
| 135 |
+
dcg = sum(1.0 / np.log2(r + 2) for r in range(n10))
|
| 136 |
+
idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), 10)))
|
| 137 |
+
oracle_ndcg.append(dcg / idcg if idcg > 0 else 0)
|
| 138 |
+
# static stays the same regardless of K
|
| 139 |
+
static_v = np.mean(static_ndcg)
|
| 140 |
+
print(f" {k:>4} {static_v:>10.4f} {np.mean(ce_ndcg):>10.4f} {np.mean(bm25_ndcg):>10.4f} {np.mean(oracle_ndcg):>10.4f}")
|
| 141 |
+
print("=" * 80)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
main()
|
scripts/generate_theme_defs.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "datasets>=2.19.0",
|
| 6 |
+
# "openai>=1.0",
|
| 7 |
+
# "sentence-transformers[train]>=5.5.0",
|
| 8 |
+
# "tqdm",
|
| 9 |
+
# "numpy",
|
| 10 |
+
# ]
|
| 11 |
+
# ///
|
| 12 |
+
"""Generate natural-language definitions for each Lichess theme via DeepSeek,
|
| 13 |
+
then embed those definitions with a general sentence-transformer (MPNet).
|
| 14 |
+
|
| 15 |
+
The resulting (theme_token, definition_embedding) pairs form a "chess-aware
|
| 16 |
+
teacher" — an English description of each chess concept that MPNet CAN
|
| 17 |
+
understand semantically. We can then distill those embeddings into our
|
| 18 |
+
StaticEmbedding model's token table.
|
| 19 |
+
|
| 20 |
+
Solves the "MPNet doesn't know chess" problem: MPNet can't read UCI moves,
|
| 21 |
+
but it CAN read English ("A tactical motif where one piece attacks two pieces
|
| 22 |
+
simultaneously" → semantically near "A tactic where you create a double
|
| 23 |
+
attack threatening two pieces at once"). Token-level semantic structure
|
| 24 |
+
emerges from the LLM bridge.
|
| 25 |
+
|
| 26 |
+
Run:
|
| 27 |
+
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 generate_theme_defs.py
|
| 28 |
+
uv run --exclude-newer=2026-05-12 generate_theme_defs.py
|
| 29 |
+
"""
|
| 30 |
+
import json
|
| 31 |
+
import os
|
| 32 |
+
import subprocess
|
| 33 |
+
import sys
|
| 34 |
+
from collections import Counter
|
| 35 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
from datasets import Dataset, load_dataset
|
| 39 |
+
from openai import OpenAI
|
| 40 |
+
from sentence_transformers import SentenceTransformer
|
| 41 |
+
from tqdm import tqdm
|
| 42 |
+
|
| 43 |
+
MODEL = "deepseek-v4-flash"
|
| 44 |
+
TEACHER_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
| 45 |
+
OUTPUT_PATH = "models/theme_definitions.parquet"
|
| 46 |
+
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
|
| 47 |
+
PARALLEL_WORKERS = 4
|
| 48 |
+
|
| 49 |
+
SYSTEM_PROMPT = """You write concise dictionary-style definitions of chess
|
| 50 |
+
concepts. Given a theme/concept name (often in camelCase from Lichess.org's
|
| 51 |
+
puzzle tagging system), write a single English sentence of 10-25 words
|
| 52 |
+
explaining the concept. Be specific and use the standard chess vocabulary that
|
| 53 |
+
would appear in any chess textbook.
|
| 54 |
+
|
| 55 |
+
Output ONLY the definition sentence. No labels, no quotes, no commentary.
|
| 56 |
+
|
| 57 |
+
Examples:
|
| 58 |
+
Input: fork
|
| 59 |
+
Output: A tactical motif where a single piece attacks two or more enemy pieces simultaneously, forcing a material gain.
|
| 60 |
+
|
| 61 |
+
Input: backRankMate
|
| 62 |
+
Output: A checkmate delivered along the opponent's back rank, typically with a rook or queen, when the king is trapped by its own pawns.
|
| 63 |
+
|
| 64 |
+
Input: zugzwang
|
| 65 |
+
Output: A position in which any move worsens the player's position, so being forced to move becomes a disadvantage.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_deepseek_key():
|
| 70 |
+
r = subprocess.run(
|
| 71 |
+
["security", "find-generic-password", "-s", "deepseek-api", "-w"],
|
| 72 |
+
capture_output=True, text=True, timeout=5,
|
| 73 |
+
)
|
| 74 |
+
return r.stdout.strip() if r.returncode == 0 else os.environ.get("DEEPSEEK_API_KEY")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def define_theme(client, theme, debug=False):
|
| 78 |
+
try:
|
| 79 |
+
resp = client.chat.completions.create(
|
| 80 |
+
model=MODEL,
|
| 81 |
+
messages=[
|
| 82 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 83 |
+
{"role": "user", "content": theme},
|
| 84 |
+
],
|
| 85 |
+
temperature=0.2,
|
| 86 |
+
max_tokens=1500, # DeepSeek-v4-flash spends tokens on reasoning_content; obscure mate-pattern names need lots
|
| 87 |
+
timeout=30,
|
| 88 |
+
)
|
| 89 |
+
content = resp.choices[0].message.content
|
| 90 |
+
return content.strip() if content else None
|
| 91 |
+
except Exception as e:
|
| 92 |
+
if debug:
|
| 93 |
+
print(f" EXC for {theme!r}: {type(e).__name__}: {e}")
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
key = get_deepseek_key()
|
| 99 |
+
if not key:
|
| 100 |
+
sys.exit("No DeepSeek API key in keychain")
|
| 101 |
+
client = OpenAI(api_key=key, base_url="https://api.deepseek.com/v1")
|
| 102 |
+
|
| 103 |
+
print("Enumerating themes from Lichess puzzles...")
|
| 104 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)
|
| 105 |
+
counter = Counter()
|
| 106 |
+
sample_size = 50_000 if SMOKE_TEST else 1_000_000
|
| 107 |
+
for i, r in enumerate(puzzles):
|
| 108 |
+
if i >= sample_size:
|
| 109 |
+
break
|
| 110 |
+
for t in (r["Themes"] or []):
|
| 111 |
+
counter[t] += 1
|
| 112 |
+
themes = sorted(counter.keys())
|
| 113 |
+
print(f" {len(themes)} unique themes")
|
| 114 |
+
|
| 115 |
+
if SMOKE_TEST:
|
| 116 |
+
themes = themes[:10]
|
| 117 |
+
print(f" SMOKE_TEST=1: limited to {len(themes)}")
|
| 118 |
+
|
| 119 |
+
print(f"\nGenerating definitions via {MODEL}...")
|
| 120 |
+
defs = {}
|
| 121 |
+
with ThreadPoolExecutor(max_workers=PARALLEL_WORKERS) as ex:
|
| 122 |
+
futs = {ex.submit(define_theme, client, t, True): t for t in themes}
|
| 123 |
+
for f in tqdm(as_completed(futs), total=len(futs)):
|
| 124 |
+
t = futs[f]
|
| 125 |
+
defs[t] = f.result()
|
| 126 |
+
|
| 127 |
+
failed = [t for t, d in defs.items() if not d]
|
| 128 |
+
if failed:
|
| 129 |
+
print(f" {len(failed)} themes failed: {failed[:5]}")
|
| 130 |
+
print(f" {len(defs) - len(failed)}/{len(defs)} succeeded")
|
| 131 |
+
|
| 132 |
+
print("\nSample definitions:")
|
| 133 |
+
for t in themes[:8]:
|
| 134 |
+
if defs[t]:
|
| 135 |
+
print(f" {t:>20s} -> {defs[t]}")
|
| 136 |
+
|
| 137 |
+
valid = [(t, defs[t]) for t in themes if defs[t]]
|
| 138 |
+
|
| 139 |
+
print(f"\nEmbedding {len(valid)} definitions with {TEACHER_MODEL}...")
|
| 140 |
+
teacher = SentenceTransformer(TEACHER_MODEL)
|
| 141 |
+
sentences = [d for _, d in valid]
|
| 142 |
+
embs = teacher.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_numpy=True)
|
| 143 |
+
|
| 144 |
+
# Sanity: do related themes have similar embeddings?
|
| 145 |
+
emb_norm = embs / np.linalg.norm(embs, axis=1, keepdims=True)
|
| 146 |
+
sim = emb_norm @ emb_norm.T
|
| 147 |
+
print("\nSanity check: pairwise similarities for related themes")
|
| 148 |
+
name_to_idx = {t: i for i, (t, _) in enumerate(valid)}
|
| 149 |
+
for a, b in [
|
| 150 |
+
("fork", "skewer"), ("fork", "pin"), ("backRankMate", "smotheredMate"),
|
| 151 |
+
("kingsideAttack", "queensideAttack"), ("endgame", "middlegame"),
|
| 152 |
+
("fork", "promotion"), # not directly related
|
| 153 |
+
]:
|
| 154 |
+
if a in name_to_idx and b in name_to_idx:
|
| 155 |
+
print(f" {a!r:>20} <-> {b!r:25} = {sim[name_to_idx[a], name_to_idx[b]]:+.3f}")
|
| 156 |
+
|
| 157 |
+
out = Dataset.from_dict({
|
| 158 |
+
"theme": [t for t, _ in valid],
|
| 159 |
+
"definition": [d for _, d in valid],
|
| 160 |
+
"embedding": embs.tolist(),
|
| 161 |
+
})
|
| 162 |
+
os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True)
|
| 163 |
+
out.to_parquet(OUTPUT_PATH)
|
| 164 |
+
print(f"\nSaved {len(out)} theme definitions to {OUTPUT_PATH}")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
scripts/mine_hard_negs_v2.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "sentence-transformers[train]>=5.5.0",
|
| 6 |
+
# "datasets>=2.19.0",
|
| 7 |
+
# "numpy",
|
| 8 |
+
# "tqdm",
|
| 9 |
+
# ]
|
| 10 |
+
# ///
|
| 11 |
+
"""Memory-bounded hard-negative miner. Custom impl (not sentence-transformers
|
| 12 |
+
util) because the SE function tries to hold the full anchor × corpus similarity
|
| 13 |
+
matrix, which OOMs at 327k anchors × 327k positives on M4.
|
| 14 |
+
|
| 15 |
+
Algorithm:
|
| 16 |
+
1. Encode all unique positives once -> N x dim float32 (~670MB at 327k x 512).
|
| 17 |
+
2. Encode all unique anchors once -> M x dim float32.
|
| 18 |
+
3. For each anchor batch (size B):
|
| 19 |
+
- scores = batch_emb @ positives_emb.T -> B x N
|
| 20 |
+
- per anchor: argpartition for top RANGE_MAX, exclude actual positive,
|
| 21 |
+
sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX).
|
| 22 |
+
4. Stream triplets to parquet.
|
| 23 |
+
|
| 24 |
+
Peak memory: B * N * 4 bytes for scores. With B=500, N=327k: 650MB.
|
| 25 |
+
|
| 26 |
+
Run:
|
| 27 |
+
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py
|
| 28 |
+
uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py
|
| 29 |
+
"""
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import os
|
| 33 |
+
import random
|
| 34 |
+
import re
|
| 35 |
+
import sys
|
| 36 |
+
from collections import defaultdict
|
| 37 |
+
|
| 38 |
+
# Force unbuffered stdout so progress is visible when piped
|
| 39 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 40 |
+
|
| 41 |
+
import numpy as np
|
| 42 |
+
import torch
|
| 43 |
+
from datasets import Dataset, load_dataset
|
| 44 |
+
from sentence_transformers import SentenceTransformer
|
| 45 |
+
from tqdm import tqdm
|
| 46 |
+
|
| 47 |
+
V3_MODEL_PATH = "models/static-embedding-chess/final"
|
| 48 |
+
OUTPUT_PATH = "models/hard_negatives.parquet"
|
| 49 |
+
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
|
| 50 |
+
HELDOUT_FREQ_MIN = 3
|
| 51 |
+
HELDOUT_FREQ_MAX = 30
|
| 52 |
+
EVAL_QUERIES = 200
|
| 53 |
+
NUM_NEGATIVES = 5
|
| 54 |
+
RANGE_MIN = 10
|
| 55 |
+
RANGE_MAX = 50
|
| 56 |
+
ANCHOR_BATCH_SIZE = 500 # 500 * 327k * 4 = ~650MB scratch per batch
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _join_tags(tags):
|
| 60 |
+
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _bigram_token_str(moves):
|
| 64 |
+
toks = moves.split()
|
| 65 |
+
if len(toks) < 2:
|
| 66 |
+
return moves
|
| 67 |
+
bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
|
| 68 |
+
return f"{moves} {bigrams}"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_puzzle_pairs(batch):
|
| 72 |
+
anchors, positives = [], []
|
| 73 |
+
for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
|
| 74 |
+
themes_txt = _join_tags(themes)
|
| 75 |
+
op_txt = _join_tags(op)
|
| 76 |
+
if not themes_txt:
|
| 77 |
+
continue
|
| 78 |
+
anchor = themes_txt + (f" {op_txt}" if op_txt else "")
|
| 79 |
+
positive = f"themes {themes_txt}"
|
| 80 |
+
if op_txt:
|
| 81 |
+
positive += f" opening {op_txt}"
|
| 82 |
+
positive += f" moves {_bigram_token_str(moves)}"
|
| 83 |
+
anchors.append(anchor)
|
| 84 |
+
positives.append(positive)
|
| 85 |
+
return {"anchor": anchors, "positive": positives}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def main():
|
| 89 |
+
print(f"Loading v3 model from {V3_MODEL_PATH}")
|
| 90 |
+
model = SentenceTransformer(V3_MODEL_PATH)
|
| 91 |
+
|
| 92 |
+
print("Loading puzzles...")
|
| 93 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 94 |
+
if SMOKE_TEST:
|
| 95 |
+
puzzles = puzzles.select(range(100_000))
|
| 96 |
+
pair_puzzles = puzzles.map(
|
| 97 |
+
build_puzzle_pairs,
|
| 98 |
+
batched=True,
|
| 99 |
+
batch_size=20_000,
|
| 100 |
+
remove_columns=puzzles.column_names,
|
| 101 |
+
num_proc=4,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Materialize columns ONCE as Python lists (HF Dataset random access is
|
| 105 |
+
# O(N) per call due to Arrow buffer slicing -- 5.8M iterations would take
|
| 106 |
+
# forever otherwise).
|
| 107 |
+
print("Materializing columns...")
|
| 108 |
+
anchors_list = pair_puzzles["anchor"]
|
| 109 |
+
positives_list = pair_puzzles["positive"]
|
| 110 |
+
print(f" done ({len(anchors_list):,} rows)")
|
| 111 |
+
|
| 112 |
+
# Remove held-out anchors
|
| 113 |
+
freq = defaultdict(int)
|
| 114 |
+
for a in anchors_list:
|
| 115 |
+
freq[a] += 1
|
| 116 |
+
rare_pool = sorted(
|
| 117 |
+
((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
|
| 118 |
+
key=lambda kv: kv[1],
|
| 119 |
+
)
|
| 120 |
+
heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]}
|
| 121 |
+
|
| 122 |
+
# Build one-per-anchor (use as both the anchor source AND the corpus source)
|
| 123 |
+
by_anchor = defaultdict(list)
|
| 124 |
+
for a, p in zip(anchors_list, positives_list):
|
| 125 |
+
if a not in heldout:
|
| 126 |
+
by_anchor[a].append(p)
|
| 127 |
+
print(f" unique anchors (post-heldout-strip): {len(by_anchor):,}")
|
| 128 |
+
|
| 129 |
+
rng = random.Random(12)
|
| 130 |
+
unique_anchors = list(by_anchor.keys())
|
| 131 |
+
if SMOKE_TEST:
|
| 132 |
+
unique_anchors = unique_anchors[:200]
|
| 133 |
+
print(f" SMOKE_TEST=1: trimmed to {len(unique_anchors)}")
|
| 134 |
+
# For each anchor, pick ONE random positive (skip the O(n^2) filter -- just
|
| 135 |
+
# iterate unique_anchors directly).
|
| 136 |
+
print(f" Sampling one positive per anchor...")
|
| 137 |
+
positives = [rng.choice(by_anchor[a]) for a in unique_anchors]
|
| 138 |
+
print(f" done")
|
| 139 |
+
|
| 140 |
+
# Encode anchors and positives
|
| 141 |
+
print(f"\nEncoding {len(unique_anchors):,} anchors...")
|
| 142 |
+
anchor_emb = model.encode(
|
| 143 |
+
unique_anchors, batch_size=512, show_progress_bar=True, convert_to_numpy=True
|
| 144 |
+
)
|
| 145 |
+
anchor_emb = anchor_emb / np.linalg.norm(anchor_emb, axis=1, keepdims=True)
|
| 146 |
+
print(f" anchor shape: {anchor_emb.shape}, mem: {anchor_emb.nbytes / 1e6:.1f}MB")
|
| 147 |
+
|
| 148 |
+
print(f"\nEncoding {len(positives):,} positives...")
|
| 149 |
+
positive_emb = model.encode(
|
| 150 |
+
positives, batch_size=512, show_progress_bar=True, convert_to_numpy=True
|
| 151 |
+
)
|
| 152 |
+
positive_emb = positive_emb / np.linalg.norm(positive_emb, axis=1, keepdims=True)
|
| 153 |
+
print(f" positive shape: {positive_emb.shape}, mem: {positive_emb.nbytes / 1e6:.1f}MB")
|
| 154 |
+
|
| 155 |
+
# Mine hard negs in chunks
|
| 156 |
+
print(f"\nMining hard negs (range={RANGE_MIN}..{RANGE_MAX}, num={NUM_NEGATIVES}, batch={ANCHOR_BATCH_SIZE})...")
|
| 157 |
+
out_anchors, out_positives, out_negatives = [], [], []
|
| 158 |
+
pos_scores_acc, neg_scores_acc = [], []
|
| 159 |
+
n_anchors = len(unique_anchors)
|
| 160 |
+
|
| 161 |
+
for start in tqdm(range(0, n_anchors, ANCHOR_BATCH_SIZE)):
|
| 162 |
+
end = min(start + ANCHOR_BATCH_SIZE, n_anchors)
|
| 163 |
+
ab = anchor_emb[start:end] # B x D
|
| 164 |
+
# scores: B x N. Each row i is anchor[start+i] vs all positives.
|
| 165 |
+
scores = ab @ positive_emb.T # B x N (float32)
|
| 166 |
+
|
| 167 |
+
# For each anchor i in batch, sort scores desc, get top RANGE_MAX
|
| 168 |
+
# excluding the actual positive (which is at column start+i).
|
| 169 |
+
# We use argpartition for efficiency.
|
| 170 |
+
for i in range(end - start):
|
| 171 |
+
anchor_idx = start + i
|
| 172 |
+
row = scores[i].copy()
|
| 173 |
+
# Mask out the actual positive (anchor's own positive is at anchor_idx)
|
| 174 |
+
row[anchor_idx] = -np.inf
|
| 175 |
+
# Take top RANGE_MAX indices
|
| 176 |
+
top_idx = np.argpartition(-row, RANGE_MAX)[:RANGE_MAX]
|
| 177 |
+
# Sort them by score
|
| 178 |
+
top_idx = top_idx[np.argsort(-row[top_idx])]
|
| 179 |
+
# Sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX)
|
| 180 |
+
mid_range = top_idx[RANGE_MIN:RANGE_MAX]
|
| 181 |
+
sampled = rng.sample(list(mid_range), min(NUM_NEGATIVES, len(mid_range)))
|
| 182 |
+
for neg_idx in sampled:
|
| 183 |
+
out_anchors.append(unique_anchors[anchor_idx])
|
| 184 |
+
out_positives.append(positives[anchor_idx])
|
| 185 |
+
out_negatives.append(positives[neg_idx])
|
| 186 |
+
pos_scores_acc.append(float(scores[i, anchor_idx]))
|
| 187 |
+
neg_scores_acc.append(float(scores[i, neg_idx]))
|
| 188 |
+
|
| 189 |
+
print(f"\n output triplets: {len(out_anchors):,}")
|
| 190 |
+
print(f" positive scores: mean={np.mean(pos_scores_acc):.3f} std={np.std(pos_scores_acc):.3f}")
|
| 191 |
+
print(f" hard-neg scores: mean={np.mean(neg_scores_acc):.3f} std={np.std(neg_scores_acc):.3f}")
|
| 192 |
+
print(f" margin (pos - neg): mean={np.mean(np.array(pos_scores_acc) - np.array(neg_scores_acc)):.3f}")
|
| 193 |
+
|
| 194 |
+
# Save
|
| 195 |
+
os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True)
|
| 196 |
+
Dataset.from_dict({
|
| 197 |
+
"anchor": out_anchors,
|
| 198 |
+
"positive": out_positives,
|
| 199 |
+
"negative": out_negatives,
|
| 200 |
+
}).to_parquet(OUTPUT_PATH)
|
| 201 |
+
print(f" saved to {OUTPUT_PATH} ({os.path.getsize(OUTPUT_PATH) / 1e6:.1f} MB)")
|
| 202 |
+
|
| 203 |
+
# Sample
|
| 204 |
+
print("\n=== Sample triplets ===")
|
| 205 |
+
for i in [0, len(out_anchors)//2, len(out_anchors)-1]:
|
| 206 |
+
print(f" ANCHOR: {out_anchors[i]!r}")
|
| 207 |
+
print(f" POSITIVE:{out_positives[i][:100]!r}")
|
| 208 |
+
print(f" NEGATIVE:{out_negatives[i][:100]!r}")
|
| 209 |
+
print()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
main()
|
scripts/train_chess_multitask.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "sentence-transformers[train]>=5.5.0",
|
| 6 |
+
# "datasets>=2.19.0",
|
| 7 |
+
# "accelerate>=0.26.0",
|
| 8 |
+
# "tokenizers>=0.20",
|
| 9 |
+
# ]
|
| 10 |
+
# ///
|
| 11 |
+
"""Multi-task training: chess-aware semantic structure + hard-negative MNRL.
|
| 12 |
+
|
| 13 |
+
Two simultaneous training signals:
|
| 14 |
+
|
| 15 |
+
1. THEME-DISTILL dataset: (theme_token, mpnet_definition_emb)
|
| 16 |
+
- 73 rows (one per Lichess theme)
|
| 17 |
+
- Loss: EmbedDistillLoss (project student 512d -> 768d, match teacher)
|
| 18 |
+
- Effect: enc("fork") moves toward MPNet("a tactical motif where one piece...")
|
| 19 |
+
- Solves orthogonal-token-embeddings problem identified in Phase 1
|
| 20 |
+
|
| 21 |
+
2. CHESS-CONTENT dataset: (anchor, positive, hard_negative)
|
| 22 |
+
- From mined hard-negs of v3 model
|
| 23 |
+
- Loss: MultipleNegativesRankingLoss (handles triplets natively)
|
| 24 |
+
- Effect: maintains chess-content associations, sharpens discriminative ability
|
| 25 |
+
|
| 26 |
+
Multi-task trainer interleaves batches from both datasets. The theme dataset is
|
| 27 |
+
tiny (73 rows) but high-impact -- it injects semantic structure into 73 token
|
| 28 |
+
embeddings. The chess dataset is large (1.6M+ triplets) and shapes the rest.
|
| 29 |
+
|
| 30 |
+
Run:
|
| 31 |
+
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_multitask.py
|
| 32 |
+
uv run --exclude-newer=2026-05-12 train_chess_multitask.py
|
| 33 |
+
"""
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
import logging
|
| 37 |
+
import os
|
| 38 |
+
import random
|
| 39 |
+
import re
|
| 40 |
+
import time
|
| 41 |
+
from collections import defaultdict
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
|
| 44 |
+
import numpy as np
|
| 45 |
+
import torch
|
| 46 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
| 47 |
+
from tokenizers import Tokenizer
|
| 48 |
+
|
| 49 |
+
from sentence_transformers import (
|
| 50 |
+
SentenceTransformer,
|
| 51 |
+
SentenceTransformerModelCardData,
|
| 52 |
+
SentenceTransformerTrainer,
|
| 53 |
+
SentenceTransformerTrainingArguments,
|
| 54 |
+
)
|
| 55 |
+
from sentence_transformers.base.sampler import BatchSamplers, MultiDatasetBatchSamplers
|
| 56 |
+
from sentence_transformers.sentence_transformer.evaluation import (
|
| 57 |
+
InformationRetrievalEvaluator,
|
| 58 |
+
)
|
| 59 |
+
from sentence_transformers.sentence_transformer.losses import (
|
| 60 |
+
EmbedDistillLoss,
|
| 61 |
+
MultipleNegativesRankingLoss,
|
| 62 |
+
)
|
| 63 |
+
from sentence_transformers.sentence_transformer.modules import StaticEmbedding
|
| 64 |
+
from transformers import EarlyStoppingCallback, TrainerCallback
|
| 65 |
+
|
| 66 |
+
THEME_DEFS_PATH = "models/theme_definitions.parquet"
|
| 67 |
+
TRIPLETS_PATH = "models/hard_negatives.parquet"
|
| 68 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "models/static-embedding-chess/chess_tokenizer.json")
|
| 69 |
+
OUTPUT_DIR = "models/static-embedding-chess-multitask"
|
| 70 |
+
RUN_NAME = "static-embedding-chess-multitask"
|
| 71 |
+
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
|
| 72 |
+
EMBEDDING_DIM = 512
|
| 73 |
+
TEACHER_DIM = 768
|
| 74 |
+
HELDOUT_FREQ_MIN = 3
|
| 75 |
+
HELDOUT_FREQ_MAX = 30
|
| 76 |
+
EVAL_QUERIES = 200
|
| 77 |
+
THEME_REPLICAS = int(os.environ.get("THEME_REPLICAS", "500")) # oversample theme dataset
|
| 78 |
+
|
| 79 |
+
IS_CUDA = torch.cuda.is_available()
|
| 80 |
+
IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available()
|
| 81 |
+
BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def setup_logging():
|
| 85 |
+
os.makedirs("logs", exist_ok=True)
|
| 86 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 87 |
+
logging.basicConfig(
|
| 88 |
+
format="%(asctime)s - %(message)s",
|
| 89 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 90 |
+
level=logging.INFO,
|
| 91 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")],
|
| 92 |
+
force=True,
|
| 93 |
+
)
|
| 94 |
+
for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
|
| 95 |
+
logging.getLogger(noisy).setLevel(logging.WARNING)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _join_tags(tags):
|
| 99 |
+
return " ".join(t.replace("_", " ") for t in tags) if tags else ""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _bigram_token_str(moves):
|
| 103 |
+
toks = moves.split()
|
| 104 |
+
if len(toks) < 2:
|
| 105 |
+
return moves
|
| 106 |
+
bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
|
| 107 |
+
return f"{moves} {bigrams}"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def build_puzzle_pairs(batch):
|
| 111 |
+
anchors, positives = [], []
|
| 112 |
+
for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
|
| 113 |
+
themes_txt = _join_tags(themes)
|
| 114 |
+
op_txt = _join_tags(op)
|
| 115 |
+
if not themes_txt:
|
| 116 |
+
continue
|
| 117 |
+
anchor = themes_txt + (f" {op_txt}" if op_txt else "")
|
| 118 |
+
positive = f"themes {themes_txt}"
|
| 119 |
+
if op_txt:
|
| 120 |
+
positive += f" opening {op_txt}"
|
| 121 |
+
positive += f" moves {_bigram_token_str(moves)}"
|
| 122 |
+
anchors.append(anchor)
|
| 123 |
+
positives.append(positive)
|
| 124 |
+
return {"anchor": anchors, "positive": positives}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def strip_theme_echo(p):
|
| 128 |
+
i = p.find(" moves ")
|
| 129 |
+
return p[i + 1 :] if i != -1 else p
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def build_evaluator(holdout):
|
| 133 |
+
corpus = {f"d{i}": strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)}
|
| 134 |
+
by_anchor = defaultdict(set)
|
| 135 |
+
for i, row in enumerate(holdout):
|
| 136 |
+
by_anchor[row["anchor"]].add(f"d{i}")
|
| 137 |
+
sorted_a = sorted(by_anchor.items(), key=lambda kv: -len(kv[1]))
|
| 138 |
+
queries = {f"q{i}": a for i, (a, _) in enumerate(sorted_a)}
|
| 139 |
+
relevant = {f"q{i}": ids for i, (_, ids) in enumerate(sorted_a)}
|
| 140 |
+
return InformationRetrievalEvaluator(
|
| 141 |
+
queries=queries, corpus=corpus, relevant_docs=relevant,
|
| 142 |
+
name="chess-ir", ndcg_at_k=[10], mrr_at_k=[10],
|
| 143 |
+
accuracy_at_k=[1, 10], precision_recall_at_k=[1, 10],
|
| 144 |
+
show_progress_bar=False, batch_size=256,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def autocast_ctx():
|
| 149 |
+
if IS_CUDA:
|
| 150 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 151 |
+
return torch.autocast("cuda", dtype=dtype)
|
| 152 |
+
if IS_MPS:
|
| 153 |
+
return torch.autocast("mps", dtype=torch.float16)
|
| 154 |
+
return nullcontext()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def main():
|
| 158 |
+
setup_logging()
|
| 159 |
+
|
| 160 |
+
logging.info(f"Loading tokenizer from {TOKENIZER_PATH}")
|
| 161 |
+
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
| 162 |
+
logging.info(f" vocab: {tokenizer.get_vocab_size():,}")
|
| 163 |
+
|
| 164 |
+
logging.info(f"Building random-init StaticEmbedding (dim={EMBEDDING_DIM})")
|
| 165 |
+
static = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM)
|
| 166 |
+
model = SentenceTransformer(
|
| 167 |
+
modules=[static],
|
| 168 |
+
model_card_data=SentenceTransformerModelCardData(
|
| 169 |
+
language="en", license="apache-2.0",
|
| 170 |
+
model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- multi-task (theme distill + hard-neg MNRL)",
|
| 171 |
+
),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# === Dataset A: theme distillation ===
|
| 175 |
+
logging.info(f"Loading theme definitions from {THEME_DEFS_PATH}")
|
| 176 |
+
theme_ds_full = Dataset.from_parquet(THEME_DEFS_PATH)
|
| 177 |
+
# EmbedDistillLoss expects columns: sentence, label
|
| 178 |
+
theme_ds = theme_ds_full.rename_columns({"theme": "sentence", "embedding": "label"}).remove_columns(["definition"])
|
| 179 |
+
# Oversample to be seen alongside the much-larger chess dataset
|
| 180 |
+
if not SMOKE_TEST:
|
| 181 |
+
theme_ds = concatenate_datasets([theme_ds] * THEME_REPLICAS).shuffle(seed=12)
|
| 182 |
+
logging.info(f" {len(theme_ds):,} theme rows (after oversampling)")
|
| 183 |
+
|
| 184 |
+
# === Dataset B: chess triplets ===
|
| 185 |
+
logging.info(f"Loading triplets from {TRIPLETS_PATH}")
|
| 186 |
+
triplet_ds = Dataset.from_parquet(TRIPLETS_PATH)
|
| 187 |
+
if SMOKE_TEST:
|
| 188 |
+
triplet_ds = triplet_ds.select(range(min(500, len(triplet_ds))))
|
| 189 |
+
logging.info(f" {len(triplet_ds):,} triplets, columns: {triplet_ds.column_names}")
|
| 190 |
+
|
| 191 |
+
# === Build eval (same as previous runs) ===
|
| 192 |
+
logging.info("Building held-out eval")
|
| 193 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 194 |
+
if SMOKE_TEST:
|
| 195 |
+
puzzles = puzzles.select(range(2_000))
|
| 196 |
+
pair_puzzles = puzzles.map(
|
| 197 |
+
build_puzzle_pairs, batched=True, batch_size=20_000,
|
| 198 |
+
remove_columns=puzzles.column_names, num_proc=4,
|
| 199 |
+
)
|
| 200 |
+
anchors = pair_puzzles["anchor"]
|
| 201 |
+
freq = defaultdict(int)
|
| 202 |
+
for a in anchors:
|
| 203 |
+
freq[a] += 1
|
| 204 |
+
rare_pool = sorted(
|
| 205 |
+
((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
|
| 206 |
+
key=lambda kv: kv[1],
|
| 207 |
+
)
|
| 208 |
+
n_eval = 20 if SMOKE_TEST else EVAL_QUERIES
|
| 209 |
+
heldout = {a for a, _ in rare_pool[:n_eval]}
|
| 210 |
+
held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h]
|
| 211 |
+
holdout = pair_puzzles.select(held_idx)
|
| 212 |
+
logging.info(f" holdout: {len(holdout)}")
|
| 213 |
+
evaluator = build_evaluator(holdout)
|
| 214 |
+
|
| 215 |
+
logging.info("Baseline eval (random init):")
|
| 216 |
+
with autocast_ctx():
|
| 217 |
+
baseline = evaluator(model)[evaluator.primary_metric]
|
| 218 |
+
metric_key = f"eval_{evaluator.primary_metric}"
|
| 219 |
+
logging.info(f" baseline {evaluator.primary_metric} = {baseline:.4f}")
|
| 220 |
+
|
| 221 |
+
# === Multi-task setup ===
|
| 222 |
+
train_datasets = {
|
| 223 |
+
"chess": triplet_ds,
|
| 224 |
+
"themes": theme_ds,
|
| 225 |
+
}
|
| 226 |
+
losses = {
|
| 227 |
+
"chess": MultipleNegativesRankingLoss(model),
|
| 228 |
+
"themes": EmbedDistillLoss(model, distance_metric="cosine", projection_dim=TEACHER_DIM),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
args = SentenceTransformerTrainingArguments(
|
| 232 |
+
output_dir=OUTPUT_DIR,
|
| 233 |
+
num_train_epochs=5,
|
| 234 |
+
max_steps=1 if SMOKE_TEST else -1,
|
| 235 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 236 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 237 |
+
learning_rate=1e-2,
|
| 238 |
+
weight_decay=0.01,
|
| 239 |
+
warmup_steps=0.1,
|
| 240 |
+
lr_scheduler_type="linear",
|
| 241 |
+
bf16=IS_CUDA and torch.cuda.is_bf16_supported(),
|
| 242 |
+
fp16=IS_CUDA and not torch.cuda.is_bf16_supported(),
|
| 243 |
+
batch_sampler=BatchSamplers.BATCH_SAMPLER,
|
| 244 |
+
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
|
| 245 |
+
eval_strategy="steps",
|
| 246 |
+
eval_steps=0.05,
|
| 247 |
+
save_strategy="steps",
|
| 248 |
+
save_steps=0.05,
|
| 249 |
+
save_total_limit=2,
|
| 250 |
+
logging_steps=0.02,
|
| 251 |
+
logging_first_step=True,
|
| 252 |
+
load_best_model_at_end=True,
|
| 253 |
+
metric_for_best_model=metric_key,
|
| 254 |
+
greater_is_better=True,
|
| 255 |
+
report_to="none",
|
| 256 |
+
run_name=RUN_NAME,
|
| 257 |
+
seed=12,
|
| 258 |
+
push_to_hub=False,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
trainer = SentenceTransformerTrainer(
|
| 262 |
+
model=model, args=args,
|
| 263 |
+
train_dataset=train_datasets, loss=losses, evaluator=evaluator,
|
| 264 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
| 265 |
+
)
|
| 266 |
+
trainer.train()
|
| 267 |
+
|
| 268 |
+
logging.info("Post-training eval:")
|
| 269 |
+
with autocast_ctx():
|
| 270 |
+
score = evaluator(model)[evaluator.primary_metric]
|
| 271 |
+
delta = score - baseline
|
| 272 |
+
verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
|
| 273 |
+
logging.info(
|
| 274 |
+
f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline:.4f} | delta={delta:+.4f}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Also report current absolute vs v3 baseline (0.080)
|
| 278 |
+
v3_baseline = 0.0801
|
| 279 |
+
logging.info(f" vs v3 (0.0801): delta = {score - v3_baseline:+.4f}")
|
| 280 |
+
|
| 281 |
+
final_dir = f"{OUTPUT_DIR}/final"
|
| 282 |
+
model.save_pretrained(final_dir)
|
| 283 |
+
logging.info(f"Saved final model to {final_dir}")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
scripts/train_chess_static.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "sentence-transformers[train]>=5.5.0",
|
| 6 |
+
# "datasets>=2.19.0",
|
| 7 |
+
# "accelerate>=0.26.0",
|
| 8 |
+
# "tokenizers>=0.20",
|
| 9 |
+
# "trackio",
|
| 10 |
+
# ]
|
| 11 |
+
# ///
|
| 12 |
+
"""Train a StaticEmbedding model for chess retrieval.
|
| 13 |
+
|
| 14 |
+
Pair shape:
|
| 15 |
+
anchor = "<themes> [<opening words>]"
|
| 16 |
+
positive = "themes <themes> [opening <words>] moves <uci>" (puzzles)
|
| 17 |
+
"name <words> eco <code> pgn <san>" (openings)
|
| 18 |
+
|
| 19 |
+
Datasets:
|
| 20 |
+
- Lichess/chess-puzzles (5.8M rows; themes + opening tags + UCI moves)
|
| 21 |
+
- Lichess/chess-openings (3.6K rows; opening name + ECO + SAN moves)
|
| 22 |
+
|
| 23 |
+
Use case: free-text search over a chess corpus. "fork endgame short" -> puzzles
|
| 24 |
+
with that motif; "Sicilian Najdorf" -> matching openings.
|
| 25 |
+
|
| 26 |
+
Design choices:
|
| 27 |
+
- Custom WordLevel + Whitespace tokenizer trained on the corpus. Every chess
|
| 28 |
+
token (UCI move e2e4, SAN move Nxd4, ECO code B90, theme name, opening word)
|
| 29 |
+
is one whole token -- BERT WordPiece would shred them 4-way.
|
| 30 |
+
- FEN dropped: position-as-character-soup doesn't fit a token-bag.
|
| 31 |
+
- PGN move numbers stripped ("1. e4 c5" -> "e4 c5") so SAN moves are high-freq.
|
| 32 |
+
- IR eval is custom (themes -> puzzles), not NanoBEIR -- general-English IR
|
| 33 |
+
benchmarks don't measure chess retrieval.
|
| 34 |
+
|
| 35 |
+
Run:
|
| 36 |
+
SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_static.py
|
| 37 |
+
uv run --exclude-newer=2026-05-12 train_chess_static.py
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
from __future__ import annotations
|
| 41 |
+
|
| 42 |
+
import logging
|
| 43 |
+
import os
|
| 44 |
+
import re
|
| 45 |
+
from collections import defaultdict
|
| 46 |
+
from contextlib import nullcontext
|
| 47 |
+
|
| 48 |
+
import datasets
|
| 49 |
+
import random
|
| 50 |
+
import torch
|
| 51 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
| 52 |
+
from tokenizers import Tokenizer
|
| 53 |
+
from tokenizers.models import WordLevel
|
| 54 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 55 |
+
from tokenizers.trainers import WordLevelTrainer
|
| 56 |
+
|
| 57 |
+
from sentence_transformers import (
|
| 58 |
+
SentenceTransformer,
|
| 59 |
+
SentenceTransformerModelCardData,
|
| 60 |
+
SentenceTransformerTrainer,
|
| 61 |
+
SentenceTransformerTrainingArguments,
|
| 62 |
+
)
|
| 63 |
+
from sentence_transformers.base.sampler import BatchSamplers
|
| 64 |
+
from sentence_transformers.sentence_transformer.evaluation import (
|
| 65 |
+
InformationRetrievalEvaluator,
|
| 66 |
+
SequentialEvaluator,
|
| 67 |
+
)
|
| 68 |
+
from sentence_transformers.sentence_transformer.losses import (
|
| 69 |
+
MatryoshkaLoss,
|
| 70 |
+
MultipleNegativesRankingLoss,
|
| 71 |
+
)
|
| 72 |
+
from sentence_transformers.sentence_transformer.modules import StaticEmbedding
|
| 73 |
+
from transformers import EarlyStoppingCallback, TrainerCallback
|
| 74 |
+
import time
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
EMBEDDING_DIM = 512 # was 256; 512 gives more capacity for bigram tokens
|
| 78 |
+
MATRYOSHKA_DIMS = [512, 256, 128, 64, 32]
|
| 79 |
+
VOCAB_SIZE = 100_000 # was 50_000; UCI/SAN bigrams add ~20-50k vocab
|
| 80 |
+
|
| 81 |
+
OUTPUT_DIR = "models/static-embedding-chess"
|
| 82 |
+
RUN_NAME = "static-embedding-chess"
|
| 83 |
+
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "oneryalcin/static-embedding-chess")
|
| 84 |
+
# TOKENIZER_PATH default lives next to the model output. On Modal, set this to
|
| 85 |
+
# a path on the persistent volume (e.g. /cache/chess_tokenizer.json) so the
|
| 86 |
+
# 6-min WordLevelTrainer run is amortized across launches.
|
| 87 |
+
TOKENIZER_PATH = os.environ.get(
|
| 88 |
+
"TOKENIZER_PATH", f"{OUTPUT_DIR}/chess_tokenizer.json"
|
| 89 |
+
)
|
| 90 |
+
RETRAIN_TOKENIZER = os.environ.get("RETRAIN_TOKENIZER") == "1"
|
| 91 |
+
SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
|
| 92 |
+
FORCE_CPU = os.environ.get("FORCE_CPU") == "1"
|
| 93 |
+
# Diagnostic knobs (default: full recipe). Both MPS and T4 show monotonic
|
| 94 |
+
# step-time growth with the full Matryoshka stack -- toggle these to isolate.
|
| 95 |
+
DISABLE_MATRYOSHKA = os.environ.get("DISABLE_MATRYOSHKA") == "1"
|
| 96 |
+
MAX_STEPS_OVERRIDE = int(os.environ.get("MAX_STEPS", "0")) or None
|
| 97 |
+
EVAL_STEPS_OVERRIDE = int(os.environ.get("EVAL_STEPS", "0")) or None
|
| 98 |
+
|
| 99 |
+
EVAL_QUERIES = 200
|
| 100 |
+
EVAL_CORPUS = 5_000
|
| 101 |
+
# Held-out anchor selection: pick rare combos in this freq range. Low end > 1
|
| 102 |
+
# keeps multi-relevant NDCG meaningful; high end caps memorization potential.
|
| 103 |
+
HELDOUT_FREQ_MIN = 3
|
| 104 |
+
HELDOUT_FREQ_MAX = 30
|
| 105 |
+
# Balanced-dataset config: each unique anchor expands to N (anchor, sampled_pos)
|
| 106 |
+
# rows. The original 5.8M pairs let the model memorize specific (anchor, pos)
|
| 107 |
+
# pairings since each anchor has ~1933 distinct positives. Capping at 100
|
| 108 |
+
# random samples per anchor gives the model meaningful variety without the
|
| 109 |
+
# 50x redundancy that fuels overfitting.
|
| 110 |
+
BALANCED_POSITIVES_PER_ANCHOR = int(os.environ.get("POSITIVES_PER_ANCHOR", "100"))
|
| 111 |
+
# Anchor token masking probability during training. 0 disables.
|
| 112 |
+
ANCHOR_MASK_PROB = float(os.environ.get("ANCHOR_MASK_PROB", "0.15"))
|
| 113 |
+
|
| 114 |
+
# Device-aware defaults. MPS (Apple Silicon) can't do bf16 and has unified-
|
| 115 |
+
# memory pressure, so the CUDA-targeted skill template defaults (batch=2048,
|
| 116 |
+
# bf16=True) don't apply. Scale BATCH_SIZE up if your M-series has 36GB+.
|
| 117 |
+
IS_CUDA = torch.cuda.is_available() and not FORCE_CPU
|
| 118 |
+
IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() and not FORCE_CPU
|
| 119 |
+
# StaticEmbedding is a lookup+average -- no transformer activations to fit.
|
| 120 |
+
# Memory cost is the (batch x batch) similarity matrix + (batch x seq x dim)
|
| 121 |
+
# lookups, both tiny. CachedMultipleNegativesRankingLoss is NOT compatible
|
| 122 |
+
# with StaticEmbedding (no encoder to GradCache through), so we just crank
|
| 123 |
+
# the real batch. Scale up freely if your M-series has the headroom.
|
| 124 |
+
BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256)
|
| 125 |
+
|
| 126 |
+
MOVE_NUM_RE = re.compile(r"\d+\.+")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class StepTimingCallback(TrainerCallback):
|
| 130 |
+
"""Per-step instrumentation: wall time, CUDA memory, allocator state.
|
| 131 |
+
Costs ~1ms/step. Run-once-and-read approach to diagnosing slowdowns
|
| 132 |
+
instead of swapping configs and rerunning.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def on_step_begin(self, args, state, control, **kw):
|
| 136 |
+
if torch.cuda.is_available():
|
| 137 |
+
torch.cuda.synchronize()
|
| 138 |
+
self._t0 = time.perf_counter()
|
| 139 |
+
|
| 140 |
+
def on_step_end(self, args, state, control, **kw):
|
| 141 |
+
if torch.cuda.is_available():
|
| 142 |
+
torch.cuda.synchronize()
|
| 143 |
+
dt = time.perf_counter() - self._t0
|
| 144 |
+
# Log every step for the first 20 to see startup; then every 10th.
|
| 145 |
+
if state.global_step <= 20 or state.global_step % 10 == 0:
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
mem = torch.cuda.memory_allocated() / 1e6
|
| 148 |
+
reserved = torch.cuda.memory_reserved() / 1e6
|
| 149 |
+
logging.info(
|
| 150 |
+
f"STEP {state.global_step}: dt={dt:.3f}s mem={mem:.0f}MB reserved={reserved:.0f}MB"
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
logging.info(f"STEP {state.global_step}: dt={dt:.3f}s (cpu/mps)")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def autocast_ctx():
|
| 157 |
+
if IS_CUDA:
|
| 158 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 159 |
+
return torch.autocast("cuda", dtype=dtype)
|
| 160 |
+
if IS_MPS:
|
| 161 |
+
return torch.autocast("mps", dtype=torch.float16)
|
| 162 |
+
return nullcontext()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def setup_logging():
|
| 166 |
+
os.makedirs("logs", exist_ok=True)
|
| 167 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 168 |
+
logging.basicConfig(
|
| 169 |
+
format="%(asctime)s - %(message)s",
|
| 170 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 171 |
+
level=logging.INFO,
|
| 172 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")],
|
| 173 |
+
force=True,
|
| 174 |
+
)
|
| 175 |
+
for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
|
| 176 |
+
logging.getLogger(noisy).setLevel(logging.WARNING)
|
| 177 |
+
if torch.cuda.is_available():
|
| 178 |
+
torch.set_float32_matmul_precision("high")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _join_tags(tags) -> str:
|
| 182 |
+
if not tags:
|
| 183 |
+
return ""
|
| 184 |
+
return " ".join(t.replace("_", " ") for t in tags)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _strip_pgn_move_numbers(pgn: str) -> str:
|
| 188 |
+
return MOVE_NUM_RE.sub("", pgn).strip()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _bigram_token_str(moves: str) -> str:
|
| 192 |
+
"""Append bigram tokens to a whitespace-separated move sequence.
|
| 193 |
+
|
| 194 |
+
"f2g3 e6e7 b2b1" -> "f2g3 e6e7 b2b1 f2g3+e6e7 e6e7+b2b1"
|
| 195 |
+
|
| 196 |
+
Bigrams use `+` as the join char so they're distinct from unigrams in the
|
| 197 |
+
WordLevel tokenizer's whitespace pretokenizer. A token-bag averaging across
|
| 198 |
+
unigrams alone loses move ordering; adding adjacent-pair tokens lets the
|
| 199 |
+
model learn that "e2e4 e7e5" (king's pawn opening) is its own pattern.
|
| 200 |
+
"""
|
| 201 |
+
tokens = moves.split()
|
| 202 |
+
if len(tokens) < 2:
|
| 203 |
+
return moves
|
| 204 |
+
bigrams = " ".join(f"{a}+{b}" for a, b in zip(tokens, tokens[1:]))
|
| 205 |
+
return f"{moves} {bigrams}"
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_puzzle_pairs(row_batch: dict) -> dict:
|
| 209 |
+
anchors, positives = [], []
|
| 210 |
+
for themes, opening_tags, moves in zip(
|
| 211 |
+
row_batch["Themes"], row_batch["OpeningTags"], row_batch["Moves"]
|
| 212 |
+
):
|
| 213 |
+
themes_txt = _join_tags(themes)
|
| 214 |
+
opening_txt = _join_tags(opening_tags)
|
| 215 |
+
if not themes_txt:
|
| 216 |
+
continue
|
| 217 |
+
anchor = themes_txt + (f" {opening_txt}" if opening_txt else "")
|
| 218 |
+
positive = f"themes {themes_txt}"
|
| 219 |
+
if opening_txt:
|
| 220 |
+
positive += f" opening {opening_txt}"
|
| 221 |
+
positive += f" moves {_bigram_token_str(moves)}"
|
| 222 |
+
anchors.append(anchor)
|
| 223 |
+
positives.append(positive)
|
| 224 |
+
return {"anchor": anchors, "positive": positives}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def build_opening_pairs(row_batch: dict) -> dict:
|
| 228 |
+
anchors, positives = [], []
|
| 229 |
+
for name, eco, pgn in zip(row_batch["name"], row_batch["eco"], row_batch["pgn"]):
|
| 230 |
+
san = _strip_pgn_move_numbers(pgn)
|
| 231 |
+
anchors.append(f"{name} {eco}")
|
| 232 |
+
positives.append(f"name {name} eco {eco} pgn {_bigram_token_str(san)}")
|
| 233 |
+
return {"anchor": anchors, "positive": positives}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def load_chess_pairs() -> tuple[Dataset, Dataset]:
|
| 237 |
+
"""Returns (train, holdout) where the holdout anchors are rare combinations
|
| 238 |
+
NEVER seen in train.
|
| 239 |
+
|
| 240 |
+
Old eval used the top-200 most-common theme strings as queries. The model
|
| 241 |
+
memorized these in training (each appears ~50k times) so eval was a recall
|
| 242 |
+
test on memorized lookups, not generalization. Replaced with compositional
|
| 243 |
+
held-out anchors:
|
| 244 |
+
|
| 245 |
+
- Pick anchor strings with frequency in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX]:
|
| 246 |
+
rare enough to be informative, common enough to have multiple positives
|
| 247 |
+
for multi-relevant eval.
|
| 248 |
+
- REMOVE all pairs with those anchors from train (no leakage).
|
| 249 |
+
- Use those rare anchors as eval queries; the held-out pairs become the
|
| 250 |
+
eval corpus.
|
| 251 |
+
- Individual theme tokens within those anchors still appear *separately*
|
| 252 |
+
in many other training anchors, so the model has learned each token's
|
| 253 |
+
embedding -- it just hasn't seen this particular combination. Tests
|
| 254 |
+
compositional generalization.
|
| 255 |
+
"""
|
| 256 |
+
logging.info("Loading Lichess/chess-puzzles (5.8M rows)")
|
| 257 |
+
puzzles = load_dataset("Lichess/chess-puzzles", split="train")
|
| 258 |
+
if SMOKE_TEST:
|
| 259 |
+
puzzles = puzzles.select(range(2_000))
|
| 260 |
+
pair_puzzles = puzzles.map(
|
| 261 |
+
build_puzzle_pairs,
|
| 262 |
+
batched=True,
|
| 263 |
+
batch_size=10_000,
|
| 264 |
+
remove_columns=puzzles.column_names,
|
| 265 |
+
desc="puzzles -> pairs",
|
| 266 |
+
)
|
| 267 |
+
logging.info(f" built {len(pair_puzzles):,} puzzle pairs")
|
| 268 |
+
|
| 269 |
+
logging.info("Loading Lichess/chess-openings (3.6K rows)")
|
| 270 |
+
openings = load_dataset("Lichess/chess-openings", split="train").remove_columns(["img"])
|
| 271 |
+
pair_openings = openings.map(
|
| 272 |
+
build_opening_pairs,
|
| 273 |
+
batched=True,
|
| 274 |
+
remove_columns=openings.column_names,
|
| 275 |
+
desc="openings -> pairs",
|
| 276 |
+
)
|
| 277 |
+
logging.info(f" built {len(pair_openings):,} opening pairs")
|
| 278 |
+
|
| 279 |
+
# Count anchor frequencies across the puzzle pairs.
|
| 280 |
+
logging.info("Computing anchor frequencies for held-out selection")
|
| 281 |
+
anchors = pair_puzzles["anchor"]
|
| 282 |
+
freq: dict[str, int] = defaultdict(int)
|
| 283 |
+
for a in anchors:
|
| 284 |
+
freq[a] += 1
|
| 285 |
+
logging.info(f" {len(freq):,} unique anchors in puzzle pairs")
|
| 286 |
+
|
| 287 |
+
# Pick rare anchors: each appears in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX] pairs.
|
| 288 |
+
# In smoke mode, lower the min so the tiny corpus still produces enough
|
| 289 |
+
# held-out queries (smoke has ~2k puzzles, most anchors freq 1-2).
|
| 290 |
+
min_freq = 2 if SMOKE_TEST else HELDOUT_FREQ_MIN
|
| 291 |
+
max_freq = HELDOUT_FREQ_MAX
|
| 292 |
+
rare_pool = sorted(
|
| 293 |
+
((a, c) for a, c in freq.items() if min_freq <= c <= max_freq),
|
| 294 |
+
key=lambda kv: kv[1], # ascending: rarest first
|
| 295 |
+
)
|
| 296 |
+
n_queries_target = 20 if SMOKE_TEST else EVAL_QUERIES
|
| 297 |
+
if len(rare_pool) < n_queries_target:
|
| 298 |
+
logging.warning(
|
| 299 |
+
f"Only {len(rare_pool)} anchors in freq range [{HELDOUT_FREQ_MIN},{HELDOUT_FREQ_MAX}]; "
|
| 300 |
+
f"using all of them ({n_queries_target} requested)"
|
| 301 |
+
)
|
| 302 |
+
heldout_anchors = {a for a, _ in rare_pool[:n_queries_target]}
|
| 303 |
+
logging.info(
|
| 304 |
+
f" selected {len(heldout_anchors)} held-out anchors "
|
| 305 |
+
f"(freq range: {rare_pool[0][1] if rare_pool else 0}..{rare_pool[min(n_queries_target, len(rare_pool))-1][1] if rare_pool else 0})"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Filter: pairs whose anchor is held-out -> eval; everything else -> train.
|
| 309 |
+
held_mask = [a in heldout_anchors for a in anchors]
|
| 310 |
+
holdout = pair_puzzles.select([i for i, h in enumerate(held_mask) if h])
|
| 311 |
+
train_puzzles = pair_puzzles.select([i for i, h in enumerate(held_mask) if not h])
|
| 312 |
+
logging.info(
|
| 313 |
+
f" split by held-out anchors: train={len(train_puzzles):,}, holdout={len(holdout):,}"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Train includes the (non-held) puzzle pairs + all openings.
|
| 317 |
+
train = concatenate_datasets([train_puzzles, pair_openings]).shuffle(seed=12)
|
| 318 |
+
logging.info(f" train: {len(train):,} pairs | holdout: {len(holdout):,} pairs")
|
| 319 |
+
return train, holdout
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def make_balanced_dataset(train: Dataset, n_per_anchor: int) -> Dataset:
|
| 323 |
+
"""Cap each anchor's positives to `n_per_anchor` random picks. Breaks the
|
| 324 |
+
5.8M pairs' redundancy (each anchor x ~1933 positives) so the model can't
|
| 325 |
+
memorize specific (anchor, positive) pairings while still seeing useful
|
| 326 |
+
positive variety per anchor.
|
| 327 |
+
"""
|
| 328 |
+
by_anchor: dict[str, list[str]] = defaultdict(list)
|
| 329 |
+
for row in train:
|
| 330 |
+
by_anchor[row["anchor"]].append(row["positive"])
|
| 331 |
+
rng = random.Random(12)
|
| 332 |
+
new_anchors, new_positives = [], []
|
| 333 |
+
for anchor, positives in by_anchor.items():
|
| 334 |
+
sample = (
|
| 335 |
+
rng.sample(positives, n_per_anchor)
|
| 336 |
+
if len(positives) > n_per_anchor
|
| 337 |
+
else positives
|
| 338 |
+
)
|
| 339 |
+
for p in sample:
|
| 340 |
+
new_anchors.append(anchor)
|
| 341 |
+
new_positives.append(p)
|
| 342 |
+
logging.info(
|
| 343 |
+
f"Balanced dataset: {len(by_anchor):,} unique anchors -> "
|
| 344 |
+
f"{len(new_anchors):,} pairs (cap {n_per_anchor}/anchor)"
|
| 345 |
+
)
|
| 346 |
+
return Dataset.from_dict({"anchor": new_anchors, "positive": new_positives}).shuffle(seed=12)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def make_anchor_masker(mask_prob: float, rng_seed: int = 12):
|
| 350 |
+
"""Return a `set_transform` callable that randomly replaces theme tokens
|
| 351 |
+
with [UNK] in the anchor. Token-bag dropout: forces the model to use
|
| 352 |
+
remaining tokens instead of memorizing the exact combination."""
|
| 353 |
+
if mask_prob <= 0:
|
| 354 |
+
return None
|
| 355 |
+
rng = random.Random(rng_seed)
|
| 356 |
+
|
| 357 |
+
def _mask(batch: dict) -> dict:
|
| 358 |
+
anchors = batch["anchor"]
|
| 359 |
+
new_anchors = []
|
| 360 |
+
for a in anchors:
|
| 361 |
+
tokens = a.split()
|
| 362 |
+
if len(tokens) <= 1:
|
| 363 |
+
new_anchors.append(a)
|
| 364 |
+
continue
|
| 365 |
+
kept = [t if rng.random() >= mask_prob else "[UNK]" for t in tokens]
|
| 366 |
+
# Guard against masking everything: if all UNK, restore one random token.
|
| 367 |
+
if all(t == "[UNK]" for t in kept):
|
| 368 |
+
kept[rng.randrange(len(kept))] = tokens[rng.randrange(len(tokens))]
|
| 369 |
+
new_anchors.append(" ".join(kept))
|
| 370 |
+
return {"anchor": new_anchors, "positive": batch["positive"]}
|
| 371 |
+
|
| 372 |
+
return _mask
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def train_chess_tokenizer(train: Dataset) -> Tokenizer:
|
| 376 |
+
"""Train or load a WordLevel tokenizer for the chess corpus.
|
| 377 |
+
|
| 378 |
+
Every space-separated unit (theme word, opening word, ECO code, UCI move,
|
| 379 |
+
SAN move) becomes one whole token. Compare to BERT WordPiece which fragments
|
| 380 |
+
"f2g3" into 4 subword pieces -- a token-bag wastes capacity on subword joins
|
| 381 |
+
that carry no chess meaning.
|
| 382 |
+
|
| 383 |
+
Caching: if TOKENIZER_PATH exists, load and return it instead of rebuilding.
|
| 384 |
+
The WordLevelTrainer is single-threaded Rust and takes ~6 min on 11.6M
|
| 385 |
+
strings. Tokenizer is deterministic given the same corpus + config, so
|
| 386 |
+
caching is safe. Set RETRAIN_TOKENIZER=1 to force rebuild.
|
| 387 |
+
"""
|
| 388 |
+
if not RETRAIN_TOKENIZER and os.path.exists(TOKENIZER_PATH):
|
| 389 |
+
tok = Tokenizer.from_file(TOKENIZER_PATH)
|
| 390 |
+
logging.info(
|
| 391 |
+
f"Reusing cached tokenizer ({tok.get_vocab_size():,} tokens) from {TOKENIZER_PATH}"
|
| 392 |
+
)
|
| 393 |
+
return tok
|
| 394 |
+
|
| 395 |
+
logging.info(f"Training WordLevel tokenizer on {len(train):,} pairs (vocab={VOCAB_SIZE})")
|
| 396 |
+
tok = Tokenizer(WordLevel(unk_token="[UNK]"))
|
| 397 |
+
tok.pre_tokenizer = Whitespace()
|
| 398 |
+
trainer = WordLevelTrainer(
|
| 399 |
+
vocab_size=VOCAB_SIZE,
|
| 400 |
+
special_tokens=["[UNK]", "[PAD]"],
|
| 401 |
+
min_frequency=2,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def text_iter():
|
| 405 |
+
for row in train:
|
| 406 |
+
yield row["anchor"]
|
| 407 |
+
yield row["positive"]
|
| 408 |
+
|
| 409 |
+
tok.train_from_iterator(text_iter(), trainer=trainer, length=2 * len(train))
|
| 410 |
+
actual_vocab = tok.get_vocab_size()
|
| 411 |
+
logging.info(f" tokenizer trained: {actual_vocab:,} tokens (cap was {VOCAB_SIZE:,})")
|
| 412 |
+
os.makedirs(os.path.dirname(TOKENIZER_PATH) or ".", exist_ok=True)
|
| 413 |
+
tok.save(TOKENIZER_PATH)
|
| 414 |
+
logging.info(f" saved tokenizer to {TOKENIZER_PATH}")
|
| 415 |
+
return tok
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _strip_theme_echo(positive: str) -> str:
|
| 419 |
+
"""Eval corpus must not echo the themes the query asks about, or the
|
| 420 |
+
baseline (random-init) scores high just from lexical token overlap. Keep
|
| 421 |
+
only the moves segment."""
|
| 422 |
+
idx = positive.find(" moves ")
|
| 423 |
+
return positive[idx + 1 :] if idx != -1 else positive
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def _build_compositional_ir_evaluator(
|
| 427 |
+
holdout: Dataset, corpus: dict[str, str], name: str
|
| 428 |
+
) -> InformationRetrievalEvaluator:
|
| 429 |
+
"""Compositional: each unseen anchor string is a query."""
|
| 430 |
+
by_anchor: dict[str, set[str]] = defaultdict(set)
|
| 431 |
+
for i, row in enumerate(holdout):
|
| 432 |
+
by_anchor[row["anchor"]].add(f"d{i}")
|
| 433 |
+
sorted_anchors = sorted(by_anchor.items(), key=lambda kv: -len(kv[1]))
|
| 434 |
+
queries = {f"q{i}": anchor for i, (anchor, _) in enumerate(sorted_anchors)}
|
| 435 |
+
relevant_docs = {f"q{i}": docs for i, (_, docs) in enumerate(sorted_anchors)}
|
| 436 |
+
avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
|
| 437 |
+
logging.info(
|
| 438 |
+
f" [{name}] {len(queries)} queries (unseen combos), avg relevant/query={avg_rel:.1f}"
|
| 439 |
+
)
|
| 440 |
+
return _ir_evaluator(queries, corpus, relevant_docs, name)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _build_single_theme_ir_evaluator(
|
| 444 |
+
holdout: Dataset, corpus: dict[str, str], name: str
|
| 445 |
+
) -> InformationRetrievalEvaluator:
|
| 446 |
+
"""Single-theme: each individual theme token from the held-out anchors is
|
| 447 |
+
a query. Tests whether per-token embeddings are useful in isolation.
|
| 448 |
+
|
| 449 |
+
Relevant docs for query "fork" = any held-out doc whose anchor contains
|
| 450 |
+
the token "fork". Coarser than the compositional eval (much higher avg
|
| 451 |
+
relevant/query) but a sharper test of token-level meaning.
|
| 452 |
+
"""
|
| 453 |
+
theme_to_docs: dict[str, set[str]] = defaultdict(set)
|
| 454 |
+
for i, row in enumerate(holdout):
|
| 455 |
+
for token in row["anchor"].split():
|
| 456 |
+
theme_to_docs[token].add(f"d{i}")
|
| 457 |
+
min_relevant = 2 if SMOKE_TEST else 3
|
| 458 |
+
candidates = [(t, d) for t, d in theme_to_docs.items() if len(d) >= min_relevant]
|
| 459 |
+
candidates.sort(key=lambda kv: -len(kv[1]))
|
| 460 |
+
queries = {f"t{i}": tok for i, (tok, _) in enumerate(candidates)}
|
| 461 |
+
relevant_docs = {f"t{i}": docs for i, (_, docs) in enumerate(candidates)}
|
| 462 |
+
avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
|
| 463 |
+
logging.info(
|
| 464 |
+
f" [{name}] {len(queries)} single-token queries, avg relevant/query={avg_rel:.1f}"
|
| 465 |
+
)
|
| 466 |
+
return _ir_evaluator(queries, corpus, relevant_docs, name)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _ir_evaluator(queries, corpus, relevant_docs, name):
|
| 470 |
+
return InformationRetrievalEvaluator(
|
| 471 |
+
queries=queries,
|
| 472 |
+
corpus=corpus,
|
| 473 |
+
relevant_docs=relevant_docs,
|
| 474 |
+
name=name,
|
| 475 |
+
ndcg_at_k=[10],
|
| 476 |
+
mrr_at_k=[10],
|
| 477 |
+
accuracy_at_k=[1, 10],
|
| 478 |
+
precision_recall_at_k=[1, 10],
|
| 479 |
+
show_progress_bar=False,
|
| 480 |
+
batch_size=256,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def build_ir_evaluator(holdout: Dataset, name: str = "chess-ir") -> SequentialEvaluator:
|
| 485 |
+
"""Wraps two evaluators (compositional + single-theme) into a sequential
|
| 486 |
+
pass. The compositional one's score drives best-model selection; the
|
| 487 |
+
single-theme one is informational.
|
| 488 |
+
"""
|
| 489 |
+
corpus = {f"d{i}": _strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)}
|
| 490 |
+
logging.info(f"IR eval setup ({len(corpus)} corpus docs):")
|
| 491 |
+
compositional = _build_compositional_ir_evaluator(holdout, corpus, name=name)
|
| 492 |
+
single_theme = _build_single_theme_ir_evaluator(holdout, corpus, name=f"{name}-tokens")
|
| 493 |
+
# First evaluator's score drives load_best_model_at_end (compositional).
|
| 494 |
+
return SequentialEvaluator(
|
| 495 |
+
[compositional, single_theme],
|
| 496 |
+
main_score_function=lambda scores: scores[0],
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def main() -> None:
|
| 501 |
+
setup_logging()
|
| 502 |
+
|
| 503 |
+
train_dataset, holdout = load_chess_pairs()
|
| 504 |
+
if SMOKE_TEST:
|
| 505 |
+
train_dataset = train_dataset.select(range(min(500, len(train_dataset))))
|
| 506 |
+
|
| 507 |
+
# Train the tokenizer on the FULL (pre-balanced) corpus -- we want every
|
| 508 |
+
# token to be seen as many times as possible for the vocab pass.
|
| 509 |
+
tokenizer = train_chess_tokenizer(train_dataset)
|
| 510 |
+
|
| 511 |
+
# Now down-sample to a balanced dataset for the contrastive training.
|
| 512 |
+
train_dataset = make_balanced_dataset(train_dataset, BALANCED_POSITIVES_PER_ANCHOR)
|
| 513 |
+
|
| 514 |
+
# Optional anchor-token masking applied on the fly via set_transform.
|
| 515 |
+
masker = make_anchor_masker(ANCHOR_MASK_PROB)
|
| 516 |
+
if masker is not None:
|
| 517 |
+
logging.info(f"Anchor token masking enabled (p={ANCHOR_MASK_PROB})")
|
| 518 |
+
train_dataset.set_transform(masker)
|
| 519 |
+
|
| 520 |
+
logging.info(f"Random-init StaticEmbedding (dim={EMBEDDING_DIM})")
|
| 521 |
+
static_embedding = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM)
|
| 522 |
+
model = SentenceTransformer(
|
| 523 |
+
modules=[static_embedding],
|
| 524 |
+
model_card_data=SentenceTransformerModelCardData(
|
| 525 |
+
language="en",
|
| 526 |
+
license="apache-2.0",
|
| 527 |
+
model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- themes/openings <-> positions",
|
| 528 |
+
),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
evaluator = build_ir_evaluator(holdout)
|
| 532 |
+
inner = MultipleNegativesRankingLoss(model)
|
| 533 |
+
if DISABLE_MATRYOSHKA:
|
| 534 |
+
logging.info("Matryoshka DISABLED -- training at single dim (diagnostic)")
|
| 535 |
+
loss = inner
|
| 536 |
+
else:
|
| 537 |
+
loss = MatryoshkaLoss(model, inner, matryoshka_dims=MATRYOSHKA_DIMS)
|
| 538 |
+
|
| 539 |
+
logging.info("Baseline evaluation (random init -- expect near-zero):")
|
| 540 |
+
with autocast_ctx():
|
| 541 |
+
baseline_eval = evaluator(model)[evaluator.primary_metric]
|
| 542 |
+
metric_key = f"eval_{evaluator.primary_metric}"
|
| 543 |
+
logging.info(f" baseline {evaluator.primary_metric} = {baseline_eval:.4f}")
|
| 544 |
+
|
| 545 |
+
if SMOKE_TEST:
|
| 546 |
+
max_steps = 1
|
| 547 |
+
elif MAX_STEPS_OVERRIDE:
|
| 548 |
+
max_steps = MAX_STEPS_OVERRIDE
|
| 549 |
+
else:
|
| 550 |
+
max_steps = -1
|
| 551 |
+
eval_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 # 20 evals/run
|
| 552 |
+
save_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05
|
| 553 |
+
|
| 554 |
+
args = SentenceTransformerTrainingArguments(
|
| 555 |
+
output_dir=OUTPUT_DIR,
|
| 556 |
+
# Balanced dataset is small (~300k pairs); need many epochs to reach
|
| 557 |
+
# comparable total training signal. Early stopping handles excess.
|
| 558 |
+
num_train_epochs=20,
|
| 559 |
+
max_steps=max_steps,
|
| 560 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 561 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 562 |
+
learning_rate=1e-2, # was 5e-2 -- much slower convergence, shifts peak later
|
| 563 |
+
weight_decay=0.01, # was 0.0 -- regularization on the embedding table
|
| 564 |
+
warmup_steps=0.1,
|
| 565 |
+
lr_scheduler_type="linear",
|
| 566 |
+
bf16=IS_CUDA and torch.cuda.is_bf16_supported(),
|
| 567 |
+
fp16=IS_CUDA and not torch.cuda.is_bf16_supported(),
|
| 568 |
+
# was NO_DUPLICATES -- linked-list scan over deferred conflicts gives
|
| 569 |
+
# O(epoch_progress) per-batch cost. With ~3000 unique anchors over
|
| 570 |
+
# 5.8M pairs, dedup is fighting impossible odds. BATCH_SAMPLER (random)
|
| 571 |
+
# is fast and accepts mild within-batch anchor duplication.
|
| 572 |
+
batch_sampler=BatchSamplers.BATCH_SAMPLER,
|
| 573 |
+
eval_strategy="steps",
|
| 574 |
+
eval_steps=eval_steps,
|
| 575 |
+
save_strategy="steps",
|
| 576 |
+
save_steps=save_steps,
|
| 577 |
+
save_total_limit=2,
|
| 578 |
+
logging_steps=0.01,
|
| 579 |
+
logging_first_step=True,
|
| 580 |
+
load_best_model_at_end=True,
|
| 581 |
+
metric_for_best_model=metric_key,
|
| 582 |
+
greater_is_better=True,
|
| 583 |
+
# Trackio crashes at first checkpoint push: empty `router_mapping`
|
| 584 |
+
# struct can't be written to parquet. Disable.
|
| 585 |
+
report_to="none",
|
| 586 |
+
run_name=RUN_NAME,
|
| 587 |
+
seed=12,
|
| 588 |
+
# HF Jobs: container is destroyed after run -- push every checkpoint to
|
| 589 |
+
# the Hub so partial progress survives a timeout. The end-of-run
|
| 590 |
+
# model.push_to_hub() below is the belt to this suspenders.
|
| 591 |
+
push_to_hub=not SMOKE_TEST,
|
| 592 |
+
hub_model_id=HUB_MODEL_ID,
|
| 593 |
+
hub_strategy="every_save",
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
trainer = SentenceTransformerTrainer(
|
| 597 |
+
model=model,
|
| 598 |
+
args=args,
|
| 599 |
+
train_dataset=train_dataset,
|
| 600 |
+
loss=loss,
|
| 601 |
+
evaluator=evaluator,
|
| 602 |
+
callbacks=[
|
| 603 |
+
# Auto-stop if compositional NDCG@10 doesn't improve for 3 evals.
|
| 604 |
+
# Lower lr makes curves smoother -- give it slack vs the patience=2
|
| 605 |
+
# we used at lr=5e-2.
|
| 606 |
+
EarlyStoppingCallback(early_stopping_patience=3),
|
| 607 |
+
# Per-step memory + dt logging.
|
| 608 |
+
StepTimingCallback(),
|
| 609 |
+
],
|
| 610 |
+
)
|
| 611 |
+
trainer.train()
|
| 612 |
+
|
| 613 |
+
logging.info("Post-training evaluation:")
|
| 614 |
+
with autocast_ctx():
|
| 615 |
+
score = evaluator(model)[evaluator.primary_metric]
|
| 616 |
+
delta = score - baseline_eval
|
| 617 |
+
verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
|
| 618 |
+
logging.info(
|
| 619 |
+
f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline_eval:.4f} | delta={delta:+.4f}"
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
final_dir = f"{OUTPUT_DIR}/final"
|
| 623 |
+
model.save_pretrained(final_dir)
|
| 624 |
+
logging.info(f"Saved final model to {final_dir}")
|
| 625 |
+
|
| 626 |
+
if SMOKE_TEST:
|
| 627 |
+
logging.info("SMOKE_TEST=1: skipping Hub push")
|
| 628 |
+
return
|
| 629 |
+
|
| 630 |
+
try:
|
| 631 |
+
commit_url = model.push_to_hub(HUB_MODEL_ID)
|
| 632 |
+
logging.info(f"Pushed model to {commit_url.rsplit('/commit/', 1)[0]}")
|
| 633 |
+
except Exception:
|
| 634 |
+
import traceback
|
| 635 |
+
|
| 636 |
+
logging.error(f"Hub push failed:\n{traceback.format_exc()}")
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
if __name__ == "__main__":
|
| 640 |
+
main()
|