smsk-01 commited on
Commit
2b8b4b2
·
verified ·
1 Parent(s): 50dda7c

Chess Challenge submission by smsk-01

Browse files
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chess
5
+ - llm-course
6
+ - chess-challenge
7
+ license: mit
8
+ ---
9
+
10
+ # chess-chess-smsk01
11
+
12
+ Chess model submitted to the LLM Course Chess Challenge.
13
+
14
+ ## Submission Info
15
+
16
+ - **Submitted by**: [smsk-01](https://huggingface.co/smsk-01)
17
+ - **Parameters**: 887,936
18
+ - **Organization**: LLM-course
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+
25
+ model = AutoModelForCausalLM.from_pretrained("LLM-course/chess-chess-smsk01", trust_remote_code=True)
26
+ tokenizer = AutoTokenizer.from_pretrained("LLM-course/chess-chess-smsk01", trust_remote_code=True)
27
+ ```
28
+
29
+ ## Evaluation
30
+
31
+ This model is evaluated at the [Chess Challenge Arena](https://huggingface.co/spaces/LLM-course/Chess1MChallenge).
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ChessForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "dropout": 0.1,
7
+ "dtype": "float32",
8
+ "eos_token_id": 2,
9
+ "layer_norm_epsilon": 1e-05,
10
+ "model_type": "chess_transformer",
11
+ "n_ctx": 320,
12
+ "n_embd": 128,
13
+ "n_head": 4,
14
+ "n_inner": 384,
15
+ "n_layer": 5,
16
+ "pad_token_id": 0,
17
+ "tie_weights": true,
18
+ "transformers_version": "4.57.6",
19
+ "vocab_size": 155,
20
+ "auto_map": {
21
+ "AutoConfig": "model.ChessConfig",
22
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
23
+ }
24
+ }
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 23,
4
+ "eos_token_id": [
5
+ 24
6
+ ],
7
+ "pad_token_id": 25,
8
+ "transformers_version": "4.57.6"
9
+ }
model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+ This module provides a simple GPT-style transformer architecture
4
+ designed to fit within the 1M parameter constraint.
5
+ Key components:
6
+ - ChessConfig: Configuration class for model hyperparameters
7
+ - ChessForCausalLM: The main model class for next-move prediction
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from dataclasses import dataclass
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+ from transformers.modeling_outputs import CausalLMOutputWithPast
21
+
22
+
23
+ class ChessConfig(PretrainedConfig):
24
+ """
25
+ Configuration class for the Chess Transformer model.
26
+
27
+ This configuration is designed for a ~1M parameter model.
28
+ Students can adjust these values to explore different architectures.
29
+
30
+ Parameter budget breakdown (with default values):
31
+ - Embeddings (vocab): 1200 x 128 = 153,600
32
+ - Position Embeddings: 256 x 128 = 32,768
33
+ - Transformer Layers: 6 x ~120,000 = ~720,000
34
+ - LM Head (with weight tying): 0 (shared with embeddings)
35
+ - Total: ~906,000 parameters
36
+
37
+ Attributes:
38
+ vocab_size: Size of the vocabulary (number of unique moves).
39
+ n_embd: Embedding dimension (d_model).
40
+ n_layer: Number of transformer layers.
41
+ n_head: Number of attention heads.
42
+ n_ctx: Maximum sequence length (context window).
43
+ n_inner: Feed-forward inner dimension (default: 3 * n_embd).
44
+ dropout: Dropout probability.
45
+ layer_norm_epsilon: Epsilon for layer normalization.
46
+ tie_weights: Whether to tie embedding and output weights.
47
+ """
48
+
49
+ model_type = "chess_transformer"
50
+
51
+ def __init__(
52
+ self,
53
+ vocab_size: int = 1200,
54
+ n_embd: int = 128,
55
+ n_layer: int = 6,
56
+ n_head: int = 4,
57
+ n_ctx: int = 256,
58
+ n_inner: Optional[int] = None,
59
+ dropout: float = 0.1,
60
+ layer_norm_epsilon: float = 1e-5,
61
+ tie_weights: bool = True,
62
+ pad_token_id: int = 0,
63
+ bos_token_id: int = 1,
64
+ eos_token_id: int = 2,
65
+ **kwargs,
66
+ ):
67
+ super().__init__(
68
+ pad_token_id=pad_token_id,
69
+ bos_token_id=bos_token_id,
70
+ eos_token_id=eos_token_id,
71
+ **kwargs,
72
+ )
73
+
74
+ self.vocab_size = vocab_size
75
+ self.n_embd = n_embd
76
+ self.n_layer = n_layer
77
+ self.n_head = n_head
78
+ self.n_ctx = n_ctx
79
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd # Reduced from 4x to 3x
80
+ self.dropout = dropout
81
+ self.layer_norm_epsilon = layer_norm_epsilon
82
+ self.tie_weights = tie_weights
83
+ # Inform HF base class about tying behavior
84
+ self.tie_word_embeddings = bool(tie_weights)
85
+
86
+
87
+ class MultiHeadAttention(nn.Module):
88
+ """
89
+ Multi-head self-attention module.
90
+
91
+ This is a standard scaled dot-product attention implementation
92
+ with causal masking for autoregressive generation.
93
+ """
94
+
95
+ def __init__(self, config: ChessConfig):
96
+ super().__init__()
97
+
98
+ assert config.n_embd % config.n_head == 0, \
99
+ f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
100
+
101
+ self.n_head = config.n_head
102
+ self.n_embd = config.n_embd
103
+ self.head_dim = config.n_embd // config.n_head
104
+
105
+ # Combined QKV projection for efficiency
106
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
107
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
108
+
109
+ self.dropout = nn.Dropout(config.dropout)
110
+
111
+ # Causal mask (will be created on first forward pass)
112
+ self.register_buffer(
113
+ "bias",
114
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
115
+ 1, 1, config.n_ctx, config.n_ctx
116
+ ),
117
+ persistent=False,
118
+ )
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ attention_mask: Optional[torch.Tensor] = None,
124
+ ) -> torch.Tensor:
125
+ batch_size, seq_len, _ = x.size()
126
+
127
+ # Compute Q, K, V
128
+ qkv = self.c_attn(x)
129
+ q, k, v = qkv.split(self.n_embd, dim=2)
130
+
131
+ # Reshape for multi-head attention
132
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
133
+ k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
134
+ v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
135
+
136
+ # Scaled dot-product attention
137
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
138
+
139
+ # Apply causal mask
140
+ causal_mask = self.bias[:, :, :seq_len, :seq_len]
141
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
142
+
143
+ # Apply attention mask (for padding)
144
+ if attention_mask is not None:
145
+ # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
146
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
147
+ attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
148
+
149
+ attn_weights = F.softmax(attn_weights, dim=-1)
150
+ attn_weights = self.dropout(attn_weights)
151
+
152
+ # Apply attention to values
153
+ attn_output = torch.matmul(attn_weights, v)
154
+
155
+ # Reshape back
156
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
157
+ batch_size, seq_len, self.n_embd
158
+ )
159
+
160
+ # Output projection
161
+ attn_output = self.c_proj(attn_output)
162
+
163
+ return attn_output
164
+
165
+
166
+ class FeedForward(nn.Module):
167
+ """
168
+ Feed-forward network (MLP) module.
169
+
170
+ Standard two-layer MLP with GELU activation.
171
+ """
172
+
173
+ def __init__(self, config: ChessConfig):
174
+ super().__init__()
175
+
176
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
177
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
178
+ self.dropout = nn.Dropout(config.dropout)
179
+
180
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
181
+ x = self.c_fc(x)
182
+ x = F.gelu(x)
183
+ x = self.c_proj(x)
184
+ x = self.dropout(x)
185
+ return x
186
+
187
+
188
+ class TransformerBlock(nn.Module):
189
+ """
190
+ A single transformer block with attention and feed-forward layers.
191
+
192
+ Uses pre-normalization (LayerNorm before attention/FFN) for better
193
+ training stability.
194
+ """
195
+
196
+ def __init__(self, config: ChessConfig):
197
+ super().__init__()
198
+
199
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
200
+ self.attn = MultiHeadAttention(config)
201
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
202
+ self.mlp = FeedForward(config)
203
+
204
+ def forward(
205
+ self,
206
+ x: torch.Tensor,
207
+ attention_mask: Optional[torch.Tensor] = None,
208
+ ) -> torch.Tensor:
209
+ # Pre-norm attention
210
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
211
+ # Pre-norm FFN
212
+ x = x + self.mlp(self.ln_2(x))
213
+ return x
214
+
215
+
216
+ class ChessForCausalLM(PreTrainedModel):
217
+ """
218
+ Chess Transformer for Causal Language Modeling (next-move prediction).
219
+
220
+ This model is designed to predict the next chess move given a sequence
221
+ of previous moves. It uses a GPT-style architecture with:
222
+ - Token embeddings for chess moves
223
+ - Learned positional embeddings
224
+ - Stacked transformer blocks
225
+ - Linear head for next-token prediction
226
+
227
+ The model supports weight tying between the embedding layer and the
228
+ output projection to save parameters.
229
+
230
+ Example:
231
+ >>> config = ChessConfig(vocab_size=1200, n_embd=128, n_layer=6)
232
+ >>> model = ChessForCausalLM(config)
233
+ >>> inputs = {"input_ids": torch.tensor([[1, 42, 87]])}
234
+ >>> outputs = model(**inputs)
235
+ >>> next_move_logits = outputs.logits[:, -1, :]
236
+ """
237
+
238
+ config_class = ChessConfig
239
+ base_model_prefix = "transformer"
240
+ supports_gradient_checkpointing = True
241
+ # Suppress missing-key warning for tied lm_head when loading
242
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
243
+
244
+ def __init__(self, config: ChessConfig):
245
+ super().__init__(config)
246
+
247
+ # Token and position embeddings
248
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
249
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
250
+
251
+ self.drop = nn.Dropout(config.dropout)
252
+
253
+ # Transformer blocks
254
+ self.h = nn.ModuleList([
255
+ TransformerBlock(config) for _ in range(config.n_layer)
256
+ ])
257
+
258
+ # Final layer norm
259
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
260
+
261
+ # Output head
262
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
263
+
264
+ # Declare tied weights for proper serialization
265
+ if config.tie_weights:
266
+ self._tied_weights_keys = ["lm_head.weight"]
267
+
268
+ # Initialize weights
269
+ self.post_init()
270
+
271
+ # Tie weights if configured
272
+ if config.tie_weights:
273
+ self.tie_weights()
274
+
275
+ def get_input_embeddings(self) -> nn.Module:
276
+ return self.wte
277
+
278
+ def set_input_embeddings(self, new_embeddings: nn.Module):
279
+ self.wte = new_embeddings
280
+ if getattr(self.config, "tie_weights", False):
281
+ self.tie_weights()
282
+
283
+ def get_output_embeddings(self) -> nn.Module:
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings: nn.Module):
287
+ self.lm_head = new_embeddings
288
+
289
+ def tie_weights(self):
290
+ # Use HF helper to tie or clone depending on config
291
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
292
+ self._tie_or_clone_weights(self.lm_head, self.wte)
293
+
294
+ def _init_weights(self, module: nn.Module):
295
+ """Initialize weights following GPT-2 style."""
296
+ if isinstance(module, nn.Linear):
297
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
298
+ if module.bias is not None:
299
+ torch.nn.init.zeros_(module.bias)
300
+ elif isinstance(module, nn.Embedding):
301
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
302
+ elif isinstance(module, nn.LayerNorm):
303
+ torch.nn.init.ones_(module.weight)
304
+ torch.nn.init.zeros_(module.bias)
305
+
306
+ def forward(
307
+ self,
308
+ input_ids: torch.LongTensor,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ position_ids: Optional[torch.LongTensor] = None,
311
+ labels: Optional[torch.LongTensor] = None,
312
+ return_dict: Optional[bool] = None,
313
+ **kwargs,
314
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
315
+ """
316
+ Forward pass of the model.
317
+
318
+ Args:
319
+ input_ids: Token IDs of shape (batch_size, seq_len).
320
+ attention_mask: Attention mask of shape (batch_size, seq_len).
321
+ position_ids: Position IDs of shape (batch_size, seq_len).
322
+ labels: Labels for language modeling loss.
323
+ return_dict: Whether to return a ModelOutput object.
324
+
325
+ Returns:
326
+ CausalLMOutputWithPast containing loss (if labels provided) and logits.
327
+ """
328
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
329
+
330
+ batch_size, seq_len = input_ids.size()
331
+ device = input_ids.device
332
+
333
+ # Create position IDs if not provided
334
+ if position_ids is None:
335
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
336
+
337
+ # Get embeddings
338
+ token_embeds = self.wte(input_ids)
339
+ position_embeds = self.wpe(position_ids)
340
+ hidden_states = self.drop(token_embeds + position_embeds)
341
+
342
+ # Pass through transformer blocks
343
+ for block in self.h:
344
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
345
+
346
+ # Final layer norm
347
+ hidden_states = self.ln_f(hidden_states)
348
+
349
+ # Get logits
350
+ logits = self.lm_head(hidden_states)
351
+
352
+ # Compute loss if labels are provided
353
+ loss = None
354
+ if labels is not None:
355
+ # Shift logits and labels for next-token prediction
356
+ shift_logits = logits[..., :-1, :].contiguous()
357
+ shift_labels = labels[..., 1:].contiguous()
358
+
359
+ # Flatten for cross-entropy
360
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
361
+ # loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
362
+ loss = loss_fct(
363
+ shift_logits.view(-1, shift_logits.size(-1)),
364
+ shift_labels.view(-1),
365
+ )
366
+
367
+ if not return_dict:
368
+ output = (logits,)
369
+ return ((loss,) + output) if loss is not None else output
370
+
371
+ return CausalLMOutputWithPast(
372
+ loss=loss,
373
+ logits=logits,
374
+ past_key_values=None,
375
+ hidden_states=None,
376
+ attentions=None,
377
+ )
378
+
379
+ @torch.no_grad()
380
+ def generate_move(
381
+ self,
382
+ input_ids: torch.LongTensor,
383
+ temperature: float = 1.0,
384
+ top_k: Optional[int] = None,
385
+ top_p: Optional[float] = None,
386
+ ) -> int:
387
+ """
388
+ Generate the next move given a sequence of moves.
389
+
390
+ Args:
391
+ input_ids: Token IDs of shape (1, seq_len).
392
+ temperature: Sampling temperature (1.0 = no change).
393
+ top_k: If set, only sample from top k tokens.
394
+ top_p: If set, use nucleus sampling with this threshold.
395
+
396
+ Returns:
397
+ The token ID of the predicted next move.
398
+ """
399
+ self.eval()
400
+
401
+ # Get logits for the last position
402
+ outputs = self(input_ids)
403
+ logits = outputs.logits[:, -1, :] / temperature
404
+
405
+ # Apply top-k filtering
406
+ if top_k is not None:
407
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
408
+ logits[indices_to_remove] = float("-inf")
409
+
410
+ # Apply top-p (nucleus) filtering
411
+ if top_p is not None:
412
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
413
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
414
+
415
+ # Remove tokens with cumulative probability above the threshold
416
+ sorted_indices_to_remove = cumulative_probs > top_p
417
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
418
+ sorted_indices_to_remove[..., 0] = 0
419
+
420
+ indices_to_remove = sorted_indices_to_remove.scatter(
421
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
422
+ )
423
+ logits[indices_to_remove] = float("-inf")
424
+
425
+ # Sample from the distribution
426
+ probs = F.softmax(logits, dim=-1)
427
+ next_token = torch.multinomial(probs, num_samples=1)
428
+
429
+ return next_token.item()
430
+
431
+
432
+ # Register the model with Auto classes for easy loading
433
+ from transformers import AutoConfig, AutoModelForCausalLM
434
+
435
+ AutoConfig.register("chess_transformer", ChessConfig)
436
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c33601ca3c98d74b4916655709c5fef8eb0129b6201c3336e2c21e0725c8983
3
+ size 3557168
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Version (Player (Color + Piece), Source_S, Destination_D, Suffix)
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional
9
+ from transformers import PreTrainedTokenizer
10
+
11
+
12
+ class ChessTokenizer(PreTrainedTokenizer):
13
+ """
14
+ Sub-move tokenizer for chess moves using extended UCI notation.
15
+
16
+ This tokenizer splits each move into atomic components:
17
+ - Players (color + piece): WP, WN, WB, WR, WQ, WK, etc.
18
+ - Source square: e2
19
+ - Destination square: e4
20
+ - Optional suffixes: x (capture), + (check), * (checkmate), o/O (castling)
21
+
22
+ Example:
23
+ Move "WPe2e4(x+)" -> ["WP", "e2_S", "e4_D", "(x+)"]
24
+ """
25
+
26
+ model_input_names = ["input_ids", "attention_mask"]
27
+ vocab_files_names = {"vocab_file": "vocab.json"}
28
+
29
+ # Special tokens
30
+ PAD_TOKEN = "[PAD]"
31
+ BOS_TOKEN = "[BOS]"
32
+ EOS_TOKEN = "[EOS]"
33
+ UNK_TOKEN = "[UNK]"
34
+
35
+ # Atomic suffix tokens for default vocab
36
+ SUFFIX_TOKENS = ["(x)", "(+)", "(*)", "(o)", "(O)", "(+*)", "(x+)"]
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_file: Optional[str] = None,
41
+ vocab: Optional[Dict[str, int]] = None,
42
+ **kwargs,
43
+ ):
44
+ # Special tokens
45
+ self._pad_token = self.PAD_TOKEN
46
+ self._bos_token = self.BOS_TOKEN
47
+ self._eos_token = self.EOS_TOKEN
48
+ self._unk_token = self.UNK_TOKEN
49
+
50
+ # Remove duplicates from kwargs
51
+ kwargs.pop("pad_token", None)
52
+ kwargs.pop("bos_token", None)
53
+ kwargs.pop("eos_token", None)
54
+ kwargs.pop("unk_token", None)
55
+
56
+ # Load or create vocab
57
+ if vocab is not None:
58
+ self._vocab = vocab
59
+ elif vocab_file is not None and os.path.exists(vocab_file):
60
+ with open(vocab_file, "r", encoding="utf-8") as f:
61
+ self._vocab = json.load(f)
62
+ else:
63
+ self._vocab = self._create_default_vocab()
64
+
65
+ # Reverse mapping
66
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
67
+
68
+ super().__init__(
69
+ pad_token=self._pad_token,
70
+ bos_token=self._bos_token,
71
+ eos_token=self._eos_token,
72
+ unk_token=self._unk_token,
73
+ **kwargs,
74
+ )
75
+
76
+ def _create_default_vocab(self) -> Dict[str, int]:
77
+ """
78
+ Build a fixed vocab based on chess grammar for sub-moves.
79
+ Useful for predefined grammar instead of dataset-based vocab.
80
+ """
81
+ colors = ["W", "B"]
82
+ pieces = ["P", "N", "B", "R", "Q", "K"]
83
+ files = ["a", "b", "c", "d", "e", "f", "g", "h"]
84
+ ranks = ["1", "2", "3", "4", "5", "6", "7", "8"]
85
+ squares = [f + r for f in files for r in ranks]
86
+
87
+ players = [c + p for c in colors for p in pieces]
88
+
89
+ # Source and destination tokens
90
+ sources = [sq + "_S" for sq in squares]
91
+ dests = [sq + "_D" for sq in squares]
92
+
93
+ # Build all possible sub-tokens
94
+ vocab_tokens = players + sources + dests + self.SUFFIX_TOKENS
95
+
96
+ # Add special tokens at the start
97
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
98
+ vocab = {token: idx for idx, token in enumerate(special_tokens + vocab_tokens)}
99
+ return vocab
100
+
101
+ def _tokenize(self, text: str) -> List[str]:
102
+ """
103
+ Convert a string of moves into sub-move tokens.
104
+ """
105
+ tokens: List[str] = []
106
+ moves = text.strip().split()
107
+ for move in moves:
108
+ if not move:
109
+ continue
110
+
111
+ # Color + Piece
112
+ tokens.append(move[:2]) # WP, BN, etc.
113
+
114
+ # Source square with _S
115
+ tokens.append(move[2:4] + "_S")
116
+
117
+ # Destination square with _D
118
+ tokens.append(move[4:6] + "_D")
119
+
120
+ if (len(move)>6):
121
+ tokens.append(move[6:])
122
+
123
+ return tokens
124
+
125
+ def _convert_token_to_id(self, token: str) -> int:
126
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
127
+
128
+ def _convert_id_to_token(self, index: int) -> str:
129
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
130
+
131
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
132
+ """Convert a list of tokens back to a string."""
133
+ # Filter out special tokens for cleaner output
134
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
135
+ clean_tokens = []
136
+ for t in tokens:
137
+ if t in special:
138
+ continue
139
+ # Remove everything from _ onward
140
+ if "_" in t:
141
+ clean_tokens.append(t.split("_")[0])
142
+ else:
143
+ clean_tokens.append(t)
144
+
145
+ result = ""
146
+ temp = "".join(token for token in clean_tokens)
147
+
148
+ for i, str in enumerate(temp):
149
+ if str in ["W", "B"]:
150
+ if result == "":
151
+ result += str
152
+ elif temp[i-1].isnumeric() or temp[i-1]==")":
153
+ result += " " + str
154
+ else :
155
+ result += str
156
+ else :
157
+ result += str
158
+
159
+ return result.split()[0]
160
+
161
+ @property
162
+ def vocab_size(self) -> int:
163
+ return len(self._vocab)
164
+
165
+ def get_vocab(self) -> Dict[str, int]:
166
+ return dict(self._vocab)
167
+
168
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
169
+ if not os.path.isdir(save_directory):
170
+ os.makedirs(save_directory, exist_ok=True)
171
+ vocab_file = os.path.join(
172
+ save_directory,
173
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
174
+ )
175
+ with open(vocab_file, "w", encoding="utf-8") as f:
176
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
177
+ return (vocab_file,)
178
+
179
+ @classmethod
180
+ def build_vocab_from_iterator(cls, iterator, min_frequency: int = 1) -> "ChessTokenizer":
181
+ """
182
+ Build vocab from dataset iterator using sub-move tokens.
183
+ """
184
+ from collections import Counter
185
+ token_counts = Counter()
186
+ for game in iterator:
187
+ sub_tokens = cls()._tokenize(game)
188
+ token_counts.update(sub_tokens)
189
+ tokens = [token for token, count in token_counts.items() if count >= min_frequency]
190
+ tokens = sorted(tokens)
191
+ special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
192
+ vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
193
+ return cls(vocab=vocab)
194
+
195
+ @classmethod
196
+ def build_vocab_from_dataset(
197
+ cls,
198
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
199
+ split: str = "train",
200
+ column: str = "text",
201
+ min_frequency: int = 500,
202
+ max_samples: Optional[int] = 100000,
203
+ ) -> "ChessTokenizer":
204
+ from datasets import load_dataset
205
+ dataset = load_dataset(dataset_name, split=split)
206
+ if max_samples is not None:
207
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
208
+ def game_iterator():
209
+ for example in dataset:
210
+ yield example[column]
211
+ return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
212
+
213
+
214
+ def count_vocab_from_dataset(
215
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
216
+ split: str = "train",
217
+ column: str = "text",
218
+ max_samples: Optional[int] = 10000,
219
+ ) -> Dict[str, int]:
220
+ """
221
+ Count sub-move token frequencies in a dataset (useful for vocab analysis).
222
+ """
223
+ from collections import Counter
224
+ from datasets import load_dataset
225
+
226
+ dataset = load_dataset(dataset_name, split=split)
227
+ if max_samples is not None:
228
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
229
+
230
+ token_counts = Counter()
231
+ for example in dataset:
232
+ moves = example[column].strip().split()
233
+ # Use sub-tokenization
234
+ tokenizer = ChessTokenizer()
235
+ for move in moves:
236
+ token_counts.update(tokenizer._tokenize(move))
237
+ return dict(token_counts)
238
+
tokenizer_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "bos_token": "[BOS]",
37
+ "clean_up_tokenization_spaces": false,
38
+ "eos_token": "[EOS]",
39
+ "extra_special_tokens": {},
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "[PAD]",
42
+ "tokenizer_class": "ChessTokenizer",
43
+ "unk_token": "[UNK]",
44
+ "auto_map": {
45
+ "AutoTokenizer": [
46
+ "tokenizer.ChessTokenizer",
47
+ null
48
+ ]
49
+ }
50
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aad1e9e8effe3ddabe5124448f632b712204b96c7fc0e0f26b222456f6791b23
3
+ size 5777
vocab.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[BOS]": 1,
4
+ "[EOS]": 2,
5
+ "[UNK]": 3,
6
+ "(+)": 4,
7
+ "(+*)": 5,
8
+ "(+Q)": 6,
9
+ "(O)": 7,
10
+ "(Q)": 8,
11
+ "(o)": 9,
12
+ "(x)": 10,
13
+ "(x+)": 11,
14
+ "(x+*)": 12,
15
+ "(x+Q)": 13,
16
+ "(xE)": 14,
17
+ "BB": 15,
18
+ "BK": 16,
19
+ "BN": 17,
20
+ "BP": 18,
21
+ "BQ": 19,
22
+ "BR": 20,
23
+ "WB": 21,
24
+ "WK": 22,
25
+ "WN": 23,
26
+ "WP": 24,
27
+ "WQ": 25,
28
+ "WR": 26,
29
+ "a1_D": 27,
30
+ "a1_S": 28,
31
+ "a2_D": 29,
32
+ "a2_S": 30,
33
+ "a3_D": 31,
34
+ "a3_S": 32,
35
+ "a4_D": 33,
36
+ "a4_S": 34,
37
+ "a5_D": 35,
38
+ "a5_S": 36,
39
+ "a6_D": 37,
40
+ "a6_S": 38,
41
+ "a7_D": 39,
42
+ "a7_S": 40,
43
+ "a8_D": 41,
44
+ "a8_S": 42,
45
+ "b1_D": 43,
46
+ "b1_S": 44,
47
+ "b2_D": 45,
48
+ "b2_S": 46,
49
+ "b3_D": 47,
50
+ "b3_S": 48,
51
+ "b4_D": 49,
52
+ "b4_S": 50,
53
+ "b5_D": 51,
54
+ "b5_S": 52,
55
+ "b6_D": 53,
56
+ "b6_S": 54,
57
+ "b7_D": 55,
58
+ "b7_S": 56,
59
+ "b8_D": 57,
60
+ "b8_S": 58,
61
+ "c1_D": 59,
62
+ "c1_S": 60,
63
+ "c2_D": 61,
64
+ "c2_S": 62,
65
+ "c3_D": 63,
66
+ "c3_S": 64,
67
+ "c4_D": 65,
68
+ "c4_S": 66,
69
+ "c5_D": 67,
70
+ "c5_S": 68,
71
+ "c6_D": 69,
72
+ "c6_S": 70,
73
+ "c7_D": 71,
74
+ "c7_S": 72,
75
+ "c8_D": 73,
76
+ "c8_S": 74,
77
+ "d1_D": 75,
78
+ "d1_S": 76,
79
+ "d2_D": 77,
80
+ "d2_S": 78,
81
+ "d3_D": 79,
82
+ "d3_S": 80,
83
+ "d4_D": 81,
84
+ "d4_S": 82,
85
+ "d5_D": 83,
86
+ "d5_S": 84,
87
+ "d6_D": 85,
88
+ "d6_S": 86,
89
+ "d7_D": 87,
90
+ "d7_S": 88,
91
+ "d8_D": 89,
92
+ "d8_S": 90,
93
+ "e1_D": 91,
94
+ "e1_S": 92,
95
+ "e2_D": 93,
96
+ "e2_S": 94,
97
+ "e3_D": 95,
98
+ "e3_S": 96,
99
+ "e4_D": 97,
100
+ "e4_S": 98,
101
+ "e5_D": 99,
102
+ "e5_S": 100,
103
+ "e6_D": 101,
104
+ "e6_S": 102,
105
+ "e7_D": 103,
106
+ "e7_S": 104,
107
+ "e8_D": 105,
108
+ "e8_S": 106,
109
+ "f1_D": 107,
110
+ "f1_S": 108,
111
+ "f2_D": 109,
112
+ "f2_S": 110,
113
+ "f3_D": 111,
114
+ "f3_S": 112,
115
+ "f4_D": 113,
116
+ "f4_S": 114,
117
+ "f5_D": 115,
118
+ "f5_S": 116,
119
+ "f6_D": 117,
120
+ "f6_S": 118,
121
+ "f7_D": 119,
122
+ "f7_S": 120,
123
+ "f8_D": 121,
124
+ "f8_S": 122,
125
+ "g1_D": 123,
126
+ "g1_S": 124,
127
+ "g2_D": 125,
128
+ "g2_S": 126,
129
+ "g3_D": 127,
130
+ "g3_S": 128,
131
+ "g4_D": 129,
132
+ "g4_S": 130,
133
+ "g5_D": 131,
134
+ "g5_S": 132,
135
+ "g6_D": 133,
136
+ "g6_S": 134,
137
+ "g7_D": 135,
138
+ "g7_S": 136,
139
+ "g8_D": 137,
140
+ "g8_S": 138,
141
+ "h1_D": 139,
142
+ "h1_S": 140,
143
+ "h2_D": 141,
144
+ "h2_S": 142,
145
+ "h3_D": 143,
146
+ "h3_S": 144,
147
+ "h4_D": 145,
148
+ "h4_S": 146,
149
+ "h5_D": 147,
150
+ "h5_S": 148,
151
+ "h6_D": 149,
152
+ "h6_S": 150,
153
+ "h7_D": 151,
154
+ "h7_S": 152,
155
+ "h8_D": 153,
156
+ "h8_S": 154
157
+ }