oneryalcin commited on
Commit
f8392aa
·
verified ·
1 Parent(s): f194aba

Add files using upload-large-folder tool

Browse files
README.md CHANGED
@@ -1,538 +1,416 @@
1
  ---
2
- language:
3
- - en
4
  license: apache-2.0
 
 
5
  tags:
6
  - sentence-transformers
7
- - sentence-similarity
8
- - feature-extraction
9
- - generated_from_trainer
10
- - dataset_size:1619946
11
- - loss:MatryoshkaLoss
12
- - loss:MultipleNegativesRankingLoss
13
- widget:
14
- - source_sentence: kingsideAttack master [UNK] mateIn1 oneMove [UNK] [UNK] Defense
15
- Sicilian Defense [UNK] Attack
16
- sentences:
17
- - themes kingsideAttack master mate mateIn1 oneMove opening opening Sicilian Defense
18
- Sicilian Defense Nyezhmetdinov-Rossolimo Attack moves f3e5 c6g2 f3e5+c6g2
19
- - themes crushing middlegame queensideAttack sacrifice veryLong moves d7c7 b3e6
20
- f7e6 e1e6 c8b8 f6d7 c7d7 e6d7 d7c7+b3e6 b3e6+f7e6 f7e6+e1e6 e1e6+c8b8 c8b8+f6d7
21
- f6d7+c7d7 c7d7+e6d7
22
- - themes advancedPawn crushing endgame veryLong zugzwang moves d4e6 c4e6 f7e6 h7g6
23
- f8g8 f6f7 g8f8 g6f6 e6e5 f6e5 d4e6+c4e6 c4e6+f7e6 f7e6+h7g6 h7g6+f8g8 f8g8+f6f7
24
- f6f7+g8f8 g8f8+g6f6 g6f6+e6e5 e6e5+f6e5
25
- - source_sentence: crushing intermezzo master middlegame sacrifice veryLong
26
- sentences:
27
- - themes crushing endgame master masterVsMaster veryLong moves f5f6 c5e6 h5g6 h7g6
28
- c3f3 d5b4 f3c6 b4c6 f5f6+c5e6 c5e6+h5g6 h5g6+h7g6 h7g6+c3f3 c3f3+d5b4 d5b4+f3c6
29
- f3c6+b4c6
30
- - themes advancedPawn advantage endgame long master promotion rookEndgame moves
31
- h3h2 g1g2 g3g2 a6a7 h2h1q a7b8q h3h2+g1g2 g1g2+g3g2 g3g2+a6a7 a6a7+h2h1q h2h1q+a7b8q
32
- - themes crushing intermezzo master middlegame sacrifice veryLong moves a6c4 d6f6
33
- f1f6 h6h1 g1f2 h8f6 f2e2 f6e7 a6c4+d6f6 d6f6+f1f6 f1f6+h6h1 h6h1+g1f2 g1f2+h8f6
34
- h8f6+f2e2 f2e2+f6e7
35
- - source_sentence: advantage hangingPiece middlegame short Nimzo-Larsen Attack Nimzo-Larsen
36
- Attack Modern [UNK]
37
- sentences:
38
- - themes hangingPiece mate mateIn1 middlegame oneMove opening Trompowsky Attack
39
- Trompowsky Attack Classical Defense moves f4g4 d8d1 f4g4+d8d1
40
- - themes advancedPawn crushing defensiveMove endgame master quietMove veryLong moves
41
- f1e1 h3h2 f8h8 f5h4 h8e5 g3g2 e5e4 h4f3 f1e1+h3h2 h3h2+f8h8 f8h8+f5h4 f5h4+h8e5
42
- h8e5+g3g2 g3g2+e5e4 e5e4+h4f3
43
- - themes advantage hangingPiece middlegame short opening Nimzo-Larsen Attack Nimzo-Larsen
44
- Attack Modern Variation moves f5d7 b5g5 e3e2 d1d2 f5d7+b5g5 b5g5+e3e2 e3e2+d1d2
45
- - source_sentence: '[UNK] defensiveMove [UNK] [UNK] veryLong'
46
- sentences:
47
- - themes advantage discoveredAttack exposedKing middlegame trappedPiece veryLong
48
- opening French Defense French Defense Orthoschnapp Gambit moves e2d1 c4e3 d2e3
49
- b5f1 d1d2 f1g2 g1e2 g2h1 e2d1+c4e3 c4e3+d2e3 d2e3+b5f1 b5f1+d1d2 d1d2+f1g2 f1g2+g1e2
50
- g1e2+g2h1
51
- - themes crushing defensiveMove enPassant middlegame veryLong moves g2e2 a3f3 f7f5
52
- e5f6 c4f4 g3f4 e2g2 f3g3 g2e2+a3f3 a3f3+f7f5 f7f5+e5f6 e5f6+c4f4 c4f4+g3f4 g3f4+e2g2
53
- e2g2+f3g3
54
- - themes advancedPawn bishopEndgame crushing defensiveMove endgame veryLong moves
55
- f3e4 a3a2 g6g7 e6f7 e5e6 f7g8 e6e7 c5e7 f3e4+a3a2 a3a2+g6g7 g6g7+e6f7 e6f7+e5e6
56
- e5e6+f7g8 f7g8+e6e7 e6e7+c5e7
57
- - source_sentence: '[UNK] deflection discoveredAttack [UNK] queensideAttack short
58
- Philidor Defense [UNK] Defense Other variations'
59
- sentences:
60
- - themes crushing middlegame pin queensideAttack short opening Sicilian Defense
61
- Sicilian Defense Najdorf Variation moves c3d5 c5b3 c1b1 b3d2 c3d5+c5b3 c5b3+c1b1
62
- c1b1+b3d2
63
- - themes crushing deflection discoveredAttack middlegame queensideAttack short opening
64
- Philidor Defense Philidor Defense Other variations moves d3c3 d4b3 c1b1 d7d1 d3c3+d4b3
65
- d4b3+c1b1 c1b1+d7d1
66
- - themes advantage discoveredAttack middlegame short opening Philidor Defense Philidor
67
- Defense Other variations moves e4d4 d3f5 c8b8 d1d4 e4d4+d3f5 d3f5+c8b8 c8b8+d1d4
68
- pipeline_tag: sentence-similarity
69
- library_name: sentence-transformers
70
- metrics:
71
- - cosine_accuracy@1
72
- - cosine_accuracy@10
73
- - cosine_precision@1
74
- - cosine_precision@10
75
- - cosine_recall@1
76
- - cosine_recall@10
77
- - cosine_ndcg@10
78
- - cosine_mrr@10
79
- - cosine_map@100
80
- model-index:
81
- - name: Static chess embedding (512d) -- themes/openings <-> positions
82
- results:
83
- - task:
84
- type: information-retrieval
85
- name: Information Retrieval
86
- dataset:
87
- name: chess ir
88
- type: chess-ir
89
- metrics:
90
- - type: cosine_accuracy@1
91
- value: 0.005
92
- name: Cosine Accuracy@1
93
- - type: cosine_accuracy@10
94
- value: 0.07
95
- name: Cosine Accuracy@10
96
- - type: cosine_precision@1
97
- value: 0.005
98
- name: Cosine Precision@1
99
- - type: cosine_precision@10
100
- value: 0.008
101
- name: Cosine Precision@10
102
- - type: cosine_recall@1
103
- value: 0.0016666666666666666
104
- name: Cosine Recall@1
105
- - type: cosine_recall@10
106
- value: 0.02666666666666666
107
- name: Cosine Recall@10
108
- - type: cosine_ndcg@10
109
- value: 0.01682968253099316
110
- name: Cosine Ndcg@10
111
- - type: cosine_mrr@10
112
- value: 0.020728174603174603
113
- name: Cosine Mrr@10
114
- - type: cosine_map@100
115
- value: 0.014144217882495914
116
- name: Cosine Map@100
117
- - task:
118
- type: information-retrieval
119
- name: Information Retrieval
120
- dataset:
121
- name: chess ir tokens
122
- type: chess-ir-tokens
123
- metrics:
124
- - type: cosine_accuracy@1
125
- value: 0.07936507936507936
126
- name: Cosine Accuracy@1
127
- - type: cosine_accuracy@10
128
- value: 0.25925925925925924
129
- name: Cosine Accuracy@10
130
- - type: cosine_precision@1
131
- value: 0.07936507936507936
132
- name: Cosine Precision@1
133
- - type: cosine_precision@10
134
- value: 0.06031746031746032
135
- name: Cosine Precision@10
136
- - type: cosine_recall@1
137
- value: 0.00224439005944158
138
- name: Cosine Recall@1
139
- - type: cosine_recall@10
140
- value: 0.023957890091684336
141
- name: Cosine Recall@10
142
- - type: cosine_ndcg@10
143
- value: 0.067202690066618
144
- name: Cosine Ndcg@10
145
- - type: cosine_mrr@10
146
- value: 0.12332031578063325
147
- name: Cosine Mrr@10
148
- - type: cosine_map@100
149
- value: 0.03321093573791526
150
- name: Cosine Map@100
151
  ---
152
 
153
- # Static chess embedding (512d) -- themes/openings <-> positions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- This is a [sentence-transformers](https://www.SBERT.net) model trained. It maps sentences & paragraphs to a 512-dimensional dense vector space and can be used for retrieval.
156
 
157
- ## Model Details
158
 
159
- ### Model Description
160
- - **Model Type:** Sentence Transformer
161
- <!-- - **Base model:** [Unknown](https://huggingface.co/unknown) -->
162
- - **Maximum Sequence Length:** inf tokens
163
- - **Output Dimensionality:** 512 dimensions
164
- - **Similarity Function:** Cosine Similarity
165
- - **Supported Modality:** Text
166
- <!-- - **Training Dataset:** Unknown -->
167
- - **Language:** en
168
- - **License:** apache-2.0
169
 
170
- ### Model Sources
 
 
171
 
172
- - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
173
- - **Repository:** [Sentence Transformers on GitHub](https://github.com/huggingface/sentence-transformers)
174
- - **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
175
 
176
- ### Full Model Architecture
 
 
177
 
178
  ```
179
- SentenceTransformer(
180
- (0): StaticEmbedding({})
181
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  ```
183
 
184
- ## Usage
185
 
186
- ### Direct Usage (Sentence Transformers)
187
 
188
- First install the Sentence Transformers library:
 
189
 
190
- ```bash
191
- pip install -U sentence-transformers
192
- ```
193
- Then you can load this model and run inference.
194
- ```python
195
- from sentence_transformers import SentenceTransformer
196
 
197
- # Download from the 🤗 Hub
198
- model = SentenceTransformer("oneryalcin/static-embedding-chess")
199
- # Run inference
200
- queries = [
201
- '[UNK] deflection discoveredAttack [UNK] queensideAttack short Philidor Defense [UNK] Defense Other variations',
202
- ]
203
- documents = [
204
- 'themes crushing deflection discoveredAttack middlegame queensideAttack short opening Philidor Defense Philidor Defense Other variations moves d3c3 d4b3 c1b1 d7d1 d3c3+d4b3 d4b3+c1b1 c1b1+d7d1',
205
- 'themes advantage discoveredAttack middlegame short opening Philidor Defense Philidor Defense Other variations moves e4d4 d3f5 c8b8 d1d4 e4d4+d3f5 d3f5+c8b8 c8b8+d1d4',
206
- 'themes crushing middlegame pin queensideAttack short opening Sicilian Defense Sicilian Defense Najdorf Variation moves c3d5 c5b3 c1b1 b3d2 c3d5+c5b3 c5b3+c1b1 c1b1+b3d2',
207
- ]
208
- query_embeddings = model.encode_query(queries)
209
- document_embeddings = model.encode_document(documents)
210
- print(query_embeddings.shape, document_embeddings.shape)
211
- # [1, 512] [3, 512]
212
-
213
- # Get the similarity scores for the embeddings
214
- similarities = model.similarity(query_embeddings, document_embeddings)
215
- print(similarities)
216
- # tensor([[0.8405, 0.5061, 0.2136]])
217
- ```
218
- <!--
219
- ### Direct Usage (Transformers)
220
-
221
- <details><summary>Click to see the direct usage in Transformers</summary>
222
-
223
- </details>
224
- -->
225
-
226
- <!--
227
- ### Downstream Usage (Sentence Transformers)
228
-
229
- You can finetune this model on your own dataset.
230
-
231
- <details><summary>Click to expand</summary>
232
-
233
- </details>
234
- -->
235
-
236
- <!--
237
- ### Out-of-Scope Use
238
-
239
- *List how the model may foreseeably be misused and address what users ought not to do with the model.*
240
- -->
241
-
242
- ## Evaluation
243
-
244
- ### Metrics
245
-
246
- #### Information Retrieval
247
-
248
- * Datasets: `chess-ir` and `chess-ir-tokens`
249
- * Evaluated with [<code>InformationRetrievalEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.sentence_transformer.evaluation.InformationRetrievalEvaluator)
250
-
251
- | Metric | chess-ir | chess-ir-tokens |
252
- |:--------------------|:-----------|:----------------|
253
- | cosine_accuracy@1 | 0.005 | 0.0794 |
254
- | cosine_accuracy@10 | 0.07 | 0.2593 |
255
- | cosine_precision@1 | 0.005 | 0.0794 |
256
- | cosine_precision@10 | 0.008 | 0.0603 |
257
- | cosine_recall@1 | 0.0017 | 0.0022 |
258
- | cosine_recall@10 | 0.0267 | 0.024 |
259
- | **cosine_ndcg@10** | **0.0168** | **0.0672** |
260
- | cosine_mrr@10 | 0.0207 | 0.1233 |
261
- | cosine_map@100 | 0.0141 | 0.0332 |
262
-
263
- <!--
264
- ## Bias, Risks and Limitations
265
-
266
- *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
267
- -->
268
-
269
- <!--
270
- ### Recommendations
271
-
272
- *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
273
- -->
274
-
275
- ## Training Details
276
-
277
- ### Training Dataset
278
-
279
- #### Unnamed Dataset
280
-
281
- * Size: 1,619,946 training samples
282
- * Columns: <code>anchor</code> and <code>positive</code>
283
- * Approximate statistics based on the first 100 samples:
284
- | | anchor | positive |
285
- |:---------|:------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------|
286
- | type | string | string |
287
- | modality | text | text |
288
- | details | <ul><li>min: 21 characters</li><li>mean: 75.57 characters</li><li>max: 122 characters</li></ul> | <ul><li>min: 86 characters</li><li>mean: 158.13 characters</li><li>max: 256 characters</li></ul> |
289
- * Samples:
290
- | anchor | positive |
291
- |:---------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
292
- | <code>kingsideAttack mate mateIn1 middlegame oneMove Horwitz Defense Horwitz Defense [UNK] variations</code> | <code>themes kingsideAttack mate mateIn1 middlegame oneMove opening Horwitz Defense Horwitz Defense Other variations moves f7h8 g6g2 f7h8+g6g2</code> |
293
- | <code>backRankMate endgame mate mateIn2 short Kings Knight Opening Kings Knight Opening [UNK] [UNK]</code> | <code>themes backRankMate endgame mate mateIn2 short opening Kings Knight Opening Kings Knight Opening Other variations moves c5d4 c3c8 g5d8 c8d8 c5d4+c3c8 c3c8+g5d8 g5d8+c8d8</code> |
294
- | <code>kingsideAttack mate mateIn1 middlegame oneMove Sicilian Defense Sicilian Defense Paulsen-Basman Defense</code> | <code>themes kingsideAttack mate mateIn1 middlegame oneMove opening Sicilian Defense Sicilian Defense Paulsen-Basman Defense moves g3f3 c7h2 g3f3+c7h2</code> |
295
- * Loss: [<code>MatryoshkaLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#matryoshkaloss) with these parameters:
296
- ```json
297
- {
298
- "loss": "MultipleNegativesRankingLoss",
299
- "matryoshka_dims": [
300
- 512,
301
- 256,
302
- 128,
303
- 64,
304
- 32
305
- ],
306
- "matryoshka_weights": [
307
- 1,
308
- 1,
309
- 1,
310
- 1,
311
- 1
312
- ],
313
- "n_dims_per_step": -1
314
- }
315
- ```
316
-
317
- ### Training Hyperparameters
318
- #### Non-Default Hyperparameters
319
-
320
- - `per_device_train_batch_size`: 4096
321
- - `num_train_epochs`: 20
322
- - `learning_rate`: 0.01
323
- - `warmup_steps`: 0.1
324
- - `weight_decay`: 0.01
325
- - `per_device_eval_batch_size`: 4096
326
- - `push_to_hub`: True
327
- - `hub_model_id`: oneryalcin/static-embedding-chess
328
- - `load_best_model_at_end`: True
329
- - `seed`: 12
330
-
331
- #### All Hyperparameters
332
- <details><summary>Click to expand</summary>
333
-
334
- - `per_device_train_batch_size`: 4096
335
- - `num_train_epochs`: 20
336
- - `max_steps`: -1
337
- - `learning_rate`: 0.01
338
- - `lr_scheduler_type`: linear
339
- - `lr_scheduler_kwargs`: None
340
- - `warmup_steps`: 0.1
341
- - `optim`: adamw_torch_fused
342
- - `optim_args`: None
343
- - `weight_decay`: 0.01
344
- - `adam_beta1`: 0.9
345
- - `adam_beta2`: 0.999
346
- - `adam_epsilon`: 1e-08
347
- - `optim_target_modules`: None
348
- - `gradient_accumulation_steps`: 1
349
- - `average_tokens_across_devices`: True
350
- - `max_grad_norm`: 1.0
351
- - `label_smoothing_factor`: 0.0
352
- - `bf16`: False
353
- - `fp16`: False
354
- - `bf16_full_eval`: False
355
- - `fp16_full_eval`: False
356
- - `tf32`: None
357
- - `gradient_checkpointing`: False
358
- - `gradient_checkpointing_kwargs`: None
359
- - `torch_compile`: False
360
- - `torch_compile_backend`: None
361
- - `torch_compile_mode`: None
362
- - `use_liger_kernel`: False
363
- - `liger_kernel_config`: None
364
- - `use_cache`: False
365
- - `neftune_noise_alpha`: None
366
- - `torch_empty_cache_steps`: None
367
- - `auto_find_batch_size`: False
368
- - `log_on_each_node`: True
369
- - `logging_nan_inf_filter`: True
370
- - `include_num_input_tokens_seen`: no
371
- - `log_level`: passive
372
- - `log_level_replica`: warning
373
- - `disable_tqdm`: False
374
- - `project`: huggingface
375
- - `trackio_space_id`: None
376
- - `trackio_bucket_id`: None
377
- - `trackio_static_space_id`: None
378
- - `per_device_eval_batch_size`: 4096
379
- - `prediction_loss_only`: True
380
- - `eval_on_start`: False
381
- - `eval_do_concat_batches`: True
382
- - `eval_use_gather_object`: False
383
- - `eval_accumulation_steps`: None
384
- - `include_for_metrics`: []
385
- - `batch_eval_metrics`: False
386
- - `save_only_model`: False
387
- - `save_on_each_node`: False
388
- - `enable_jit_checkpoint`: False
389
- - `push_to_hub`: True
390
- - `hub_private_repo`: None
391
- - `hub_model_id`: oneryalcin/static-embedding-chess
392
- - `hub_strategy`: every_save
393
- - `hub_always_push`: False
394
- - `hub_revision`: None
395
- - `load_best_model_at_end`: True
396
- - `ignore_data_skip`: False
397
- - `restore_callback_states_from_checkpoint`: False
398
- - `full_determinism`: False
399
- - `seed`: 12
400
- - `data_seed`: None
401
- - `use_cpu`: False
402
- - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
403
- - `parallelism_config`: None
404
- - `dataloader_drop_last`: False
405
- - `dataloader_num_workers`: 0
406
- - `dataloader_pin_memory`: True
407
- - `dataloader_persistent_workers`: False
408
- - `dataloader_prefetch_factor`: None
409
- - `remove_unused_columns`: True
410
- - `label_names`: None
411
- - `train_sampling_strategy`: random
412
- - `length_column_name`: length
413
- - `ddp_find_unused_parameters`: None
414
- - `ddp_bucket_cap_mb`: None
415
- - `ddp_broadcast_buffers`: False
416
- - `ddp_static_graph`: None
417
- - `ddp_backend`: None
418
- - `ddp_timeout`: 1800
419
- - `fsdp`: []
420
- - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
421
- - `deepspeed`: None
422
- - `debug`: []
423
- - `skip_memory_metrics`: True
424
- - `do_predict`: False
425
- - `resume_from_checkpoint`: None
426
- - `warmup_ratio`: None
427
- - `local_rank`: -1
428
- - `prompts`: None
429
- - `batch_sampler`: batch_sampler
430
- - `multi_dataset_batch_sampler`: proportional
431
- - `router_mapping`: {}
432
- - `learning_rate_mapping`: {}
433
-
434
- </details>
435
-
436
- ### Training Logs
437
- | Epoch | Step | Training Loss | chess-ir_cosine_ndcg@10 | chess-ir-tokens_cosine_ndcg@10 |
438
- |:------:|:----:|:-------------:|:-----------------------:|:------------------------------:|
439
- | -1 | -1 | - | 0.0123 | 0.0561 |
440
- | 0.0025 | 1 | 27.3123 | - | - |
441
- | 0.2020 | 80 | 26.3304 | - | - |
442
- | 0.4040 | 160 | 22.2114 | - | - |
443
- | 0.6061 | 240 | 17.4522 | - | - |
444
- | 0.8081 | 320 | 12.8864 | - | - |
445
- | 1.0 | 396 | - | 0.0800 | 0.1181 |
446
- | 1.0101 | 400 | 9.1439 | - | - |
447
- | 1.2121 | 480 | 6.5434 | - | - |
448
- | 1.4141 | 560 | 4.9138 | - | - |
449
- | 1.6162 | 640 | 3.9819 | - | - |
450
- | 1.8182 | 720 | 3.4584 | - | - |
451
- | 2.0 | 792 | - | 0.0505 | 0.0938 |
452
- | 2.0202 | 800 | 3.1303 | - | - |
453
- | 2.2222 | 880 | 2.9652 | - | - |
454
- | 2.4242 | 960 | 2.8584 | - | - |
455
- | 2.6263 | 1040 | 2.7907 | - | - |
456
- | 2.8283 | 1120 | 2.7475 | - | - |
457
- | 3.0 | 1188 | - | 0.0251 | 0.0830 |
458
- | 3.0303 | 1200 | 2.7031 | - | - |
459
- | 3.2323 | 1280 | 2.6927 | - | - |
460
- | 3.4343 | 1360 | 2.6516 | - | - |
461
- | 3.6364 | 1440 | 2.6441 | - | - |
462
- | 3.8384 | 1520 | 2.6202 | - | - |
463
- | 4.0 | 1584 | - | 0.0168 | 0.0672 |
464
-
465
-
466
- ### Training Time
467
- - **Training**: 4.1 minutes
468
- - **Evaluation**: 0.2 seconds
469
- - **Total**: 4.1 minutes
470
-
471
- ### Framework Versions
472
- - Python: 3.12.10
473
- - Sentence Transformers: 5.5.0
474
- - Transformers: 5.8.0
475
- - PyTorch: 2.11.0
476
- - Accelerate: 1.13.0
477
- - Datasets: 4.8.5
478
- - Tokenizers: 0.22.2
479
 
480
- ## Citation
 
 
 
 
 
481
 
482
- ### BibTeX
483
-
484
- #### Sentence Transformers
485
- ```bibtex
486
- @inproceedings{reimers-2019-sentence-bert,
487
- title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
488
- author = "Reimers, Nils and Gurevych, Iryna",
489
- booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
490
- month = "11",
491
- year = "2019",
492
- publisher = "Association for Computational Linguistics",
493
- url = "https://arxiv.org/abs/1908.10084",
494
- }
495
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- #### MatryoshkaLoss
498
- ```bibtex
499
- @misc{kusupati2024matryoshka,
500
- title={Matryoshka Representation Learning},
501
- author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
502
- year={2024},
503
- eprint={2205.13147},
504
- archivePrefix={arXiv},
505
- primaryClass={cs.LG}
506
- }
507
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
- #### MultipleNegativesRankingLoss
510
- ```bibtex
511
- @misc{oord2019representationlearningcontrastivepredictive,
512
- title={Representation Learning with Contrastive Predictive Coding},
513
- author={Aaron van den Oord and Yazhe Li and Oriol Vinyals},
514
- year={2019},
515
- eprint={1807.03748},
516
- archivePrefix={arXiv},
517
- primaryClass={cs.LG},
518
- url={https://arxiv.org/abs/1807.03748},
519
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  ```
521
 
522
- <!--
523
- ## Glossary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
- *Clearly define terms in order to be accessible across audiences.*
526
- -->
 
 
 
 
 
 
527
 
528
- <!--
529
- ## Model Card Authors
530
 
531
- *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
532
- -->
533
 
534
- <!--
535
- ## Model Card Contact
 
 
 
 
 
 
536
 
537
- *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
538
- -->
 
 
1
  ---
2
+ language: en
 
3
  license: apache-2.0
4
+ library_name: sentence-transformers
5
+ pipeline_tag: sentence-similarity
6
  tags:
7
  - sentence-transformers
8
+ - static-embedding
9
+ - chess
10
+ - retrieval
11
+ - exploratory
12
+ datasets:
13
+ - Lichess/chess-puzzles
14
+ - Lichess/chess-openings
15
+ ---
16
+
17
+ # Chess Static Embedding (v4-C2) — Open Exploration
18
+
19
+ A 4M-parameter `StaticEmbedding` model for chess content retrieval, plus the
20
+ full **open-science methodology document** describing what we tried, what
21
+ worked, what failed, and why.
22
+
23
+ This repo is **exploratory experimental work**, published as-is. The model is
24
+ genuinely useful (NDCG@10 = 0.12 on a compositional held-out eval, 50× smaller
25
+ than typical retrieval encoders) but the bigger contribution is the
26
+ **methodology narrative** below particularly the *LLM-bridge* and
27
+ *deterministic-bridge* findings.
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ---
30
 
31
+ ## Quick start
32
+
33
+ ```python
34
+ from sentence_transformers import SentenceTransformer
35
+
36
+ model = SentenceTransformer("oneryalcin/static-embedding-chess")
37
+ query = "fork endgame short"
38
+ docs = [
39
+ "themes crushing endgame fork short opening Sicilian Defense moves f2g3 e6e7",
40
+ "themes mate mateIn1 oneMove opening Caro-Kann moves d2d4 e7e5",
41
+ ]
42
+ sims = model.encode(query) @ model.encode(docs).T
43
+ ```
44
+
45
+ Static embedding: lookup table + average. Sub-millisecond CPU inference. No GPU
46
+ required.
47
 
48
+ ---
49
 
50
+ ## Headline result
51
 
52
+ | Variant | NDCG@10 | vs random init |
53
+ |---------|---------|---------------|
54
+ | v3 baseline (random init + MNRL) | 0.0801 | — |
55
+ | v4-A hard-neg only | 0.1000 | +25% |
56
+ | v4-B theme distill only | 0.0112 | -86% (regression — see methodology) |
57
+ | v4-C multitask 500× | 0.1154 | +44% |
58
+ | **v4-C2 multitask 5000× (this model)** | **0.1202** | **+50%** |
 
 
 
59
 
60
+ Held-out eval: 200 unseen anchor combinations × 600-doc corpus. Compositional
61
+ generalization — the model never saw these exact theme combinations during
62
+ training, only the individual tokens in other combos.
63
 
64
+ For **production-ready** chess search, see the **two-stage architecture** below
65
+ (static + BM25 over English-bridged docs) that delivers NDCG@10 = 0.59-0.87.
 
66
 
67
+ ---
68
+
69
+ ## What's in this repo
70
 
71
  ```
72
+ model.safetensors # 4M-param StaticEmbedding weights (~9MB)
73
+ chess_tokenizer.json # WordLevel chess tokenizer (4,336 tokens)
74
+ tokenizer.json # Same, in HF format for ST loading
75
+ config_sentence_transformers.json # Module config
76
+ modules.json # Module pipeline
77
+
78
+ data/
79
+ ├── theme_definitions.parquet # 73 chess themes + LLM-generated English defs + MPNet embeddings (the LLM-bridge teacher signal)
80
+ ├── hard_negatives_chess.parquet # 1.6M (anchor, positive, negative) triplets, chess-token format
81
+ └── hard_negatives_english.parquet # Same, English-bridged via deterministic conversion
82
+
83
+ scripts/
84
+ ├── train_chess_static.py # Main training entrypoint (multi-version, env-flag controlled)
85
+ ├── train_chess_multitask.py # The v4-C2 winning recipe (theme distill + hard-neg MNRL)
86
+ ├── convert_to_english.py # Deterministic chess→English (no LLM needed; python-chess + regex)
87
+ ├── mine_hard_negs_v2.py # Memory-bounded custom hard-negative miner
88
+ ├── generate_theme_defs.py # LLM-bridge: DeepSeek-v4-flash writes chess concept definitions
89
+ ├── compare_variants.py # Side-by-side eval framework across all variants
90
+ └── diag_ce_vs_bm25.py # The critical "is your CE really helping" diagnostic
91
  ```
92
 
93
+ ---
94
 
95
+ ## Methodology the full experimental journey
96
 
97
+ This was 36+ hours of iterative exploration. The model is the small visible
98
+ output; the methodology is the bigger contribution.
99
 
100
+ ### 1. Problem and approach
 
 
 
 
 
101
 
102
+ **Task:** Free-text search over a chess puzzle corpus. User types something
103
+ like `"fork endgame short"` and gets matching Lichess puzzles.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ **Why static embedding:** Tom Aarsen's
106
+ [static-retrieval-mrl-en-v1](https://huggingface.co/sentence-transformers/static-retrieval-mrl-en-v1)
107
+ showed StaticEmbedding can be a useful retrieval primitive with the right
108
+ training. We adapted the recipe for a chess-specific domain with a custom
109
+ WordLevel tokenizer so chess tokens (UCI moves, theme names, ECO codes) are
110
+ first-class.
111
 
112
+ **Data:** Lichess/chess-puzzles (5.8M puzzles, CC0) + Lichess/chess-openings
113
+ (3.6K openings, CC0).
114
+
115
+ ### 2. Eval design — the hardest part
116
+
117
+ **Initial mistake:** First eval used top-200 most-common theme strings as
118
+ queries. The model had seen each of these ~50,000 times in training. Baseline
119
+ NDCG@10 was inflated to 0.81 by lexical overlap before any training. Useless.
120
+
121
+ **Fixed eval (used throughout):** *Compositional held-out anchors*. Pick 200
122
+ theme-combination strings that appear exactly 3 times in the data
123
+ (rare-but-multi-relevant), remove all matching pairs from train, use those rare
124
+ combos as queries. Tests whether the model can compose meaning from individual
125
+ theme tokens it learned, without having seen the specific combination.
126
+
127
+ This is harsh — the model can never "memorize" the eval queries — and that's
128
+ the point. Random-init baseline drops to NDCG@10 ≈ 0.01.
129
+
130
+ ### 3. Phase 1 — diagnostic of the v3 model (0.08 NDCG@10)
131
+
132
+ A working baseline existed. Question: **why isn't it better?**
133
+
134
+ Token-similarity probe revealed the core issue:
135
+
136
+ | Pair | v3 cosine similarity |
137
+ |---|---|
138
+ | `fork` ↔ `pin` | +0.01 |
139
+ | `fork` ↔ `skewer` | -0.12 |
140
+ | `endgame` ↔ `middlegame` | -0.30 |
141
+
142
+ **Token embeddings were essentially orthogonal.** The model learned per-token
143
+ mappings to chess-content clusters but no relationships *between* tokens.
144
+ Compositional generalization (the eval task) requires those relationships.
145
+
146
+ Also discovered: 51% of held-out queries returned zero relevant in top-10
147
+ (median NDCG@10 = 0). Bimodal failure pattern.
148
+
149
+ Also discovered: model beat BM25 by 7.5× (0.08 vs 0.01), confirming it does
150
+ real semantic work beyond keyword match.
151
+
152
+ ### 4. Phase 2 — distillation from raw MPNet (DEAD END)
153
+
154
+ Hypothesis: distill student token embeddings to match teacher (MPNet)
155
+ embeddings. Teacher knows English; should know that `fork ≈ pin`.
156
+
157
+ **Result:** REGRESSION. Why? **MPNet itself scores NDCG@10 = 0.0094 on our
158
+ eval.** 95.5% of queries get zero in top-10. MPNet doesn't know chess: UCI
159
+ moves are character soup to its WordPiece tokenizer.
160
+
161
+ **You can't distill what the teacher doesn't know.** This was the first key
162
+ lesson.
163
+
164
+ ### 5. Phase 3 — LLM-bridge for theme distillation (BREAKTHROUGH)
165
+
166
+ Key insight: an LLM can read both chess (in camelCase) AND English. Use it as
167
+ a **translator** to put chess concepts into language MPNet *can* understand
168
+ semantically.
169
+
170
+ **Steps:**
171
+
172
+ 1. DeepSeek-v4-flash writes English definitions for 73 Lichess themes:
173
+ - `fork` → "A tactical motif where a single piece attacks two or more
174
+ enemy pieces simultaneously, forcing a material gain."
175
+ 2. MPNet embeds the *English definitions* (it knows English fluently).
176
+ 3. Distill the student's per-token embedding to match the definition embedding.
177
+
178
+ After step 2 alone, MPNet's `fork ↔ skewer` similarity jumps from 0.39 (raw
179
+ camelCase) to **0.87** (via definitions). Real semantic structure.
180
+
181
+ Combined with hard-negative MNRL training (v4-C2): **NDCG@10 = 0.1202**, +50%
182
+ over v3.
183
+
184
+ Cost: 73 themes × DeepSeek API ≈ $0.01 + ~1 minute generation.
185
+
186
+ This is the **LLM-bridge** pattern: when system A doesn't speak system B's
187
+ language, use an LLM as a translator. The LLM is one-shot work, not part of
188
+ inference.
189
+
190
+ ### 6. Phase 4 — hard-negative mining
191
+
192
+ Used the v3 model to mine confusable documents per anchor. Custom
193
+ memory-bounded miner because the sentence-transformers built-in OOMs on M4 at
194
+ 327k unique anchors × 327k positives. See `scripts/mine_hard_negs_v2.py`.
195
+
196
+ 1.6M triplets mined. Positive-negative margin: 0.135 mean (good signal for
197
+ training).
198
+
199
+ ### 7. Phase 5 — multi-task training (v4-C2 winner)
200
+
201
+ Multi-dataset trainer combining:
202
+ - **Chess triplets** (1.6M, MNRL loss): teaches content associations
203
+ - **Theme distillation** (73 themes × 5000 replicas via `EmbedDistillLoss`):
204
+ injects semantic structure between tokens
205
+
206
+ With proportional sampling, theme tokens see ~500 gradient updates per epoch
207
+ (via replication) vs chess pairs once. Theme distillation oversampling matters:
208
+
209
+ | Theme replicas | NDCG@10 |
210
+ |---|---|
211
+ | 500× | 0.1154 |
212
+ | 5000× | 0.1202 |
213
+
214
+ ### 8. Phase 6 — cross-encoder reranker attempts (ALL FAILED)
215
+
216
+ Tried three variants:
217
+ - MS-MARCO MiniLM (English-pretrained, 22M params) on chess-format docs
218
+ - Same, with theme echo stripped from training docs
219
+ - Fresh-init tiny BERT (5M params) with our chess tokenizer
220
+
221
+ **All regressed below static-only.** Diagnosis: trained CEs operate at
222
+ random-ordering level on the eval. Inspection of training predictions showed
223
+ the trained CE got pair-ordering wrong 2/3 of the time on sample inputs.
224
+
225
+ **Root cause:** documents are UCI move sequences (`f2g3 e6e7 ...`). To
226
+ English-pretrained CE tokenizers these are character fragments with no
227
+ meaningful representation. The CE can't learn what makes a "fork-y" move
228
+ sequence from sparse labels alone. Static embedding worked because token-bag
229
+ averaging is sample-efficient (each `fork` token gets gradients from many
230
+ examples → converges to a useful cluster); the CE's pair-level processing is
231
+ hungrier for signal not available in our data.
232
+
233
+ ### 9. Phase 7 — deterministic English bridge for documents (REVEALED THE TRUTH)
234
+
235
+ Insight: we don't need an LLM to translate documents either. `python-chess`
236
+ deterministically converts UCI → SAN with board context (`f2g3` → `Bxg3`).
237
+ Regex decamelizes themes (`backRankMate` → `back rank mate`). Free, instant,
238
+ reproducible. The `convert_to_english.py` script does the full 5.8M corpus in
239
+ ~3 minutes.
240
+
241
+ Re-ran reranker training on English-bridged docs. **Untrained MS-MARCO CE hit
242
+ the oracle ceiling (0.5947 at top-100).** Massive jump.
243
+
244
+ But: ran a final diagnostic comparing trained CE vs **BM25** over the same
245
+ English docs. They were *identical*:
246
+
247
+ | K | Static | +CE | +BM25 | Oracle |
248
+ |---|---|---|---|---|
249
+ | 100 | 0.1202 | **0.5947** | **0.5947** | 0.5947 |
250
+ | 200 | 0.1202 | 0.7706 | 0.7706 | 0.7706 |
251
+ | 300 | 0.1202 | 0.8718 | 0.8718 | 0.8718 |
252
+
253
+ The "LLM-bridge effect" we observed was **lexical match enabled by the
254
+ English conversion**, not semantic CE understanding. BM25 over English docs
255
+ does the same job.
256
+
257
+ **Stress test**: stripped theme tokens from English docs too. Forces the CE
258
+ to genuinely understand "fork query ↔ fork-pattern moves":
259
+
260
+ | K | Static | +CE | +BM25 | Oracle |
261
+ |---|---|---|---|---|
262
+ | 100 | 0.1202 | 0.0726 | 0.4327 | 0.5947 |
263
+ | 300 | 0.1202 | 0.0706 | 0.6252 | 0.8718 |
264
+
265
+ CE drops below static (negative transfer — memorized "theme overlap = match"
266
+ during training; can't generalize). BM25 still partially works via opening
267
+ name overlap.
268
+
269
+ **True semantic CE chess understanding is not achievable** with 22M-param
270
+ English-pretrained models on our training signal.
271
+
272
+ ---
273
+
274
+ ## Production recommendation
275
+
276
+ For a real chess search system, the winning architecture is:
277
 
 
 
 
 
 
 
 
 
 
 
278
  ```
279
+ Stage 1: Static embedding (this model)
280
+ - Encode chess-format corpus (4M params, ~9MB)
281
+ - Sub-millisecond CPU inference
282
+ - Retrieve top-200 candidates via cosine similarity
283
+ - Recall@200 = 93.5%
284
+
285
+ Stage 2: BM25 over English-bridged corpus
286
+ - python-chess + regex (one-time, $0)
287
+ - Index the English versions of all docs
288
+ - Rerank top-200 candidates to top-10
289
+ - NDCG@10 ≈ 0.55-0.62
290
+ ```
291
+
292
+ **Total: <10ms/query, $0 inference cost, no GPU.**
293
+
294
+ The cross-encoder is only worth adding if you have GPU available AND you train
295
+ it on a fundamentally different signal (e.g., human-annotated relevance,
296
+ chess-engine strategic descriptions, or much more parameters with chess in
297
+ pretraining).
298
+
299
+ ---
300
 
301
+ ## Key learnings worth keeping (general, not chess-specific)
302
+
303
+ 1. **Eval methodology dominates.** Most time spent debugging the "model isn't
304
+ improving" turned out to be eval issues, not training issues. Compositional
305
+ held-out > top-frequent-string eval. Strip lexical leakage between query
306
+ and corpus when testing generalization.
307
+
308
+ 2. **Sentence-transformers' `NoDuplicatesBatchSampler` is O(epoch-progress)
309
+ per batch.** It walks a linked-list of deferred conflicts. For datasets
310
+ with limited unique anchors (our ~327k anchors over 5.8M pairs), this
311
+ creates monotonic step-time blowup. Switch to `BatchSamplers.BATCH_SAMPLER`.
312
+
313
+ 3. **`CachedMultipleNegativesRankingLoss` is incompatible with
314
+ `StaticEmbedding`** — explicit error. Token-bag has no transformer
315
+ activations to GradCache through.
316
+
317
+ 4. **Trackio crashes on first checkpoint push** with sentence-transformers
318
+ due to an empty `router_mapping` struct that pyarrow can't write. Use
319
+ `report_to="none"`.
320
+
321
+ 5. **The "LLM-bridge" pattern**: when system A speaks language X and system
322
+ B speaks language Y, use an LLM to translate B→X once (not at inference).
323
+ For chess: LLM writes English definitions of themes → general English
324
+ teacher can now embed them → distill into chess-specific model.
325
+
326
+ 6. **Deterministic translation often suffices** for the bridge. Don't pay LLM
327
+ API costs if `python-chess` and regex can produce the same English text.
328
+ Reserve LLMs for the parts that genuinely need understanding (concept
329
+ definitions, paraphrases, strategic narratives).
330
+
331
+ 7. **Compare your trained model against BM25** on the actual eval. If they
332
+ tie, your model is doing keyword matching, not semantic work. Diagnostic
333
+ in `scripts/diag_ce_vs_bm25.py`.
334
+
335
+ 8. **Modal `.spawn()` only survives entrypoint exit on deployed apps.** For
336
+ ephemeral `modal run`, the app dies when entrypoint returns — including
337
+ spawned calls. Use `.remote()` with `--detach`.
338
+
339
+ 9. **Apple Silicon M4 is competitive with cloud A100** for tiny models. Token
340
+ bag + small batch easily hits 17 it/s on MPS. GPU cost is wasted unless
341
+ the model is compute-bound.
342
+
343
+ ---
344
+
345
+ ## Reproducibility
346
+
347
+ Clone this repo, then with sentence-transformers v5.5+:
348
+
349
+ ```bash
350
+ # Inspect the recipe
351
+ cat scripts/train_chess_multitask.py
352
+
353
+ # Reproduce the data prep (one-time, ~10 min)
354
+ python scripts/generate_theme_defs.py # Needs DeepSeek API key in macOS keychain
355
+ python scripts/convert_to_english.py # python-chess + regex, $0
356
+ python scripts/mine_hard_negs_v2.py # ~10 min on M4 MPS
357
+
358
+ # Reproduce the winning training
359
+ python scripts/train_chess_multitask.py # ~5 min on M4 MPS
360
+
361
+ # Verify
362
+ python scripts/compare_variants.py # Side-by-side eval table
363
+ python scripts/diag_ce_vs_bm25.py # Is the rerank doing real work?
364
  ```
365
 
366
+ ---
367
+
368
+ ## Limitations and honest caveats
369
+
370
+ - **NDCG@10 = 0.12 is modest in absolute terms.** Industry retrieval encoders
371
+ reach 0.4-0.6 on similar tasks. This model is competitive on size/speed,
372
+ not absolute quality.
373
+ - **The two-stage architecture (NDCG@10 ≈ 0.6) is the production answer**
374
+ but relies on BM25 over English-converted docs, not on the cross-encoder.
375
+ - **Cross-encoder didn't add semantic value** in our setup; results came from
376
+ lexical match enabled by the English bridge.
377
+ - **Bimodal failure**: even the best model misses half of queries entirely
378
+ (median NDCG@10 = 0). The architecture has fundamental limits for chess
379
+ reasoning.
380
+ - **English-pretrained models don't know chess.** Tried MPNet, MiniLM,
381
+ Jina-v5; all fail on UCI moves. Bigger English models won't fix this; only
382
+ chess-pretrained or deterministic conversion helps.
383
+ - **No engine evaluation.** "Is this puzzle a fork?" was determined by
384
+ Lichess theme tags; we never ran a chess engine. A real production system
385
+ would integrate Stockfish for ground-truth tactical pattern detection.
386
+
387
+ ---
388
+
389
+ ## What this is NOT
390
 
391
+ - Not a chess engine. See [`thomasahle/fastchess`](https://github.com/thomasahle/fastchess)
392
+ for FastText-based move prediction (closest related work).
393
+ - Not a position similarity model. See `chess2vec` lineage on GitHub for
394
+ position-level embeddings.
395
+ - Not a state-of-the-art retrieval model. It's a tiny first-stage filter
396
+ designed to pair with a reranker.
397
+
398
+ ---
399
 
400
+ ## License
 
401
 
402
+ Apache 2.0 (model + scripts). Data derived from Lichess/chess-puzzles which is
403
+ CC0 — derived parquets in this repo are also released under CC0.
404
 
405
+ ## Acknowledgments
406
+
407
+ - [Lichess](https://lichess.org) for releasing puzzles + openings under CC0.
408
+ - [Tom Aarsen](https://huggingface.co/tomaarsen) for the
409
+ `train-sentence-transformers` skill and `StaticEmbedding` recipe.
410
+ - DeepSeek for the v4-flash API used for theme definitions.
411
+
412
+ ## Citation
413
 
414
+ If this work is useful, please link to this repo. The scientific findings
415
+ (particularly the deterministic-bridge insight that BM25 over English-bridged
416
+ docs equals a trained cross-encoder for this task) are the main contribution.
data/hard_negatives_chess.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dc7f1bfcb497ba5f5e61c1b9fffe76ca52825758454c65b3a2dc2010e3e68bb
3
+ size 161012028
data/hard_negatives_english.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50b28d80013527fcb6f27554ee0cda91116e4b3967a74472320a089a7b1fa873
3
+ size 111083130
data/theme_definitions.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f70e1629bfda29faedfca1474d2195bd527590eeb48b628fd862da12a2070f3
3
+ size 456977
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:85ba258107839fe02a04763d71797aeb5f4fa19f2a8712e73a0ed0e38b4c15ff
3
  size 8880224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fa4d9dd8e62c4ef6d7f288ea1822f30d5f75f3a5ab178a923c4330e3b09652d
3
  size 8880224
scripts/compare_variants.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "sentence-transformers[train]>=5.5.0",
6
+ # "datasets>=2.19.0",
7
+ # "numpy",
8
+ # ]
9
+ # ///
10
+ """Side-by-side comparison of all chess static-embedding variants on the same
11
+ held-out compositional eval. Produces the final table for NOTES.md.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import os
16
+ import sys
17
+ from collections import defaultdict
18
+
19
+ import numpy as np
20
+ from datasets import load_dataset
21
+ from sentence_transformers import SentenceTransformer
22
+
23
+ sys.stdout.reconfigure(line_buffering=True)
24
+
25
+ VARIANTS = [
26
+ ("v3 baseline", "models/static-embedding-chess/final"),
27
+ ("v4-A hard-neg only", "models/static-embedding-chess-triplet/final"),
28
+ ("v4-B theme distill", "models/static-embedding-chess-theme-only/final"),
29
+ ("v4-C multitask 500x", "models/static-embedding-chess-multitask-500x/final"),
30
+ ("v4-C2 multitask 5000x", "models/static-embedding-chess-multitask-5000x/final"),
31
+ ]
32
+
33
+ HELDOUT_FREQ_MIN = 3
34
+ HELDOUT_FREQ_MAX = 30
35
+ EVAL_QUERIES = 200
36
+
37
+
38
+ def _join_tags(tags):
39
+ return " ".join(t.replace("_", " ") for t in tags) if tags else ""
40
+
41
+
42
+ def _bigram_token_str(moves):
43
+ toks = moves.split()
44
+ if len(toks) < 2:
45
+ return moves
46
+ return moves + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
47
+
48
+
49
+ def build_puzzle_pairs(batch):
50
+ anchors, positives = [], []
51
+ for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
52
+ themes_txt = _join_tags(themes)
53
+ op_txt = _join_tags(op)
54
+ if not themes_txt:
55
+ continue
56
+ anchor = themes_txt + (f" {op_txt}" if op_txt else "")
57
+ positive = f"themes {themes_txt}"
58
+ if op_txt:
59
+ positive += f" opening {op_txt}"
60
+ positive += f" moves {_bigram_token_str(moves)}"
61
+ anchors.append(anchor)
62
+ positives.append(positive)
63
+ return {"anchor": anchors, "positive": positives}
64
+
65
+
66
+ def strip_theme_echo(p):
67
+ i = p.find(" moves ")
68
+ return p[i + 1 :] if i != -1 else p
69
+
70
+
71
+ def ndcg_at_k(scores, rel, k=10):
72
+ ranked = sorted(scores, key=lambda kv: -kv[1])[:k]
73
+ dcg = sum((1.0 if d in rel else 0.0) / np.log2(r + 2) for r, (d, _) in enumerate(ranked))
74
+ idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), k)))
75
+ return dcg / idcg if idcg > 0 else 0.0
76
+
77
+
78
+ def main():
79
+ print("Loading + held-out selection...")
80
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
81
+ pair_puzzles = puzzles.map(
82
+ build_puzzle_pairs,
83
+ batched=True, batch_size=20_000,
84
+ remove_columns=puzzles.column_names,
85
+ num_proc=4,
86
+ )
87
+ anchors = pair_puzzles["anchor"]
88
+ freq = defaultdict(int)
89
+ for a in anchors:
90
+ freq[a] += 1
91
+ rare_pool = sorted(
92
+ ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
93
+ key=lambda kv: kv[1],
94
+ )
95
+ heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]}
96
+ held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h]
97
+ held_anchors = [anchors[i] for i in held_idx]
98
+ corpus_texts = [strip_theme_echo(pair_puzzles["positive"][i]) for i in held_idx]
99
+ corpus_ids = [f"d{i}" for i in range(len(corpus_texts))]
100
+ by_anchor = defaultdict(list)
101
+ for i, a in enumerate(held_anchors):
102
+ by_anchor[a].append(corpus_ids[i])
103
+ queries = list(by_anchor.keys())
104
+ print(f" {len(queries)} queries, {len(corpus_texts)} corpus")
105
+
106
+ results = []
107
+
108
+ for name, path in VARIANTS:
109
+ if not os.path.exists(path):
110
+ print(f"\nSKIPPING {name}: {path} not found")
111
+ continue
112
+ print(f"\n=== {name} ({path}) ===")
113
+ m = SentenceTransformer(path)
114
+ c = m.encode(corpus_texts, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
115
+ c = c / np.linalg.norm(c, axis=1, keepdims=True)
116
+ q = m.encode(queries, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
117
+ q = q / np.linalg.norm(q, axis=1, keepdims=True)
118
+ sims = q @ c.T
119
+ ndcgs = []
120
+ for qi, query in enumerate(queries):
121
+ score_pairs = [(corpus_ids[ci], float(sims[qi, ci])) for ci in range(len(corpus_ids))]
122
+ rel = set(by_anchor[query])
123
+ ndcgs.append(ndcg_at_k(score_pairs, rel, k=10))
124
+ ndcg = np.mean(ndcgs)
125
+ median = np.median(ndcgs)
126
+ zero = sum(1 for n in ndcgs if n == 0)
127
+ results.append((name, ndcg, median, zero, len(ndcgs)))
128
+ print(f" NDCG@10 = {ndcg:.4f} median = {median:.4f} zero = {zero}/{len(ndcgs)}")
129
+
130
+ print("\n" + "=" * 70)
131
+ print(f"{'Variant':<30} {'NDCG@10':>10} {'Median':>10} {'Zero/All':>15}")
132
+ print("=" * 70)
133
+ for name, ndcg, median, zero, total in results:
134
+ print(f"{name:<30} {ndcg:>10.4f} {median:>10.4f} {zero:>7}/{total:<7}")
135
+ print("=" * 70)
136
+
137
+ # === Token-similarity probe ===
138
+ # Measures the orthogonal-tokens problem from Phase 1: do related themes
139
+ # cluster in embedding space? Higher = more semantic structure.
140
+ print("\n=== Theme-token similarity (higher = more semantic clustering) ===")
141
+ PROBES = [
142
+ ("fork", "skewer"), # tactical motifs (should be close)
143
+ ("fork", "pin"),
144
+ ("backRankMate", "smotheredMate"), # mate patterns
145
+ ("kingsideAttack", "queensideAttack"),
146
+ ("endgame", "middlegame"), # phases
147
+ ("fork", "promotion"), # unrelated (control)
148
+ ]
149
+ print(f"{'Pair':<40}", end="")
150
+ for name, _ in VARIANTS:
151
+ if os.path.exists([p for n, p in VARIANTS if n == name][0]):
152
+ print(f" {name[:14]:>16}", end="")
153
+ print()
154
+ print("-" * 70)
155
+ for a, b in PROBES:
156
+ line = f"{a} <-> {b}".ljust(40)
157
+ for name, path in VARIANTS:
158
+ if not os.path.exists(path):
159
+ continue
160
+ m = SentenceTransformer(path)
161
+ ea = m.encode([a], convert_to_numpy=True)[0]
162
+ eb = m.encode([b], convert_to_numpy=True)[0]
163
+ ea = ea / max(np.linalg.norm(ea), 1e-9)
164
+ eb = eb / max(np.linalg.norm(eb), 1e-9)
165
+ sim = float(np.dot(ea, eb))
166
+ line += f" {sim:>+16.3f}"
167
+ print(line)
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
scripts/convert_to_english.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = ["chess", "datasets>=2.19", "tqdm"]
5
+ # ///
6
+ """Deterministic chess→English converter for puzzles.
7
+
8
+ Generates a standardized English-readable description of each puzzle WITHOUT
9
+ any LLM. Uses python-chess for UCI→SAN conversion (with board context), regex
10
+ for decamelizing themes, and a fixed template.
11
+
12
+ For each puzzle, produces a doc like:
13
+
14
+ "White to move. Short middlegame puzzle with crushing fork and hanging
15
+ piece motifs. Opening: King's Pawn Game. Moves: Bxg3 Rxe7 Qb1+ Nc1 Qxc1+
16
+ Qxc1"
17
+
18
+ Pretrained English cross-encoders have seen SAN notation in chess web content
19
+ during pretraining, so this doc is semantically meaningful to them — unlike
20
+ the raw UCI form (`f2g3`) which gets fragmented into character pieces.
21
+
22
+ Output: parquet at models/puzzles_english.parquet with columns:
23
+ PuzzleId, anchor (original themes+opening str), english_doc
24
+
25
+ Run:
26
+ SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 convert_to_english.py
27
+ uv run --exclude-newer=2026-05-12 convert_to_english.py
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import os
32
+ import re
33
+ import sys
34
+
35
+ import chess
36
+ from datasets import Dataset, load_dataset
37
+ from tqdm import tqdm
38
+
39
+ sys.stdout.reconfigure(line_buffering=True)
40
+
41
+ OUTPUT_PATH = "models/puzzles_english.parquet"
42
+ SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
43
+
44
+ # Length tag mapping
45
+ LENGTH_MAP = {
46
+ "oneMove": "single-move",
47
+ "short": "short",
48
+ "long": "long",
49
+ "veryLong": "very long",
50
+ }
51
+ PHASE_TAGS = {"opening", "middlegame", "endgame"}
52
+ LENGTH_TAGS = set(LENGTH_MAP.keys())
53
+ # Anything matching `mateInN`, `mateIn1`, etc.
54
+ MATE_IN_PATTERN = re.compile(r"^mateIn(\d+)$")
55
+ # Specific mate-pattern names (their English form is just decamel)
56
+ # camelCase → "camel case" via regex
57
+ _CAMEL_BOUNDARY = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
58
+
59
+
60
+ def decamelize(tag: str) -> str:
61
+ """`backRankMate` → 'back rank mate'. `attackingF2F7` → 'attacking f2 f7'."""
62
+ return _CAMEL_BOUNDARY.sub(" ", tag).lower()
63
+
64
+
65
+ def themes_to_english(themes: list[str]) -> tuple[str, str, str, list[str]]:
66
+ """Returns (side_phrase, length_phrase, phase, decamelized_other_themes).
67
+
68
+ Splits themes into structural (phase, length, mate-in-N) and motif (everything else).
69
+ The motifs are returned decamelized.
70
+ """
71
+ if not themes:
72
+ return ("", "", "", [])
73
+ phase = ""
74
+ length = ""
75
+ mate_in = None
76
+ motifs = []
77
+ for t in themes:
78
+ if t in PHASE_TAGS:
79
+ phase = t
80
+ elif t in LENGTH_TAGS:
81
+ length = LENGTH_MAP[t]
82
+ elif (m := MATE_IN_PATTERN.match(t)):
83
+ mate_in = int(m.group(1))
84
+ else:
85
+ motifs.append(decamelize(t))
86
+ # Mate-in-N gets folded into motifs as natural-language phrase
87
+ if mate_in is not None:
88
+ motifs.append(f"mate in {mate_in}")
89
+ return phase, length, "", motifs # side_phrase computed separately from FEN
90
+
91
+
92
+ def opening_tags_to_english(opening_tags: list[str]) -> str:
93
+ """`['Kings_Pawn_Game', 'Kings_Pawn_Game_Leonardis_Variation']` → 'King's Pawn Game Leonardi's Variation'.
94
+ Dedupe by taking the longest matching tag."""
95
+ if not opening_tags:
96
+ return ""
97
+ # Use the longest tag (most specific) and replace underscores with spaces
98
+ longest = max(opening_tags, key=len)
99
+ return longest.replace("_", " ")
100
+
101
+
102
+ def uci_to_san_sequence(fen: str, uci_moves: str) -> str:
103
+ """Convert UCI move sequence to SAN, using board context for disambiguation."""
104
+ try:
105
+ board = chess.Board(fen)
106
+ san_moves = []
107
+ for uci in uci_moves.split():
108
+ try:
109
+ move = chess.Move.from_uci(uci)
110
+ san = board.san(move)
111
+ san_moves.append(san)
112
+ board.push(move)
113
+ except Exception:
114
+ # Invalid move — skip rest
115
+ break
116
+ return " ".join(san_moves)
117
+ except Exception:
118
+ return uci_moves # fall back to raw UCI
119
+
120
+
121
+ def side_to_move(fen: str) -> str:
122
+ parts = fen.split()
123
+ if len(parts) >= 2 and parts[1] == "w":
124
+ return "White"
125
+ return "Black"
126
+
127
+
128
+ def build_english_doc(row: dict) -> str:
129
+ """Build a deterministic English description from a Lichess puzzle row."""
130
+ side = side_to_move(row["FEN"])
131
+ phase, length, _, motifs = themes_to_english(row["Themes"] or [])
132
+ opening = opening_tags_to_english(row.get("OpeningTags") or [])
133
+ san = uci_to_san_sequence(row["FEN"], row["Moves"])
134
+
135
+ # Construct sentence
136
+ parts = []
137
+ parts.append(f"{side} to move.")
138
+
139
+ # "Short middlegame puzzle with crushing fork and hanging piece motifs."
140
+ descriptor = []
141
+ if length:
142
+ descriptor.append(length)
143
+ if phase:
144
+ descriptor.append(phase)
145
+ descriptor.append("puzzle")
146
+ descriptor_str = " ".join(descriptor)
147
+ if motifs:
148
+ motifs_str = ", ".join(motifs)
149
+ descriptor_str += f" with {motifs_str} motifs"
150
+ parts.append(descriptor_str.capitalize() + ".")
151
+
152
+ if opening:
153
+ parts.append(f"Opening: {opening}.")
154
+
155
+ if san:
156
+ parts.append(f"Moves: {san}")
157
+
158
+ return " ".join(parts)
159
+
160
+
161
+ def build_english_anchor(row: dict) -> str:
162
+ """Anchor side: same as before (themes + opening) but in deterministic English.
163
+ Used as query for retrieval/reranker training."""
164
+ phase, length, _, motifs = themes_to_english(row["Themes"] or [])
165
+ opening = opening_tags_to_english(row.get("OpeningTags") or [])
166
+ parts = []
167
+ if motifs:
168
+ parts.append(", ".join(motifs))
169
+ if length:
170
+ parts.append(length)
171
+ if phase:
172
+ parts.append(phase)
173
+ if opening:
174
+ parts.append(opening)
175
+ return " ".join(parts).strip()
176
+
177
+
178
+ def main():
179
+ print("Loading puzzles...")
180
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
181
+ if SMOKE_TEST:
182
+ puzzles = puzzles.select(range(2_000))
183
+ print(f" {len(puzzles):,} rows")
184
+
185
+ print("Converting to English (deterministic)...")
186
+
187
+ def proc(batch):
188
+ ids, anchors, docs = [], [], []
189
+ for r in [{k: batch[k][i] for k in batch} for i in range(len(batch["PuzzleId"]))]:
190
+ if not r["Themes"]:
191
+ continue
192
+ ids.append(r["PuzzleId"])
193
+ anchors.append(build_english_anchor(r))
194
+ docs.append(build_english_doc(r))
195
+ return {"PuzzleId": ids, "anchor_en": anchors, "doc_en": docs}
196
+
197
+ out = puzzles.map(
198
+ proc, batched=True, batch_size=10_000,
199
+ remove_columns=puzzles.column_names,
200
+ num_proc=4,
201
+ )
202
+ print(f" produced {len(out):,} English-converted rows")
203
+
204
+ print("\n=== Sample conversions ===")
205
+ for i in [0, 100, 1000]:
206
+ r = out[i]
207
+ print(f"\nPuzzleId: {r['PuzzleId']}")
208
+ print(f" anchor: {r['anchor_en']!r}")
209
+ print(f" doc: {r['doc_en'][:200]!r}")
210
+
211
+ out.to_parquet(OUTPUT_PATH)
212
+ print(f"\nSaved to {OUTPUT_PATH} ({os.path.getsize(OUTPUT_PATH) / 1e6:.1f} MB)")
213
+
214
+
215
+ if __name__ == "__main__":
216
+ main()
scripts/diag_ce_vs_bm25.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = ["sentence-transformers[train]>=5.5.0", "datasets>=2.19", "numpy", "rank-bm25", "chess"]
5
+ # ///
6
+ """Compare trained CE vs BM25 on English-bridged docs, plus top-K sweep.
7
+
8
+ Tests:
9
+ 1. Is the 0.59 CE result just lexical match that BM25 could also do?
10
+ 2. Does increasing K to 200/300 push past oracle 0.59 → 0.77 → 0.87?
11
+ """
12
+ import os
13
+ import sys
14
+ from collections import defaultdict
15
+
16
+ import numpy as np
17
+ from datasets import Dataset, load_dataset
18
+ from rank_bm25 import BM25Okapi
19
+ from sentence_transformers import CrossEncoder, SentenceTransformer
20
+
21
+ sys.stdout.reconfigure(line_buffering=True)
22
+ sys.path.insert(0, os.path.dirname(__file__))
23
+ from convert_to_english import build_english_anchor, build_english_doc
24
+
25
+ HELDOUT_FREQ_MIN = 3
26
+ HELDOUT_FREQ_MAX = 30
27
+ EVAL_QUERIES = 200
28
+
29
+
30
+ def _join_tags(tags):
31
+ return " ".join(t.replace("_", " ") for t in tags) if tags else ""
32
+
33
+
34
+ def _bigram(m):
35
+ toks = m.split()
36
+ return m + " " + " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:])) if len(toks) > 1 else m
37
+
38
+
39
+ def build_chess_anchor(themes, op):
40
+ tt = _join_tags(themes)
41
+ ot = _join_tags(op or [])
42
+ return tt + (f" {ot}" if ot else "")
43
+
44
+
45
+ def build_chess_doc_stripped(themes, op, moves):
46
+ return f"moves {_bigram(moves)}"
47
+
48
+
49
+ def ndcg_at_k(scores, rel, k=10):
50
+ r = sorted(scores, key=lambda kv: -kv[1])[:k]
51
+ dcg = sum((1.0 if d in rel else 0.0) / np.log2(rr + 2) for rr, (d, _) in enumerate(r))
52
+ idcg = sum(1.0 / np.log2(rr + 2) for rr in range(min(len(rel), k)))
53
+ return dcg / idcg if idcg > 0 else 0
54
+
55
+
56
+ def main():
57
+ print("Building eval set...")
58
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
59
+ freq = defaultdict(int)
60
+ rows_by_anchor = defaultdict(list)
61
+ for r in puzzles:
62
+ if not r["Themes"]:
63
+ continue
64
+ ca = build_chess_anchor(r["Themes"], r["OpeningTags"])
65
+ freq[ca] += 1
66
+ rows_by_anchor[ca].append(r)
67
+ rare = sorted(((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX), key=lambda kv: kv[1])
68
+ heldout = [a for a, _ in rare[:EVAL_QUERIES]]
69
+ print(f" {len(heldout)} held-out anchors")
70
+
71
+ qchess, qen = [], []
72
+ corp_chess, corp_en = [], []
73
+ held_per_doc = []
74
+ ch_to_en = {}
75
+ for ca in heldout:
76
+ for r in rows_by_anchor[ca]:
77
+ corp_chess.append(build_chess_doc_stripped(r["Themes"], r["OpeningTags"], r["Moves"]))
78
+ corp_en.append(build_english_doc(r))
79
+ held_per_doc.append(ca)
80
+ if ca not in ch_to_en:
81
+ ch_to_en[ca] = build_english_anchor(r)
82
+ qchess = list(heldout)
83
+ qen = [ch_to_en[a] for a in qchess]
84
+ by_anchor = defaultdict(list)
85
+ for i, a in enumerate(held_per_doc):
86
+ by_anchor[a].append(i)
87
+ print(f" corpus: {len(corp_chess)} docs")
88
+
89
+ print("\nLoading static (v4-C2) for first-stage...")
90
+ static = SentenceTransformer("models/static-embedding-chess-multitask-5000x/final")
91
+ sc = static.encode(corp_chess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
92
+ sc = sc / np.linalg.norm(sc, axis=1, keepdims=True)
93
+ sq = static.encode(qchess, batch_size=128, convert_to_numpy=True, show_progress_bar=False)
94
+ sq = sq / np.linalg.norm(sq, axis=1, keepdims=True)
95
+ static_sims = sq @ sc.T
96
+
97
+ # Loaded trained CE
98
+ print("Loading trained CE...")
99
+ ce = CrossEncoder("models/chess-reranker-english/final")
100
+
101
+ # BM25 on English docs
102
+ print("Building BM25 over English docs...")
103
+ bm25 = BM25Okapi([d.split() for d in corp_en])
104
+
105
+ print("\n" + "=" * 80)
106
+ print(f" {'K':>4} {'Static':>10} {'+CE':>10} {'+BM25':>10} {'Oracle':>10}")
107
+ print("=" * 80)
108
+ for k in [10, 50, 100, 200, 300]:
109
+ if k > len(corp_chess):
110
+ continue
111
+ static_ndcg = []
112
+ ce_ndcg = []
113
+ bm25_ndcg = []
114
+ oracle_ndcg = []
115
+ for qi, q_chess in enumerate(qchess):
116
+ rel = set(by_anchor[q_chess])
117
+ # Static-only at top-10
118
+ top10 = np.argsort(-static_sims[qi])[:10]
119
+ sp = [(int(i), float(static_sims[qi, int(i)])) for i in top10]
120
+ static_ndcg.append(ndcg_at_k(sp, rel, k=10))
121
+ # Top-K shortlist
122
+ topk = np.argsort(-static_sims[qi])[:k]
123
+ # CE rerank
124
+ pairs = [[qen[qi], corp_en[int(i)]] for i in topk]
125
+ ce_scores = ce.predict(pairs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
126
+ ce_sp = [(int(topk[j]), float(ce_scores[j])) for j in range(len(topk))]
127
+ ce_ndcg.append(ndcg_at_k(ce_sp, rel, k=10))
128
+ # BM25 rerank over top-K shortlist
129
+ bm_full = bm25.get_scores(qen[qi].split())
130
+ bm_sp = [(int(topk[j]), float(bm_full[int(topk[j])])) for j in range(len(topk))]
131
+ bm25_ndcg.append(ndcg_at_k(bm_sp, rel, k=10))
132
+ # Oracle ceiling
133
+ rel_in_topk = len(rel & set(int(i) for i in topk))
134
+ n10 = min(10, rel_in_topk)
135
+ dcg = sum(1.0 / np.log2(r + 2) for r in range(n10))
136
+ idcg = sum(1.0 / np.log2(r + 2) for r in range(min(len(rel), 10)))
137
+ oracle_ndcg.append(dcg / idcg if idcg > 0 else 0)
138
+ # static stays the same regardless of K
139
+ static_v = np.mean(static_ndcg)
140
+ print(f" {k:>4} {static_v:>10.4f} {np.mean(ce_ndcg):>10.4f} {np.mean(bm25_ndcg):>10.4f} {np.mean(oracle_ndcg):>10.4f}")
141
+ print("=" * 80)
142
+
143
+
144
+ if __name__ == "__main__":
145
+ main()
scripts/generate_theme_defs.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "datasets>=2.19.0",
6
+ # "openai>=1.0",
7
+ # "sentence-transformers[train]>=5.5.0",
8
+ # "tqdm",
9
+ # "numpy",
10
+ # ]
11
+ # ///
12
+ """Generate natural-language definitions for each Lichess theme via DeepSeek,
13
+ then embed those definitions with a general sentence-transformer (MPNet).
14
+
15
+ The resulting (theme_token, definition_embedding) pairs form a "chess-aware
16
+ teacher" — an English description of each chess concept that MPNet CAN
17
+ understand semantically. We can then distill those embeddings into our
18
+ StaticEmbedding model's token table.
19
+
20
+ Solves the "MPNet doesn't know chess" problem: MPNet can't read UCI moves,
21
+ but it CAN read English ("A tactical motif where one piece attacks two pieces
22
+ simultaneously" → semantically near "A tactic where you create a double
23
+ attack threatening two pieces at once"). Token-level semantic structure
24
+ emerges from the LLM bridge.
25
+
26
+ Run:
27
+ SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 generate_theme_defs.py
28
+ uv run --exclude-newer=2026-05-12 generate_theme_defs.py
29
+ """
30
+ import json
31
+ import os
32
+ import subprocess
33
+ import sys
34
+ from collections import Counter
35
+ from concurrent.futures import ThreadPoolExecutor, as_completed
36
+
37
+ import numpy as np
38
+ from datasets import Dataset, load_dataset
39
+ from openai import OpenAI
40
+ from sentence_transformers import SentenceTransformer
41
+ from tqdm import tqdm
42
+
43
+ MODEL = "deepseek-v4-flash"
44
+ TEACHER_MODEL = "sentence-transformers/all-mpnet-base-v2"
45
+ OUTPUT_PATH = "models/theme_definitions.parquet"
46
+ SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
47
+ PARALLEL_WORKERS = 4
48
+
49
+ SYSTEM_PROMPT = """You write concise dictionary-style definitions of chess
50
+ concepts. Given a theme/concept name (often in camelCase from Lichess.org's
51
+ puzzle tagging system), write a single English sentence of 10-25 words
52
+ explaining the concept. Be specific and use the standard chess vocabulary that
53
+ would appear in any chess textbook.
54
+
55
+ Output ONLY the definition sentence. No labels, no quotes, no commentary.
56
+
57
+ Examples:
58
+ Input: fork
59
+ Output: A tactical motif where a single piece attacks two or more enemy pieces simultaneously, forcing a material gain.
60
+
61
+ Input: backRankMate
62
+ Output: A checkmate delivered along the opponent's back rank, typically with a rook or queen, when the king is trapped by its own pawns.
63
+
64
+ Input: zugzwang
65
+ Output: A position in which any move worsens the player's position, so being forced to move becomes a disadvantage.
66
+ """
67
+
68
+
69
+ def get_deepseek_key():
70
+ r = subprocess.run(
71
+ ["security", "find-generic-password", "-s", "deepseek-api", "-w"],
72
+ capture_output=True, text=True, timeout=5,
73
+ )
74
+ return r.stdout.strip() if r.returncode == 0 else os.environ.get("DEEPSEEK_API_KEY")
75
+
76
+
77
+ def define_theme(client, theme, debug=False):
78
+ try:
79
+ resp = client.chat.completions.create(
80
+ model=MODEL,
81
+ messages=[
82
+ {"role": "system", "content": SYSTEM_PROMPT},
83
+ {"role": "user", "content": theme},
84
+ ],
85
+ temperature=0.2,
86
+ max_tokens=1500, # DeepSeek-v4-flash spends tokens on reasoning_content; obscure mate-pattern names need lots
87
+ timeout=30,
88
+ )
89
+ content = resp.choices[0].message.content
90
+ return content.strip() if content else None
91
+ except Exception as e:
92
+ if debug:
93
+ print(f" EXC for {theme!r}: {type(e).__name__}: {e}")
94
+ return None
95
+
96
+
97
+ def main():
98
+ key = get_deepseek_key()
99
+ if not key:
100
+ sys.exit("No DeepSeek API key in keychain")
101
+ client = OpenAI(api_key=key, base_url="https://api.deepseek.com/v1")
102
+
103
+ print("Enumerating themes from Lichess puzzles...")
104
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)
105
+ counter = Counter()
106
+ sample_size = 50_000 if SMOKE_TEST else 1_000_000
107
+ for i, r in enumerate(puzzles):
108
+ if i >= sample_size:
109
+ break
110
+ for t in (r["Themes"] or []):
111
+ counter[t] += 1
112
+ themes = sorted(counter.keys())
113
+ print(f" {len(themes)} unique themes")
114
+
115
+ if SMOKE_TEST:
116
+ themes = themes[:10]
117
+ print(f" SMOKE_TEST=1: limited to {len(themes)}")
118
+
119
+ print(f"\nGenerating definitions via {MODEL}...")
120
+ defs = {}
121
+ with ThreadPoolExecutor(max_workers=PARALLEL_WORKERS) as ex:
122
+ futs = {ex.submit(define_theme, client, t, True): t for t in themes}
123
+ for f in tqdm(as_completed(futs), total=len(futs)):
124
+ t = futs[f]
125
+ defs[t] = f.result()
126
+
127
+ failed = [t for t, d in defs.items() if not d]
128
+ if failed:
129
+ print(f" {len(failed)} themes failed: {failed[:5]}")
130
+ print(f" {len(defs) - len(failed)}/{len(defs)} succeeded")
131
+
132
+ print("\nSample definitions:")
133
+ for t in themes[:8]:
134
+ if defs[t]:
135
+ print(f" {t:>20s} -> {defs[t]}")
136
+
137
+ valid = [(t, defs[t]) for t in themes if defs[t]]
138
+
139
+ print(f"\nEmbedding {len(valid)} definitions with {TEACHER_MODEL}...")
140
+ teacher = SentenceTransformer(TEACHER_MODEL)
141
+ sentences = [d for _, d in valid]
142
+ embs = teacher.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_numpy=True)
143
+
144
+ # Sanity: do related themes have similar embeddings?
145
+ emb_norm = embs / np.linalg.norm(embs, axis=1, keepdims=True)
146
+ sim = emb_norm @ emb_norm.T
147
+ print("\nSanity check: pairwise similarities for related themes")
148
+ name_to_idx = {t: i for i, (t, _) in enumerate(valid)}
149
+ for a, b in [
150
+ ("fork", "skewer"), ("fork", "pin"), ("backRankMate", "smotheredMate"),
151
+ ("kingsideAttack", "queensideAttack"), ("endgame", "middlegame"),
152
+ ("fork", "promotion"), # not directly related
153
+ ]:
154
+ if a in name_to_idx and b in name_to_idx:
155
+ print(f" {a!r:>20} <-> {b!r:25} = {sim[name_to_idx[a], name_to_idx[b]]:+.3f}")
156
+
157
+ out = Dataset.from_dict({
158
+ "theme": [t for t, _ in valid],
159
+ "definition": [d for _, d in valid],
160
+ "embedding": embs.tolist(),
161
+ })
162
+ os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True)
163
+ out.to_parquet(OUTPUT_PATH)
164
+ print(f"\nSaved {len(out)} theme definitions to {OUTPUT_PATH}")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
scripts/mine_hard_negs_v2.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "sentence-transformers[train]>=5.5.0",
6
+ # "datasets>=2.19.0",
7
+ # "numpy",
8
+ # "tqdm",
9
+ # ]
10
+ # ///
11
+ """Memory-bounded hard-negative miner. Custom impl (not sentence-transformers
12
+ util) because the SE function tries to hold the full anchor × corpus similarity
13
+ matrix, which OOMs at 327k anchors × 327k positives on M4.
14
+
15
+ Algorithm:
16
+ 1. Encode all unique positives once -> N x dim float32 (~670MB at 327k x 512).
17
+ 2. Encode all unique anchors once -> M x dim float32.
18
+ 3. For each anchor batch (size B):
19
+ - scores = batch_emb @ positives_emb.T -> B x N
20
+ - per anchor: argpartition for top RANGE_MAX, exclude actual positive,
21
+ sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX).
22
+ 4. Stream triplets to parquet.
23
+
24
+ Peak memory: B * N * 4 bytes for scores. With B=500, N=327k: 650MB.
25
+
26
+ Run:
27
+ SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py
28
+ uv run --exclude-newer=2026-05-12 mine_hard_negs_v2.py
29
+ """
30
+ from __future__ import annotations
31
+
32
+ import os
33
+ import random
34
+ import re
35
+ import sys
36
+ from collections import defaultdict
37
+
38
+ # Force unbuffered stdout so progress is visible when piped
39
+ sys.stdout.reconfigure(line_buffering=True)
40
+
41
+ import numpy as np
42
+ import torch
43
+ from datasets import Dataset, load_dataset
44
+ from sentence_transformers import SentenceTransformer
45
+ from tqdm import tqdm
46
+
47
+ V3_MODEL_PATH = "models/static-embedding-chess/final"
48
+ OUTPUT_PATH = "models/hard_negatives.parquet"
49
+ SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
50
+ HELDOUT_FREQ_MIN = 3
51
+ HELDOUT_FREQ_MAX = 30
52
+ EVAL_QUERIES = 200
53
+ NUM_NEGATIVES = 5
54
+ RANGE_MIN = 10
55
+ RANGE_MAX = 50
56
+ ANCHOR_BATCH_SIZE = 500 # 500 * 327k * 4 = ~650MB scratch per batch
57
+
58
+
59
+ def _join_tags(tags):
60
+ return " ".join(t.replace("_", " ") for t in tags) if tags else ""
61
+
62
+
63
+ def _bigram_token_str(moves):
64
+ toks = moves.split()
65
+ if len(toks) < 2:
66
+ return moves
67
+ bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
68
+ return f"{moves} {bigrams}"
69
+
70
+
71
+ def build_puzzle_pairs(batch):
72
+ anchors, positives = [], []
73
+ for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
74
+ themes_txt = _join_tags(themes)
75
+ op_txt = _join_tags(op)
76
+ if not themes_txt:
77
+ continue
78
+ anchor = themes_txt + (f" {op_txt}" if op_txt else "")
79
+ positive = f"themes {themes_txt}"
80
+ if op_txt:
81
+ positive += f" opening {op_txt}"
82
+ positive += f" moves {_bigram_token_str(moves)}"
83
+ anchors.append(anchor)
84
+ positives.append(positive)
85
+ return {"anchor": anchors, "positive": positives}
86
+
87
+
88
+ def main():
89
+ print(f"Loading v3 model from {V3_MODEL_PATH}")
90
+ model = SentenceTransformer(V3_MODEL_PATH)
91
+
92
+ print("Loading puzzles...")
93
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
94
+ if SMOKE_TEST:
95
+ puzzles = puzzles.select(range(100_000))
96
+ pair_puzzles = puzzles.map(
97
+ build_puzzle_pairs,
98
+ batched=True,
99
+ batch_size=20_000,
100
+ remove_columns=puzzles.column_names,
101
+ num_proc=4,
102
+ )
103
+
104
+ # Materialize columns ONCE as Python lists (HF Dataset random access is
105
+ # O(N) per call due to Arrow buffer slicing -- 5.8M iterations would take
106
+ # forever otherwise).
107
+ print("Materializing columns...")
108
+ anchors_list = pair_puzzles["anchor"]
109
+ positives_list = pair_puzzles["positive"]
110
+ print(f" done ({len(anchors_list):,} rows)")
111
+
112
+ # Remove held-out anchors
113
+ freq = defaultdict(int)
114
+ for a in anchors_list:
115
+ freq[a] += 1
116
+ rare_pool = sorted(
117
+ ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
118
+ key=lambda kv: kv[1],
119
+ )
120
+ heldout = {a for a, _ in rare_pool[:EVAL_QUERIES]}
121
+
122
+ # Build one-per-anchor (use as both the anchor source AND the corpus source)
123
+ by_anchor = defaultdict(list)
124
+ for a, p in zip(anchors_list, positives_list):
125
+ if a not in heldout:
126
+ by_anchor[a].append(p)
127
+ print(f" unique anchors (post-heldout-strip): {len(by_anchor):,}")
128
+
129
+ rng = random.Random(12)
130
+ unique_anchors = list(by_anchor.keys())
131
+ if SMOKE_TEST:
132
+ unique_anchors = unique_anchors[:200]
133
+ print(f" SMOKE_TEST=1: trimmed to {len(unique_anchors)}")
134
+ # For each anchor, pick ONE random positive (skip the O(n^2) filter -- just
135
+ # iterate unique_anchors directly).
136
+ print(f" Sampling one positive per anchor...")
137
+ positives = [rng.choice(by_anchor[a]) for a in unique_anchors]
138
+ print(f" done")
139
+
140
+ # Encode anchors and positives
141
+ print(f"\nEncoding {len(unique_anchors):,} anchors...")
142
+ anchor_emb = model.encode(
143
+ unique_anchors, batch_size=512, show_progress_bar=True, convert_to_numpy=True
144
+ )
145
+ anchor_emb = anchor_emb / np.linalg.norm(anchor_emb, axis=1, keepdims=True)
146
+ print(f" anchor shape: {anchor_emb.shape}, mem: {anchor_emb.nbytes / 1e6:.1f}MB")
147
+
148
+ print(f"\nEncoding {len(positives):,} positives...")
149
+ positive_emb = model.encode(
150
+ positives, batch_size=512, show_progress_bar=True, convert_to_numpy=True
151
+ )
152
+ positive_emb = positive_emb / np.linalg.norm(positive_emb, axis=1, keepdims=True)
153
+ print(f" positive shape: {positive_emb.shape}, mem: {positive_emb.nbytes / 1e6:.1f}MB")
154
+
155
+ # Mine hard negs in chunks
156
+ print(f"\nMining hard negs (range={RANGE_MIN}..{RANGE_MAX}, num={NUM_NEGATIVES}, batch={ANCHOR_BATCH_SIZE})...")
157
+ out_anchors, out_positives, out_negatives = [], [], []
158
+ pos_scores_acc, neg_scores_acc = [], []
159
+ n_anchors = len(unique_anchors)
160
+
161
+ for start in tqdm(range(0, n_anchors, ANCHOR_BATCH_SIZE)):
162
+ end = min(start + ANCHOR_BATCH_SIZE, n_anchors)
163
+ ab = anchor_emb[start:end] # B x D
164
+ # scores: B x N. Each row i is anchor[start+i] vs all positives.
165
+ scores = ab @ positive_emb.T # B x N (float32)
166
+
167
+ # For each anchor i in batch, sort scores desc, get top RANGE_MAX
168
+ # excluding the actual positive (which is at column start+i).
169
+ # We use argpartition for efficiency.
170
+ for i in range(end - start):
171
+ anchor_idx = start + i
172
+ row = scores[i].copy()
173
+ # Mask out the actual positive (anchor's own positive is at anchor_idx)
174
+ row[anchor_idx] = -np.inf
175
+ # Take top RANGE_MAX indices
176
+ top_idx = np.argpartition(-row, RANGE_MAX)[:RANGE_MAX]
177
+ # Sort them by score
178
+ top_idx = top_idx[np.argsort(-row[top_idx])]
179
+ # Sample NUM_NEGATIVES from rank [RANGE_MIN, RANGE_MAX)
180
+ mid_range = top_idx[RANGE_MIN:RANGE_MAX]
181
+ sampled = rng.sample(list(mid_range), min(NUM_NEGATIVES, len(mid_range)))
182
+ for neg_idx in sampled:
183
+ out_anchors.append(unique_anchors[anchor_idx])
184
+ out_positives.append(positives[anchor_idx])
185
+ out_negatives.append(positives[neg_idx])
186
+ pos_scores_acc.append(float(scores[i, anchor_idx]))
187
+ neg_scores_acc.append(float(scores[i, neg_idx]))
188
+
189
+ print(f"\n output triplets: {len(out_anchors):,}")
190
+ print(f" positive scores: mean={np.mean(pos_scores_acc):.3f} std={np.std(pos_scores_acc):.3f}")
191
+ print(f" hard-neg scores: mean={np.mean(neg_scores_acc):.3f} std={np.std(neg_scores_acc):.3f}")
192
+ print(f" margin (pos - neg): mean={np.mean(np.array(pos_scores_acc) - np.array(neg_scores_acc)):.3f}")
193
+
194
+ # Save
195
+ os.makedirs(os.path.dirname(OUTPUT_PATH) or ".", exist_ok=True)
196
+ Dataset.from_dict({
197
+ "anchor": out_anchors,
198
+ "positive": out_positives,
199
+ "negative": out_negatives,
200
+ }).to_parquet(OUTPUT_PATH)
201
+ print(f" saved to {OUTPUT_PATH} ({os.path.getsize(OUTPUT_PATH) / 1e6:.1f} MB)")
202
+
203
+ # Sample
204
+ print("\n=== Sample triplets ===")
205
+ for i in [0, len(out_anchors)//2, len(out_anchors)-1]:
206
+ print(f" ANCHOR: {out_anchors[i]!r}")
207
+ print(f" POSITIVE:{out_positives[i][:100]!r}")
208
+ print(f" NEGATIVE:{out_negatives[i][:100]!r}")
209
+ print()
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()
scripts/train_chess_multitask.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "sentence-transformers[train]>=5.5.0",
6
+ # "datasets>=2.19.0",
7
+ # "accelerate>=0.26.0",
8
+ # "tokenizers>=0.20",
9
+ # ]
10
+ # ///
11
+ """Multi-task training: chess-aware semantic structure + hard-negative MNRL.
12
+
13
+ Two simultaneous training signals:
14
+
15
+ 1. THEME-DISTILL dataset: (theme_token, mpnet_definition_emb)
16
+ - 73 rows (one per Lichess theme)
17
+ - Loss: EmbedDistillLoss (project student 512d -> 768d, match teacher)
18
+ - Effect: enc("fork") moves toward MPNet("a tactical motif where one piece...")
19
+ - Solves orthogonal-token-embeddings problem identified in Phase 1
20
+
21
+ 2. CHESS-CONTENT dataset: (anchor, positive, hard_negative)
22
+ - From mined hard-negs of v3 model
23
+ - Loss: MultipleNegativesRankingLoss (handles triplets natively)
24
+ - Effect: maintains chess-content associations, sharpens discriminative ability
25
+
26
+ Multi-task trainer interleaves batches from both datasets. The theme dataset is
27
+ tiny (73 rows) but high-impact -- it injects semantic structure into 73 token
28
+ embeddings. The chess dataset is large (1.6M+ triplets) and shapes the rest.
29
+
30
+ Run:
31
+ SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_multitask.py
32
+ uv run --exclude-newer=2026-05-12 train_chess_multitask.py
33
+ """
34
+ from __future__ import annotations
35
+
36
+ import logging
37
+ import os
38
+ import random
39
+ import re
40
+ import time
41
+ from collections import defaultdict
42
+ from contextlib import nullcontext
43
+
44
+ import numpy as np
45
+ import torch
46
+ from datasets import Dataset, concatenate_datasets, load_dataset
47
+ from tokenizers import Tokenizer
48
+
49
+ from sentence_transformers import (
50
+ SentenceTransformer,
51
+ SentenceTransformerModelCardData,
52
+ SentenceTransformerTrainer,
53
+ SentenceTransformerTrainingArguments,
54
+ )
55
+ from sentence_transformers.base.sampler import BatchSamplers, MultiDatasetBatchSamplers
56
+ from sentence_transformers.sentence_transformer.evaluation import (
57
+ InformationRetrievalEvaluator,
58
+ )
59
+ from sentence_transformers.sentence_transformer.losses import (
60
+ EmbedDistillLoss,
61
+ MultipleNegativesRankingLoss,
62
+ )
63
+ from sentence_transformers.sentence_transformer.modules import StaticEmbedding
64
+ from transformers import EarlyStoppingCallback, TrainerCallback
65
+
66
+ THEME_DEFS_PATH = "models/theme_definitions.parquet"
67
+ TRIPLETS_PATH = "models/hard_negatives.parquet"
68
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "models/static-embedding-chess/chess_tokenizer.json")
69
+ OUTPUT_DIR = "models/static-embedding-chess-multitask"
70
+ RUN_NAME = "static-embedding-chess-multitask"
71
+ SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
72
+ EMBEDDING_DIM = 512
73
+ TEACHER_DIM = 768
74
+ HELDOUT_FREQ_MIN = 3
75
+ HELDOUT_FREQ_MAX = 30
76
+ EVAL_QUERIES = 200
77
+ THEME_REPLICAS = int(os.environ.get("THEME_REPLICAS", "500")) # oversample theme dataset
78
+
79
+ IS_CUDA = torch.cuda.is_available()
80
+ IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available()
81
+ BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256)
82
+
83
+
84
+ def setup_logging():
85
+ os.makedirs("logs", exist_ok=True)
86
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
87
+ logging.basicConfig(
88
+ format="%(asctime)s - %(message)s",
89
+ datefmt="%Y-%m-%d %H:%M:%S",
90
+ level=logging.INFO,
91
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")],
92
+ force=True,
93
+ )
94
+ for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
95
+ logging.getLogger(noisy).setLevel(logging.WARNING)
96
+
97
+
98
+ def _join_tags(tags):
99
+ return " ".join(t.replace("_", " ") for t in tags) if tags else ""
100
+
101
+
102
+ def _bigram_token_str(moves):
103
+ toks = moves.split()
104
+ if len(toks) < 2:
105
+ return moves
106
+ bigrams = " ".join(f"{a}+{b}" for a, b in zip(toks, toks[1:]))
107
+ return f"{moves} {bigrams}"
108
+
109
+
110
+ def build_puzzle_pairs(batch):
111
+ anchors, positives = [], []
112
+ for themes, op, moves in zip(batch["Themes"], batch["OpeningTags"], batch["Moves"]):
113
+ themes_txt = _join_tags(themes)
114
+ op_txt = _join_tags(op)
115
+ if not themes_txt:
116
+ continue
117
+ anchor = themes_txt + (f" {op_txt}" if op_txt else "")
118
+ positive = f"themes {themes_txt}"
119
+ if op_txt:
120
+ positive += f" opening {op_txt}"
121
+ positive += f" moves {_bigram_token_str(moves)}"
122
+ anchors.append(anchor)
123
+ positives.append(positive)
124
+ return {"anchor": anchors, "positive": positives}
125
+
126
+
127
+ def strip_theme_echo(p):
128
+ i = p.find(" moves ")
129
+ return p[i + 1 :] if i != -1 else p
130
+
131
+
132
+ def build_evaluator(holdout):
133
+ corpus = {f"d{i}": strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)}
134
+ by_anchor = defaultdict(set)
135
+ for i, row in enumerate(holdout):
136
+ by_anchor[row["anchor"]].add(f"d{i}")
137
+ sorted_a = sorted(by_anchor.items(), key=lambda kv: -len(kv[1]))
138
+ queries = {f"q{i}": a for i, (a, _) in enumerate(sorted_a)}
139
+ relevant = {f"q{i}": ids for i, (_, ids) in enumerate(sorted_a)}
140
+ return InformationRetrievalEvaluator(
141
+ queries=queries, corpus=corpus, relevant_docs=relevant,
142
+ name="chess-ir", ndcg_at_k=[10], mrr_at_k=[10],
143
+ accuracy_at_k=[1, 10], precision_recall_at_k=[1, 10],
144
+ show_progress_bar=False, batch_size=256,
145
+ )
146
+
147
+
148
+ def autocast_ctx():
149
+ if IS_CUDA:
150
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
151
+ return torch.autocast("cuda", dtype=dtype)
152
+ if IS_MPS:
153
+ return torch.autocast("mps", dtype=torch.float16)
154
+ return nullcontext()
155
+
156
+
157
+ def main():
158
+ setup_logging()
159
+
160
+ logging.info(f"Loading tokenizer from {TOKENIZER_PATH}")
161
+ tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
162
+ logging.info(f" vocab: {tokenizer.get_vocab_size():,}")
163
+
164
+ logging.info(f"Building random-init StaticEmbedding (dim={EMBEDDING_DIM})")
165
+ static = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM)
166
+ model = SentenceTransformer(
167
+ modules=[static],
168
+ model_card_data=SentenceTransformerModelCardData(
169
+ language="en", license="apache-2.0",
170
+ model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- multi-task (theme distill + hard-neg MNRL)",
171
+ ),
172
+ )
173
+
174
+ # === Dataset A: theme distillation ===
175
+ logging.info(f"Loading theme definitions from {THEME_DEFS_PATH}")
176
+ theme_ds_full = Dataset.from_parquet(THEME_DEFS_PATH)
177
+ # EmbedDistillLoss expects columns: sentence, label
178
+ theme_ds = theme_ds_full.rename_columns({"theme": "sentence", "embedding": "label"}).remove_columns(["definition"])
179
+ # Oversample to be seen alongside the much-larger chess dataset
180
+ if not SMOKE_TEST:
181
+ theme_ds = concatenate_datasets([theme_ds] * THEME_REPLICAS).shuffle(seed=12)
182
+ logging.info(f" {len(theme_ds):,} theme rows (after oversampling)")
183
+
184
+ # === Dataset B: chess triplets ===
185
+ logging.info(f"Loading triplets from {TRIPLETS_PATH}")
186
+ triplet_ds = Dataset.from_parquet(TRIPLETS_PATH)
187
+ if SMOKE_TEST:
188
+ triplet_ds = triplet_ds.select(range(min(500, len(triplet_ds))))
189
+ logging.info(f" {len(triplet_ds):,} triplets, columns: {triplet_ds.column_names}")
190
+
191
+ # === Build eval (same as previous runs) ===
192
+ logging.info("Building held-out eval")
193
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
194
+ if SMOKE_TEST:
195
+ puzzles = puzzles.select(range(2_000))
196
+ pair_puzzles = puzzles.map(
197
+ build_puzzle_pairs, batched=True, batch_size=20_000,
198
+ remove_columns=puzzles.column_names, num_proc=4,
199
+ )
200
+ anchors = pair_puzzles["anchor"]
201
+ freq = defaultdict(int)
202
+ for a in anchors:
203
+ freq[a] += 1
204
+ rare_pool = sorted(
205
+ ((a, c) for a, c in freq.items() if HELDOUT_FREQ_MIN <= c <= HELDOUT_FREQ_MAX),
206
+ key=lambda kv: kv[1],
207
+ )
208
+ n_eval = 20 if SMOKE_TEST else EVAL_QUERIES
209
+ heldout = {a for a, _ in rare_pool[:n_eval]}
210
+ held_idx = [i for i, h in enumerate([a in heldout for a in anchors]) if h]
211
+ holdout = pair_puzzles.select(held_idx)
212
+ logging.info(f" holdout: {len(holdout)}")
213
+ evaluator = build_evaluator(holdout)
214
+
215
+ logging.info("Baseline eval (random init):")
216
+ with autocast_ctx():
217
+ baseline = evaluator(model)[evaluator.primary_metric]
218
+ metric_key = f"eval_{evaluator.primary_metric}"
219
+ logging.info(f" baseline {evaluator.primary_metric} = {baseline:.4f}")
220
+
221
+ # === Multi-task setup ===
222
+ train_datasets = {
223
+ "chess": triplet_ds,
224
+ "themes": theme_ds,
225
+ }
226
+ losses = {
227
+ "chess": MultipleNegativesRankingLoss(model),
228
+ "themes": EmbedDistillLoss(model, distance_metric="cosine", projection_dim=TEACHER_DIM),
229
+ }
230
+
231
+ args = SentenceTransformerTrainingArguments(
232
+ output_dir=OUTPUT_DIR,
233
+ num_train_epochs=5,
234
+ max_steps=1 if SMOKE_TEST else -1,
235
+ per_device_train_batch_size=BATCH_SIZE,
236
+ per_device_eval_batch_size=BATCH_SIZE,
237
+ learning_rate=1e-2,
238
+ weight_decay=0.01,
239
+ warmup_steps=0.1,
240
+ lr_scheduler_type="linear",
241
+ bf16=IS_CUDA and torch.cuda.is_bf16_supported(),
242
+ fp16=IS_CUDA and not torch.cuda.is_bf16_supported(),
243
+ batch_sampler=BatchSamplers.BATCH_SAMPLER,
244
+ multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
245
+ eval_strategy="steps",
246
+ eval_steps=0.05,
247
+ save_strategy="steps",
248
+ save_steps=0.05,
249
+ save_total_limit=2,
250
+ logging_steps=0.02,
251
+ logging_first_step=True,
252
+ load_best_model_at_end=True,
253
+ metric_for_best_model=metric_key,
254
+ greater_is_better=True,
255
+ report_to="none",
256
+ run_name=RUN_NAME,
257
+ seed=12,
258
+ push_to_hub=False,
259
+ )
260
+
261
+ trainer = SentenceTransformerTrainer(
262
+ model=model, args=args,
263
+ train_dataset=train_datasets, loss=losses, evaluator=evaluator,
264
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
265
+ )
266
+ trainer.train()
267
+
268
+ logging.info("Post-training eval:")
269
+ with autocast_ctx():
270
+ score = evaluator(model)[evaluator.primary_metric]
271
+ delta = score - baseline
272
+ verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
273
+ logging.info(
274
+ f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline:.4f} | delta={delta:+.4f}"
275
+ )
276
+
277
+ # Also report current absolute vs v3 baseline (0.080)
278
+ v3_baseline = 0.0801
279
+ logging.info(f" vs v3 (0.0801): delta = {score - v3_baseline:+.4f}")
280
+
281
+ final_dir = f"{OUTPUT_DIR}/final"
282
+ model.save_pretrained(final_dir)
283
+ logging.info(f"Saved final model to {final_dir}")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
scripts/train_chess_static.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "sentence-transformers[train]>=5.5.0",
6
+ # "datasets>=2.19.0",
7
+ # "accelerate>=0.26.0",
8
+ # "tokenizers>=0.20",
9
+ # "trackio",
10
+ # ]
11
+ # ///
12
+ """Train a StaticEmbedding model for chess retrieval.
13
+
14
+ Pair shape:
15
+ anchor = "<themes> [<opening words>]"
16
+ positive = "themes <themes> [opening <words>] moves <uci>" (puzzles)
17
+ "name <words> eco <code> pgn <san>" (openings)
18
+
19
+ Datasets:
20
+ - Lichess/chess-puzzles (5.8M rows; themes + opening tags + UCI moves)
21
+ - Lichess/chess-openings (3.6K rows; opening name + ECO + SAN moves)
22
+
23
+ Use case: free-text search over a chess corpus. "fork endgame short" -> puzzles
24
+ with that motif; "Sicilian Najdorf" -> matching openings.
25
+
26
+ Design choices:
27
+ - Custom WordLevel + Whitespace tokenizer trained on the corpus. Every chess
28
+ token (UCI move e2e4, SAN move Nxd4, ECO code B90, theme name, opening word)
29
+ is one whole token -- BERT WordPiece would shred them 4-way.
30
+ - FEN dropped: position-as-character-soup doesn't fit a token-bag.
31
+ - PGN move numbers stripped ("1. e4 c5" -> "e4 c5") so SAN moves are high-freq.
32
+ - IR eval is custom (themes -> puzzles), not NanoBEIR -- general-English IR
33
+ benchmarks don't measure chess retrieval.
34
+
35
+ Run:
36
+ SMOKE_TEST=1 uv run --exclude-newer=2026-05-12 train_chess_static.py
37
+ uv run --exclude-newer=2026-05-12 train_chess_static.py
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ import logging
43
+ import os
44
+ import re
45
+ from collections import defaultdict
46
+ from contextlib import nullcontext
47
+
48
+ import datasets
49
+ import random
50
+ import torch
51
+ from datasets import Dataset, concatenate_datasets, load_dataset
52
+ from tokenizers import Tokenizer
53
+ from tokenizers.models import WordLevel
54
+ from tokenizers.pre_tokenizers import Whitespace
55
+ from tokenizers.trainers import WordLevelTrainer
56
+
57
+ from sentence_transformers import (
58
+ SentenceTransformer,
59
+ SentenceTransformerModelCardData,
60
+ SentenceTransformerTrainer,
61
+ SentenceTransformerTrainingArguments,
62
+ )
63
+ from sentence_transformers.base.sampler import BatchSamplers
64
+ from sentence_transformers.sentence_transformer.evaluation import (
65
+ InformationRetrievalEvaluator,
66
+ SequentialEvaluator,
67
+ )
68
+ from sentence_transformers.sentence_transformer.losses import (
69
+ MatryoshkaLoss,
70
+ MultipleNegativesRankingLoss,
71
+ )
72
+ from sentence_transformers.sentence_transformer.modules import StaticEmbedding
73
+ from transformers import EarlyStoppingCallback, TrainerCallback
74
+ import time
75
+
76
+
77
+ EMBEDDING_DIM = 512 # was 256; 512 gives more capacity for bigram tokens
78
+ MATRYOSHKA_DIMS = [512, 256, 128, 64, 32]
79
+ VOCAB_SIZE = 100_000 # was 50_000; UCI/SAN bigrams add ~20-50k vocab
80
+
81
+ OUTPUT_DIR = "models/static-embedding-chess"
82
+ RUN_NAME = "static-embedding-chess"
83
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "oneryalcin/static-embedding-chess")
84
+ # TOKENIZER_PATH default lives next to the model output. On Modal, set this to
85
+ # a path on the persistent volume (e.g. /cache/chess_tokenizer.json) so the
86
+ # 6-min WordLevelTrainer run is amortized across launches.
87
+ TOKENIZER_PATH = os.environ.get(
88
+ "TOKENIZER_PATH", f"{OUTPUT_DIR}/chess_tokenizer.json"
89
+ )
90
+ RETRAIN_TOKENIZER = os.environ.get("RETRAIN_TOKENIZER") == "1"
91
+ SMOKE_TEST = os.environ.get("SMOKE_TEST") == "1"
92
+ FORCE_CPU = os.environ.get("FORCE_CPU") == "1"
93
+ # Diagnostic knobs (default: full recipe). Both MPS and T4 show monotonic
94
+ # step-time growth with the full Matryoshka stack -- toggle these to isolate.
95
+ DISABLE_MATRYOSHKA = os.environ.get("DISABLE_MATRYOSHKA") == "1"
96
+ MAX_STEPS_OVERRIDE = int(os.environ.get("MAX_STEPS", "0")) or None
97
+ EVAL_STEPS_OVERRIDE = int(os.environ.get("EVAL_STEPS", "0")) or None
98
+
99
+ EVAL_QUERIES = 200
100
+ EVAL_CORPUS = 5_000
101
+ # Held-out anchor selection: pick rare combos in this freq range. Low end > 1
102
+ # keeps multi-relevant NDCG meaningful; high end caps memorization potential.
103
+ HELDOUT_FREQ_MIN = 3
104
+ HELDOUT_FREQ_MAX = 30
105
+ # Balanced-dataset config: each unique anchor expands to N (anchor, sampled_pos)
106
+ # rows. The original 5.8M pairs let the model memorize specific (anchor, pos)
107
+ # pairings since each anchor has ~1933 distinct positives. Capping at 100
108
+ # random samples per anchor gives the model meaningful variety without the
109
+ # 50x redundancy that fuels overfitting.
110
+ BALANCED_POSITIVES_PER_ANCHOR = int(os.environ.get("POSITIVES_PER_ANCHOR", "100"))
111
+ # Anchor token masking probability during training. 0 disables.
112
+ ANCHOR_MASK_PROB = float(os.environ.get("ANCHOR_MASK_PROB", "0.15"))
113
+
114
+ # Device-aware defaults. MPS (Apple Silicon) can't do bf16 and has unified-
115
+ # memory pressure, so the CUDA-targeted skill template defaults (batch=2048,
116
+ # bf16=True) don't apply. Scale BATCH_SIZE up if your M-series has 36GB+.
117
+ IS_CUDA = torch.cuda.is_available() and not FORCE_CPU
118
+ IS_MPS = (not IS_CUDA) and torch.backends.mps.is_available() and not FORCE_CPU
119
+ # StaticEmbedding is a lookup+average -- no transformer activations to fit.
120
+ # Memory cost is the (batch x batch) similarity matrix + (batch x seq x dim)
121
+ # lookups, both tiny. CachedMultipleNegativesRankingLoss is NOT compatible
122
+ # with StaticEmbedding (no encoder to GradCache through), so we just crank
123
+ # the real batch. Scale up freely if your M-series has the headroom.
124
+ BATCH_SIZE = 4096 if IS_CUDA else (4096 if IS_MPS else 256)
125
+
126
+ MOVE_NUM_RE = re.compile(r"\d+\.+")
127
+
128
+
129
+ class StepTimingCallback(TrainerCallback):
130
+ """Per-step instrumentation: wall time, CUDA memory, allocator state.
131
+ Costs ~1ms/step. Run-once-and-read approach to diagnosing slowdowns
132
+ instead of swapping configs and rerunning.
133
+ """
134
+
135
+ def on_step_begin(self, args, state, control, **kw):
136
+ if torch.cuda.is_available():
137
+ torch.cuda.synchronize()
138
+ self._t0 = time.perf_counter()
139
+
140
+ def on_step_end(self, args, state, control, **kw):
141
+ if torch.cuda.is_available():
142
+ torch.cuda.synchronize()
143
+ dt = time.perf_counter() - self._t0
144
+ # Log every step for the first 20 to see startup; then every 10th.
145
+ if state.global_step <= 20 or state.global_step % 10 == 0:
146
+ if torch.cuda.is_available():
147
+ mem = torch.cuda.memory_allocated() / 1e6
148
+ reserved = torch.cuda.memory_reserved() / 1e6
149
+ logging.info(
150
+ f"STEP {state.global_step}: dt={dt:.3f}s mem={mem:.0f}MB reserved={reserved:.0f}MB"
151
+ )
152
+ else:
153
+ logging.info(f"STEP {state.global_step}: dt={dt:.3f}s (cpu/mps)")
154
+
155
+
156
+ def autocast_ctx():
157
+ if IS_CUDA:
158
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
159
+ return torch.autocast("cuda", dtype=dtype)
160
+ if IS_MPS:
161
+ return torch.autocast("mps", dtype=torch.float16)
162
+ return nullcontext()
163
+
164
+
165
+ def setup_logging():
166
+ os.makedirs("logs", exist_ok=True)
167
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
168
+ logging.basicConfig(
169
+ format="%(asctime)s - %(message)s",
170
+ datefmt="%Y-%m-%d %H:%M:%S",
171
+ level=logging.INFO,
172
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"logs/{RUN_NAME}.log")],
173
+ force=True,
174
+ )
175
+ for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
176
+ logging.getLogger(noisy).setLevel(logging.WARNING)
177
+ if torch.cuda.is_available():
178
+ torch.set_float32_matmul_precision("high")
179
+
180
+
181
+ def _join_tags(tags) -> str:
182
+ if not tags:
183
+ return ""
184
+ return " ".join(t.replace("_", " ") for t in tags)
185
+
186
+
187
+ def _strip_pgn_move_numbers(pgn: str) -> str:
188
+ return MOVE_NUM_RE.sub("", pgn).strip()
189
+
190
+
191
+ def _bigram_token_str(moves: str) -> str:
192
+ """Append bigram tokens to a whitespace-separated move sequence.
193
+
194
+ "f2g3 e6e7 b2b1" -> "f2g3 e6e7 b2b1 f2g3+e6e7 e6e7+b2b1"
195
+
196
+ Bigrams use `+` as the join char so they're distinct from unigrams in the
197
+ WordLevel tokenizer's whitespace pretokenizer. A token-bag averaging across
198
+ unigrams alone loses move ordering; adding adjacent-pair tokens lets the
199
+ model learn that "e2e4 e7e5" (king's pawn opening) is its own pattern.
200
+ """
201
+ tokens = moves.split()
202
+ if len(tokens) < 2:
203
+ return moves
204
+ bigrams = " ".join(f"{a}+{b}" for a, b in zip(tokens, tokens[1:]))
205
+ return f"{moves} {bigrams}"
206
+
207
+
208
+ def build_puzzle_pairs(row_batch: dict) -> dict:
209
+ anchors, positives = [], []
210
+ for themes, opening_tags, moves in zip(
211
+ row_batch["Themes"], row_batch["OpeningTags"], row_batch["Moves"]
212
+ ):
213
+ themes_txt = _join_tags(themes)
214
+ opening_txt = _join_tags(opening_tags)
215
+ if not themes_txt:
216
+ continue
217
+ anchor = themes_txt + (f" {opening_txt}" if opening_txt else "")
218
+ positive = f"themes {themes_txt}"
219
+ if opening_txt:
220
+ positive += f" opening {opening_txt}"
221
+ positive += f" moves {_bigram_token_str(moves)}"
222
+ anchors.append(anchor)
223
+ positives.append(positive)
224
+ return {"anchor": anchors, "positive": positives}
225
+
226
+
227
+ def build_opening_pairs(row_batch: dict) -> dict:
228
+ anchors, positives = [], []
229
+ for name, eco, pgn in zip(row_batch["name"], row_batch["eco"], row_batch["pgn"]):
230
+ san = _strip_pgn_move_numbers(pgn)
231
+ anchors.append(f"{name} {eco}")
232
+ positives.append(f"name {name} eco {eco} pgn {_bigram_token_str(san)}")
233
+ return {"anchor": anchors, "positive": positives}
234
+
235
+
236
+ def load_chess_pairs() -> tuple[Dataset, Dataset]:
237
+ """Returns (train, holdout) where the holdout anchors are rare combinations
238
+ NEVER seen in train.
239
+
240
+ Old eval used the top-200 most-common theme strings as queries. The model
241
+ memorized these in training (each appears ~50k times) so eval was a recall
242
+ test on memorized lookups, not generalization. Replaced with compositional
243
+ held-out anchors:
244
+
245
+ - Pick anchor strings with frequency in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX]:
246
+ rare enough to be informative, common enough to have multiple positives
247
+ for multi-relevant eval.
248
+ - REMOVE all pairs with those anchors from train (no leakage).
249
+ - Use those rare anchors as eval queries; the held-out pairs become the
250
+ eval corpus.
251
+ - Individual theme tokens within those anchors still appear *separately*
252
+ in many other training anchors, so the model has learned each token's
253
+ embedding -- it just hasn't seen this particular combination. Tests
254
+ compositional generalization.
255
+ """
256
+ logging.info("Loading Lichess/chess-puzzles (5.8M rows)")
257
+ puzzles = load_dataset("Lichess/chess-puzzles", split="train")
258
+ if SMOKE_TEST:
259
+ puzzles = puzzles.select(range(2_000))
260
+ pair_puzzles = puzzles.map(
261
+ build_puzzle_pairs,
262
+ batched=True,
263
+ batch_size=10_000,
264
+ remove_columns=puzzles.column_names,
265
+ desc="puzzles -> pairs",
266
+ )
267
+ logging.info(f" built {len(pair_puzzles):,} puzzle pairs")
268
+
269
+ logging.info("Loading Lichess/chess-openings (3.6K rows)")
270
+ openings = load_dataset("Lichess/chess-openings", split="train").remove_columns(["img"])
271
+ pair_openings = openings.map(
272
+ build_opening_pairs,
273
+ batched=True,
274
+ remove_columns=openings.column_names,
275
+ desc="openings -> pairs",
276
+ )
277
+ logging.info(f" built {len(pair_openings):,} opening pairs")
278
+
279
+ # Count anchor frequencies across the puzzle pairs.
280
+ logging.info("Computing anchor frequencies for held-out selection")
281
+ anchors = pair_puzzles["anchor"]
282
+ freq: dict[str, int] = defaultdict(int)
283
+ for a in anchors:
284
+ freq[a] += 1
285
+ logging.info(f" {len(freq):,} unique anchors in puzzle pairs")
286
+
287
+ # Pick rare anchors: each appears in [HELDOUT_FREQ_MIN, HELDOUT_FREQ_MAX] pairs.
288
+ # In smoke mode, lower the min so the tiny corpus still produces enough
289
+ # held-out queries (smoke has ~2k puzzles, most anchors freq 1-2).
290
+ min_freq = 2 if SMOKE_TEST else HELDOUT_FREQ_MIN
291
+ max_freq = HELDOUT_FREQ_MAX
292
+ rare_pool = sorted(
293
+ ((a, c) for a, c in freq.items() if min_freq <= c <= max_freq),
294
+ key=lambda kv: kv[1], # ascending: rarest first
295
+ )
296
+ n_queries_target = 20 if SMOKE_TEST else EVAL_QUERIES
297
+ if len(rare_pool) < n_queries_target:
298
+ logging.warning(
299
+ f"Only {len(rare_pool)} anchors in freq range [{HELDOUT_FREQ_MIN},{HELDOUT_FREQ_MAX}]; "
300
+ f"using all of them ({n_queries_target} requested)"
301
+ )
302
+ heldout_anchors = {a for a, _ in rare_pool[:n_queries_target]}
303
+ logging.info(
304
+ f" selected {len(heldout_anchors)} held-out anchors "
305
+ f"(freq range: {rare_pool[0][1] if rare_pool else 0}..{rare_pool[min(n_queries_target, len(rare_pool))-1][1] if rare_pool else 0})"
306
+ )
307
+
308
+ # Filter: pairs whose anchor is held-out -> eval; everything else -> train.
309
+ held_mask = [a in heldout_anchors for a in anchors]
310
+ holdout = pair_puzzles.select([i for i, h in enumerate(held_mask) if h])
311
+ train_puzzles = pair_puzzles.select([i for i, h in enumerate(held_mask) if not h])
312
+ logging.info(
313
+ f" split by held-out anchors: train={len(train_puzzles):,}, holdout={len(holdout):,}"
314
+ )
315
+
316
+ # Train includes the (non-held) puzzle pairs + all openings.
317
+ train = concatenate_datasets([train_puzzles, pair_openings]).shuffle(seed=12)
318
+ logging.info(f" train: {len(train):,} pairs | holdout: {len(holdout):,} pairs")
319
+ return train, holdout
320
+
321
+
322
+ def make_balanced_dataset(train: Dataset, n_per_anchor: int) -> Dataset:
323
+ """Cap each anchor's positives to `n_per_anchor` random picks. Breaks the
324
+ 5.8M pairs' redundancy (each anchor x ~1933 positives) so the model can't
325
+ memorize specific (anchor, positive) pairings while still seeing useful
326
+ positive variety per anchor.
327
+ """
328
+ by_anchor: dict[str, list[str]] = defaultdict(list)
329
+ for row in train:
330
+ by_anchor[row["anchor"]].append(row["positive"])
331
+ rng = random.Random(12)
332
+ new_anchors, new_positives = [], []
333
+ for anchor, positives in by_anchor.items():
334
+ sample = (
335
+ rng.sample(positives, n_per_anchor)
336
+ if len(positives) > n_per_anchor
337
+ else positives
338
+ )
339
+ for p in sample:
340
+ new_anchors.append(anchor)
341
+ new_positives.append(p)
342
+ logging.info(
343
+ f"Balanced dataset: {len(by_anchor):,} unique anchors -> "
344
+ f"{len(new_anchors):,} pairs (cap {n_per_anchor}/anchor)"
345
+ )
346
+ return Dataset.from_dict({"anchor": new_anchors, "positive": new_positives}).shuffle(seed=12)
347
+
348
+
349
+ def make_anchor_masker(mask_prob: float, rng_seed: int = 12):
350
+ """Return a `set_transform` callable that randomly replaces theme tokens
351
+ with [UNK] in the anchor. Token-bag dropout: forces the model to use
352
+ remaining tokens instead of memorizing the exact combination."""
353
+ if mask_prob <= 0:
354
+ return None
355
+ rng = random.Random(rng_seed)
356
+
357
+ def _mask(batch: dict) -> dict:
358
+ anchors = batch["anchor"]
359
+ new_anchors = []
360
+ for a in anchors:
361
+ tokens = a.split()
362
+ if len(tokens) <= 1:
363
+ new_anchors.append(a)
364
+ continue
365
+ kept = [t if rng.random() >= mask_prob else "[UNK]" for t in tokens]
366
+ # Guard against masking everything: if all UNK, restore one random token.
367
+ if all(t == "[UNK]" for t in kept):
368
+ kept[rng.randrange(len(kept))] = tokens[rng.randrange(len(tokens))]
369
+ new_anchors.append(" ".join(kept))
370
+ return {"anchor": new_anchors, "positive": batch["positive"]}
371
+
372
+ return _mask
373
+
374
+
375
+ def train_chess_tokenizer(train: Dataset) -> Tokenizer:
376
+ """Train or load a WordLevel tokenizer for the chess corpus.
377
+
378
+ Every space-separated unit (theme word, opening word, ECO code, UCI move,
379
+ SAN move) becomes one whole token. Compare to BERT WordPiece which fragments
380
+ "f2g3" into 4 subword pieces -- a token-bag wastes capacity on subword joins
381
+ that carry no chess meaning.
382
+
383
+ Caching: if TOKENIZER_PATH exists, load and return it instead of rebuilding.
384
+ The WordLevelTrainer is single-threaded Rust and takes ~6 min on 11.6M
385
+ strings. Tokenizer is deterministic given the same corpus + config, so
386
+ caching is safe. Set RETRAIN_TOKENIZER=1 to force rebuild.
387
+ """
388
+ if not RETRAIN_TOKENIZER and os.path.exists(TOKENIZER_PATH):
389
+ tok = Tokenizer.from_file(TOKENIZER_PATH)
390
+ logging.info(
391
+ f"Reusing cached tokenizer ({tok.get_vocab_size():,} tokens) from {TOKENIZER_PATH}"
392
+ )
393
+ return tok
394
+
395
+ logging.info(f"Training WordLevel tokenizer on {len(train):,} pairs (vocab={VOCAB_SIZE})")
396
+ tok = Tokenizer(WordLevel(unk_token="[UNK]"))
397
+ tok.pre_tokenizer = Whitespace()
398
+ trainer = WordLevelTrainer(
399
+ vocab_size=VOCAB_SIZE,
400
+ special_tokens=["[UNK]", "[PAD]"],
401
+ min_frequency=2,
402
+ )
403
+
404
+ def text_iter():
405
+ for row in train:
406
+ yield row["anchor"]
407
+ yield row["positive"]
408
+
409
+ tok.train_from_iterator(text_iter(), trainer=trainer, length=2 * len(train))
410
+ actual_vocab = tok.get_vocab_size()
411
+ logging.info(f" tokenizer trained: {actual_vocab:,} tokens (cap was {VOCAB_SIZE:,})")
412
+ os.makedirs(os.path.dirname(TOKENIZER_PATH) or ".", exist_ok=True)
413
+ tok.save(TOKENIZER_PATH)
414
+ logging.info(f" saved tokenizer to {TOKENIZER_PATH}")
415
+ return tok
416
+
417
+
418
+ def _strip_theme_echo(positive: str) -> str:
419
+ """Eval corpus must not echo the themes the query asks about, or the
420
+ baseline (random-init) scores high just from lexical token overlap. Keep
421
+ only the moves segment."""
422
+ idx = positive.find(" moves ")
423
+ return positive[idx + 1 :] if idx != -1 else positive
424
+
425
+
426
+ def _build_compositional_ir_evaluator(
427
+ holdout: Dataset, corpus: dict[str, str], name: str
428
+ ) -> InformationRetrievalEvaluator:
429
+ """Compositional: each unseen anchor string is a query."""
430
+ by_anchor: dict[str, set[str]] = defaultdict(set)
431
+ for i, row in enumerate(holdout):
432
+ by_anchor[row["anchor"]].add(f"d{i}")
433
+ sorted_anchors = sorted(by_anchor.items(), key=lambda kv: -len(kv[1]))
434
+ queries = {f"q{i}": anchor for i, (anchor, _) in enumerate(sorted_anchors)}
435
+ relevant_docs = {f"q{i}": docs for i, (_, docs) in enumerate(sorted_anchors)}
436
+ avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
437
+ logging.info(
438
+ f" [{name}] {len(queries)} queries (unseen combos), avg relevant/query={avg_rel:.1f}"
439
+ )
440
+ return _ir_evaluator(queries, corpus, relevant_docs, name)
441
+
442
+
443
+ def _build_single_theme_ir_evaluator(
444
+ holdout: Dataset, corpus: dict[str, str], name: str
445
+ ) -> InformationRetrievalEvaluator:
446
+ """Single-theme: each individual theme token from the held-out anchors is
447
+ a query. Tests whether per-token embeddings are useful in isolation.
448
+
449
+ Relevant docs for query "fork" = any held-out doc whose anchor contains
450
+ the token "fork". Coarser than the compositional eval (much higher avg
451
+ relevant/query) but a sharper test of token-level meaning.
452
+ """
453
+ theme_to_docs: dict[str, set[str]] = defaultdict(set)
454
+ for i, row in enumerate(holdout):
455
+ for token in row["anchor"].split():
456
+ theme_to_docs[token].add(f"d{i}")
457
+ min_relevant = 2 if SMOKE_TEST else 3
458
+ candidates = [(t, d) for t, d in theme_to_docs.items() if len(d) >= min_relevant]
459
+ candidates.sort(key=lambda kv: -len(kv[1]))
460
+ queries = {f"t{i}": tok for i, (tok, _) in enumerate(candidates)}
461
+ relevant_docs = {f"t{i}": docs for i, (_, docs) in enumerate(candidates)}
462
+ avg_rel = sum(len(v) for v in relevant_docs.values()) / max(1, len(relevant_docs))
463
+ logging.info(
464
+ f" [{name}] {len(queries)} single-token queries, avg relevant/query={avg_rel:.1f}"
465
+ )
466
+ return _ir_evaluator(queries, corpus, relevant_docs, name)
467
+
468
+
469
+ def _ir_evaluator(queries, corpus, relevant_docs, name):
470
+ return InformationRetrievalEvaluator(
471
+ queries=queries,
472
+ corpus=corpus,
473
+ relevant_docs=relevant_docs,
474
+ name=name,
475
+ ndcg_at_k=[10],
476
+ mrr_at_k=[10],
477
+ accuracy_at_k=[1, 10],
478
+ precision_recall_at_k=[1, 10],
479
+ show_progress_bar=False,
480
+ batch_size=256,
481
+ )
482
+
483
+
484
+ def build_ir_evaluator(holdout: Dataset, name: str = "chess-ir") -> SequentialEvaluator:
485
+ """Wraps two evaluators (compositional + single-theme) into a sequential
486
+ pass. The compositional one's score drives best-model selection; the
487
+ single-theme one is informational.
488
+ """
489
+ corpus = {f"d{i}": _strip_theme_echo(row["positive"]) for i, row in enumerate(holdout)}
490
+ logging.info(f"IR eval setup ({len(corpus)} corpus docs):")
491
+ compositional = _build_compositional_ir_evaluator(holdout, corpus, name=name)
492
+ single_theme = _build_single_theme_ir_evaluator(holdout, corpus, name=f"{name}-tokens")
493
+ # First evaluator's score drives load_best_model_at_end (compositional).
494
+ return SequentialEvaluator(
495
+ [compositional, single_theme],
496
+ main_score_function=lambda scores: scores[0],
497
+ )
498
+
499
+
500
+ def main() -> None:
501
+ setup_logging()
502
+
503
+ train_dataset, holdout = load_chess_pairs()
504
+ if SMOKE_TEST:
505
+ train_dataset = train_dataset.select(range(min(500, len(train_dataset))))
506
+
507
+ # Train the tokenizer on the FULL (pre-balanced) corpus -- we want every
508
+ # token to be seen as many times as possible for the vocab pass.
509
+ tokenizer = train_chess_tokenizer(train_dataset)
510
+
511
+ # Now down-sample to a balanced dataset for the contrastive training.
512
+ train_dataset = make_balanced_dataset(train_dataset, BALANCED_POSITIVES_PER_ANCHOR)
513
+
514
+ # Optional anchor-token masking applied on the fly via set_transform.
515
+ masker = make_anchor_masker(ANCHOR_MASK_PROB)
516
+ if masker is not None:
517
+ logging.info(f"Anchor token masking enabled (p={ANCHOR_MASK_PROB})")
518
+ train_dataset.set_transform(masker)
519
+
520
+ logging.info(f"Random-init StaticEmbedding (dim={EMBEDDING_DIM})")
521
+ static_embedding = StaticEmbedding(tokenizer, embedding_dim=EMBEDDING_DIM)
522
+ model = SentenceTransformer(
523
+ modules=[static_embedding],
524
+ model_card_data=SentenceTransformerModelCardData(
525
+ language="en",
526
+ license="apache-2.0",
527
+ model_name=f"Static chess embedding ({EMBEDDING_DIM}d) -- themes/openings <-> positions",
528
+ ),
529
+ )
530
+
531
+ evaluator = build_ir_evaluator(holdout)
532
+ inner = MultipleNegativesRankingLoss(model)
533
+ if DISABLE_MATRYOSHKA:
534
+ logging.info("Matryoshka DISABLED -- training at single dim (diagnostic)")
535
+ loss = inner
536
+ else:
537
+ loss = MatryoshkaLoss(model, inner, matryoshka_dims=MATRYOSHKA_DIMS)
538
+
539
+ logging.info("Baseline evaluation (random init -- expect near-zero):")
540
+ with autocast_ctx():
541
+ baseline_eval = evaluator(model)[evaluator.primary_metric]
542
+ metric_key = f"eval_{evaluator.primary_metric}"
543
+ logging.info(f" baseline {evaluator.primary_metric} = {baseline_eval:.4f}")
544
+
545
+ if SMOKE_TEST:
546
+ max_steps = 1
547
+ elif MAX_STEPS_OVERRIDE:
548
+ max_steps = MAX_STEPS_OVERRIDE
549
+ else:
550
+ max_steps = -1
551
+ eval_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05 # 20 evals/run
552
+ save_steps = EVAL_STEPS_OVERRIDE if EVAL_STEPS_OVERRIDE else 0.05
553
+
554
+ args = SentenceTransformerTrainingArguments(
555
+ output_dir=OUTPUT_DIR,
556
+ # Balanced dataset is small (~300k pairs); need many epochs to reach
557
+ # comparable total training signal. Early stopping handles excess.
558
+ num_train_epochs=20,
559
+ max_steps=max_steps,
560
+ per_device_train_batch_size=BATCH_SIZE,
561
+ per_device_eval_batch_size=BATCH_SIZE,
562
+ learning_rate=1e-2, # was 5e-2 -- much slower convergence, shifts peak later
563
+ weight_decay=0.01, # was 0.0 -- regularization on the embedding table
564
+ warmup_steps=0.1,
565
+ lr_scheduler_type="linear",
566
+ bf16=IS_CUDA and torch.cuda.is_bf16_supported(),
567
+ fp16=IS_CUDA and not torch.cuda.is_bf16_supported(),
568
+ # was NO_DUPLICATES -- linked-list scan over deferred conflicts gives
569
+ # O(epoch_progress) per-batch cost. With ~3000 unique anchors over
570
+ # 5.8M pairs, dedup is fighting impossible odds. BATCH_SAMPLER (random)
571
+ # is fast and accepts mild within-batch anchor duplication.
572
+ batch_sampler=BatchSamplers.BATCH_SAMPLER,
573
+ eval_strategy="steps",
574
+ eval_steps=eval_steps,
575
+ save_strategy="steps",
576
+ save_steps=save_steps,
577
+ save_total_limit=2,
578
+ logging_steps=0.01,
579
+ logging_first_step=True,
580
+ load_best_model_at_end=True,
581
+ metric_for_best_model=metric_key,
582
+ greater_is_better=True,
583
+ # Trackio crashes at first checkpoint push: empty `router_mapping`
584
+ # struct can't be written to parquet. Disable.
585
+ report_to="none",
586
+ run_name=RUN_NAME,
587
+ seed=12,
588
+ # HF Jobs: container is destroyed after run -- push every checkpoint to
589
+ # the Hub so partial progress survives a timeout. The end-of-run
590
+ # model.push_to_hub() below is the belt to this suspenders.
591
+ push_to_hub=not SMOKE_TEST,
592
+ hub_model_id=HUB_MODEL_ID,
593
+ hub_strategy="every_save",
594
+ )
595
+
596
+ trainer = SentenceTransformerTrainer(
597
+ model=model,
598
+ args=args,
599
+ train_dataset=train_dataset,
600
+ loss=loss,
601
+ evaluator=evaluator,
602
+ callbacks=[
603
+ # Auto-stop if compositional NDCG@10 doesn't improve for 3 evals.
604
+ # Lower lr makes curves smoother -- give it slack vs the patience=2
605
+ # we used at lr=5e-2.
606
+ EarlyStoppingCallback(early_stopping_patience=3),
607
+ # Per-step memory + dt logging.
608
+ StepTimingCallback(),
609
+ ],
610
+ )
611
+ trainer.train()
612
+
613
+ logging.info("Post-training evaluation:")
614
+ with autocast_ctx():
615
+ score = evaluator(model)[evaluator.primary_metric]
616
+ delta = score - baseline_eval
617
+ verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
618
+ logging.info(
619
+ f"VERDICT: {verdict} | score={score:.4f} | baseline={baseline_eval:.4f} | delta={delta:+.4f}"
620
+ )
621
+
622
+ final_dir = f"{OUTPUT_DIR}/final"
623
+ model.save_pretrained(final_dir)
624
+ logging.info(f"Saved final model to {final_dir}")
625
+
626
+ if SMOKE_TEST:
627
+ logging.info("SMOKE_TEST=1: skipping Hub push")
628
+ return
629
+
630
+ try:
631
+ commit_url = model.push_to_hub(HUB_MODEL_ID)
632
+ logging.info(f"Pushed model to {commit_url.rsplit('/commit/', 1)[0]}")
633
+ except Exception:
634
+ import traceback
635
+
636
+ logging.error(f"Hub push failed:\n{traceback.format_exc()}")
637
+
638
+
639
+ if __name__ == "__main__":
640
+ main()