khanghoang0902 commited on
Commit
5912985
·
verified ·
1 Parent(s): 4a397b6

update model

Browse files
Files changed (7) hide show
  1. __init__.py +0 -0
  2. config.json +11 -6
  3. model.py +362 -0
  4. model.safetensors +2 -2
  5. tokenizer.py +119 -193
  6. tokenizer_config.json +6 -0
  7. vocab.json +80 -145
__init__.py ADDED
File without changes
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "ChessForCausalLM"
4
  ],
@@ -8,13 +9,17 @@
8
  "eos_token_id": 2,
9
  "layer_norm_epsilon": 1e-05,
10
  "model_type": "chess_transformer",
11
- "n_ctx": 360,
12
- "n_embd": 102,
13
- "n_head": 6,
14
- "n_inner": 360,
15
- "n_layer": 8,
 
 
 
 
16
  "pad_token_id": 0,
17
  "tie_weights": true,
18
  "transformers_version": "4.57.5",
19
- "vocab_size": 149
20
  }
 
1
  {
2
+ "_name_or_path": "./output/final_model",
3
  "architectures": [
4
  "ChessForCausalLM"
5
  ],
 
9
  "eos_token_id": 2,
10
  "layer_norm_epsilon": 1e-05,
11
  "model_type": "chess_transformer",
12
+ "auto_map": {
13
+ "AutoConfig": "model.ChessConfig",
14
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
15
+ },
16
+ "n_ctx": 512,
17
+ "n_embd": 128,
18
+ "n_head": 4,
19
+ "n_inner": 256,
20
+ "n_layer": 7,
21
  "pad_token_id": 0,
22
  "tie_weights": true,
23
  "transformers_version": "4.57.5",
24
+ "vocab_size": 84
25
  }
model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer model for the Chess Challenge.
3
+
4
+ Lightweight GPT-style architecture sized to stay within a ~1M parameter budget.
5
+ Key pieces:
6
+ - ChessConfig: hyperparameter container
7
+ - ChessForCausalLM: autoregressive model for next-move prediction
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from dataclasses import dataclass
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+ from transformers.modeling_outputs import CausalLMOutputWithPast
21
+
22
+
23
+ class ChessConfig(PretrainedConfig):
24
+ """
25
+ Configuration for the small chess transformer.
26
+
27
+ Defaults target roughly 1M parameters; adjust values to explore variants.
28
+ """
29
+ model_type = "chess_transformer"
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size: int = 1200,
34
+ n_embd: int = 128,
35
+ n_layer: int = 6,
36
+ n_head: int = 4,
37
+ n_ctx: int = 256,
38
+ n_inner: Optional[int] = None,
39
+ dropout: float = 0.1,
40
+ layer_norm_epsilon: float = 1e-5,
41
+ tie_weights: bool = True,
42
+ pad_token_id: int = 0,
43
+ bos_token_id: int = 1,
44
+ eos_token_id: int = 2,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(
48
+ pad_token_id=pad_token_id,
49
+ bos_token_id=bos_token_id,
50
+ eos_token_id=eos_token_id,
51
+ **kwargs,
52
+ )
53
+
54
+ self.vocab_size = vocab_size
55
+ self.n_embd = n_embd
56
+ self.n_layer = n_layer
57
+ self.n_head = n_head
58
+ self.n_ctx = n_ctx
59
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd # Reduced from 4x to 3x
60
+ self.dropout = dropout
61
+ self.layer_norm_epsilon = layer_norm_epsilon
62
+ self.tie_weights = tie_weights
63
+ # Inform HF base class about tying behavior
64
+ self.tie_word_embeddings = bool(tie_weights)
65
+
66
+
67
+ class MultiHeadAttention(nn.Module):
68
+ """Standard masked self-attention with combined QKV projection."""
69
+ def __init__(self, config: ChessConfig):
70
+ super().__init__()
71
+
72
+ assert config.n_embd % config.n_head == 0, \
73
+ f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
74
+
75
+ self.n_head = config.n_head
76
+ self.n_embd = config.n_embd
77
+ self.head_dim = config.n_embd // config.n_head
78
+
79
+ # Combined QKV projection for efficiency
80
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
81
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
82
+
83
+ self.dropout = nn.Dropout(config.dropout)
84
+
85
+ # Causal mask (will be created on first forward pass)
86
+ self.register_buffer(
87
+ "bias",
88
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
89
+ 1, 1, config.n_ctx, config.n_ctx
90
+ ),
91
+ persistent=False,
92
+ )
93
+
94
+ def forward(
95
+ self,
96
+ x: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ ) -> torch.Tensor:
99
+ batch_size, seq_len, _ = x.size()
100
+
101
+ # Compute Q, K, V
102
+ qkv = self.c_attn(x)
103
+ q, k, v = qkv.split(self.n_embd, dim=2)
104
+
105
+ # Reshape for multi-head attention
106
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
107
+ k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
108
+ v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
109
+
110
+ # Scaled dot-product attention
111
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
112
+
113
+ # Apply causal mask
114
+ causal_mask = self.bias[:, :, :seq_len, :seq_len]
115
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
116
+
117
+ # Apply attention mask (for padding)
118
+ if attention_mask is not None:
119
+ # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
120
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
121
+ attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
122
+
123
+ attn_weights = F.softmax(attn_weights, dim=-1)
124
+ attn_weights = self.dropout(attn_weights)
125
+
126
+ # Apply attention to values
127
+ attn_output = torch.matmul(attn_weights, v)
128
+
129
+ # Reshape back
130
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
131
+ batch_size, seq_len, self.n_embd
132
+ )
133
+
134
+ # Output projection
135
+ attn_output = self.c_proj(attn_output)
136
+
137
+ return attn_output
138
+
139
+
140
+ class FeedForward(nn.Module):
141
+ """Two-layer MLP with GELU and dropout."""
142
+ def __init__(self, config: ChessConfig):
143
+ super().__init__()
144
+
145
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
146
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd, bias=False)
147
+ self.dropout = nn.Dropout(config.dropout)
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ x = self.c_fc(x)
151
+ x = F.gelu(x)
152
+ x = self.c_proj(x)
153
+ x = self.dropout(x)
154
+ return x
155
+
156
+
157
+ class TransformerBlock(nn.Module):
158
+ """Pre-norm attention + MLP block."""
159
+ def __init__(self, config: ChessConfig):
160
+ super().__init__()
161
+
162
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
163
+ self.attn = MultiHeadAttention(config)
164
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
165
+ self.mlp = FeedForward(config)
166
+
167
+ def forward(
168
+ self,
169
+ x: torch.Tensor,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ ) -> torch.Tensor:
172
+ # Pre-norm attention
173
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
174
+ # Pre-norm FFN
175
+ x = x + self.mlp(self.ln_2(x))
176
+ return x
177
+
178
+
179
+ class ChessForCausalLM(PreTrainedModel):
180
+ """
181
+ GPT-style causal LM for chess move prediction.
182
+
183
+ Stacks transformer blocks over token and position embeddings; ties output
184
+ head to embeddings when configured.
185
+ """
186
+ config_class = ChessConfig
187
+ base_model_prefix = "transformer"
188
+ supports_gradient_checkpointing = True
189
+ # Suppress missing-key warning for tied lm_head when loading
190
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
191
+
192
+ def __init__(self, config: ChessConfig):
193
+ super().__init__(config)
194
+
195
+ # Token and position embeddings
196
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
197
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
198
+
199
+ self.drop = nn.Dropout(config.dropout)
200
+
201
+ # Transformer blocks
202
+ self.h = nn.ModuleList([
203
+ TransformerBlock(config) for _ in range(config.n_layer)
204
+ ])
205
+
206
+ # Final layer norm
207
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
208
+
209
+ # Output head
210
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
211
+
212
+ # Declare tied weights for proper serialization
213
+ if config.tie_weights:
214
+ self._tied_weights_keys = ["lm_head.weight"]
215
+
216
+ # Initialize weights
217
+ self.post_init()
218
+
219
+ # Tie weights if configured
220
+ if config.tie_weights:
221
+ self.tie_weights()
222
+
223
+ def get_input_embeddings(self) -> nn.Module:
224
+ return self.wte
225
+
226
+ def set_input_embeddings(self, new_embeddings: nn.Module):
227
+ self.wte = new_embeddings
228
+ if getattr(self.config, "tie_weights", False):
229
+ self.tie_weights()
230
+
231
+ def get_output_embeddings(self) -> nn.Module:
232
+ return self.lm_head
233
+
234
+ def set_output_embeddings(self, new_embeddings: nn.Module):
235
+ self.lm_head = new_embeddings
236
+
237
+ def tie_weights(self):
238
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
239
+ self._tie_or_clone_weights(self.lm_head, self.wte)
240
+
241
+ def _init_weights(self, module: nn.Module):
242
+ """GPT-2 style init."""
243
+ if isinstance(module, nn.Linear):
244
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
245
+ if module.bias is not None:
246
+ torch.nn.init.zeros_(module.bias)
247
+ elif isinstance(module, nn.Embedding):
248
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
249
+ elif isinstance(module, nn.LayerNorm):
250
+ torch.nn.init.ones_(module.weight)
251
+ torch.nn.init.zeros_(module.bias)
252
+
253
+ def forward(
254
+ self,
255
+ input_ids: torch.LongTensor,
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ position_ids: Optional[torch.LongTensor] = None,
258
+ labels: Optional[torch.LongTensor] = None,
259
+ return_dict: Optional[bool] = None,
260
+ **kwargs,
261
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
262
+ """Forward pass with optional label loss."""
263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
264
+
265
+ batch_size, seq_len = input_ids.size()
266
+ device = input_ids.device
267
+
268
+ # Create position IDs if not provided
269
+ if position_ids is None:
270
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
271
+
272
+ # Get embeddings
273
+ token_embeds = self.wte(input_ids)
274
+ position_embeds = self.wpe(position_ids)
275
+ hidden_states = self.drop(token_embeds + position_embeds)
276
+
277
+ # Pass through transformer blocks
278
+ for block in self.h:
279
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
280
+
281
+ # Final layer norm
282
+ hidden_states = self.ln_f(hidden_states)
283
+
284
+ # Get logits
285
+ logits = self.lm_head(hidden_states)
286
+
287
+ # Compute loss if labels are provided
288
+ loss = None
289
+ if labels is not None:
290
+ # Shift logits and labels for next-token prediction
291
+ shift_logits = logits[..., :-1, :].contiguous()
292
+ shift_labels = labels[..., 1:].contiguous()
293
+
294
+ # Flatten for cross-entropy
295
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
296
+ loss = loss_fct(
297
+ shift_logits.view(-1, shift_logits.size(-1)),
298
+ shift_labels.view(-1),
299
+ )
300
+
301
+ if not return_dict:
302
+ output = (logits,)
303
+ return ((loss,) + output) if loss is not None else output
304
+
305
+ return CausalLMOutputWithPast(
306
+ loss=loss,
307
+ logits=logits,
308
+ past_key_values=None,
309
+ hidden_states=None,
310
+ attentions=None,
311
+ )
312
+
313
+ @torch.no_grad()
314
+ def generate_move(
315
+ self,
316
+ input_ids: torch.LongTensor,
317
+ temperature: float = 1.0,
318
+ top_k: Optional[int] = None,
319
+ top_p: Optional[float] = None,
320
+ ) -> int:
321
+ """
322
+ Sample a next move token ID from logits, with optional top-k/p filtering.
323
+ Expects input_ids shaped (1, seq_len).
324
+ """
325
+ self.eval()
326
+
327
+ # Get logits for the last position
328
+ outputs = self(input_ids)
329
+ logits = outputs.logits[:, -1, :] / temperature
330
+
331
+ # Apply top-k filtering
332
+ if top_k is not None:
333
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
334
+ logits[indices_to_remove] = float("-inf")
335
+
336
+ # Apply top-p (nucleus) filtering
337
+ if top_p is not None:
338
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
339
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
340
+
341
+ # Remove tokens with cumulative probability above the threshold
342
+ sorted_indices_to_remove = cumulative_probs > top_p
343
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
344
+ sorted_indices_to_remove[..., 0] = 0
345
+
346
+ indices_to_remove = sorted_indices_to_remove.scatter(
347
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
348
+ )
349
+ logits[indices_to_remove] = float("-inf")
350
+
351
+ # Sample from the distribution
352
+ probs = F.softmax(logits, dim=-1)
353
+ next_token = torch.multinomial(probs, num_samples=1)
354
+
355
+ return next_token.item()
356
+
357
+
358
+ # Register the model with Auto classes for easy loading
359
+ from transformers import AutoConfig, AutoModelForCausalLM
360
+
361
+ AutoConfig.register("chess_transformer", ChessConfig)
362
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:27e3c4ddd2a3cfa3cf53001092e6c540678540427ddc709e73bc309aa23687a8
3
- size 3939664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f6c2d3a9303683f3e532a8d924fe30f9a77e70bea8d1577b4a39eb167183e6
3
+ size 4003376
tokenizer.py CHANGED
@@ -1,87 +1,66 @@
1
- """
2
- Custom Chess Tokenizer for the Chess Challenge.
3
-
4
- This tokenizer uses a Factorized strategy (Triple Tokenization) to split moves
5
- into atomic components: [Piece, From_Square, To_Square, Suffix].
6
-
7
- Example: "WPe2e4" -> ["WP", "e2_f", "e4_t"]
8
-
9
- This drastically reduces vocabulary size (~155 tokens vs 1700), allowing
10
- for deeper models within the 1M parameter budget.
11
- """
12
-
13
  from __future__ import annotations
14
 
15
  import json
16
  import os
17
  import re
18
- from typing import Dict, List, Optional
19
-
20
  from transformers import PreTrainedTokenizer
21
 
22
 
23
  class ChessTokenizer(PreTrainedTokenizer):
24
- """
25
- A custom tokenizer for chess moves using factorized notation.
26
-
27
- This tokenizer maps chess concepts (Pieces, Squares) to unique token IDs.
28
- The vocabulary is fixed and does not need to be built from a dataset.
29
-
30
- Example:
31
- >>> tokenizer = ChessTokenizer()
32
- >>> tokenizer.tokenize("WPe2e4")
33
- ['WP', 'e2_f', 'e4_t']
34
- """
35
-
36
  model_input_names = ["input_ids", "attention_mask"]
37
  vocab_files_names = {"vocab_file": "vocab.json"}
38
-
39
- # Special tokens
40
  PAD_TOKEN = "[PAD]"
41
  BOS_TOKEN = "[BOS]"
42
  EOS_TOKEN = "[EOS]"
43
  UNK_TOKEN = "[UNK]"
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def __init__(
46
  self,
47
  vocab_file: Optional[str] = None,
48
  vocab: Optional[Dict[str, int]] = None,
49
  **kwargs,
50
  ):
51
- """
52
- Initialize the chess tokenizer.
53
-
54
- Args:
55
- vocab_file: Path to a JSON file containing the vocabulary mapping.
56
- vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
57
- **kwargs: Additional arguments passed to PreTrainedTokenizer.
58
- """
59
- # Initialize special tokens
60
  self._pad_token = self.PAD_TOKEN
61
  self._bos_token = self.BOS_TOKEN
62
  self._eos_token = self.EOS_TOKEN
63
  self._unk_token = self.UNK_TOKEN
64
 
65
- # Remove duplicate special-token entries passed through kwargs
66
- kwargs.pop("pad_token", None)
67
- kwargs.pop("bos_token", None)
68
- kwargs.pop("eos_token", None)
69
- kwargs.pop("unk_token", None)
70
-
71
- # Load or create vocabulary
72
  if vocab is not None:
73
  self._vocab = vocab
74
  elif vocab_file is not None and os.path.exists(vocab_file):
75
  with open(vocab_file, "r", encoding="utf-8") as f:
76
  self._vocab = json.load(f)
77
  else:
78
- # Create the Factorized Vocabulary (Fixed ~155 tokens)
79
- self._vocab = self._create_default_vocab()
80
-
81
- # Create reverse mapping
82
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
83
-
84
- # Call parent init
85
  super().__init__(
86
  pad_token=self._pad_token,
87
  bos_token=self._bos_token,
@@ -89,167 +68,114 @@ class ChessTokenizer(PreTrainedTokenizer):
89
  unk_token=self._unk_token,
90
  **kwargs,
91
  )
92
-
93
- def _create_default_vocab(self) -> Dict[str, int]:
94
- """
95
- Create the fixed factorized vocabulary.
96
- Includes Special Tokens, Pieces, From-Squares, To-Squares, and Suffixes.
97
- """
98
- special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
99
- pieces = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"]
100
- suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
101
-
102
- cols = "abcdefgh"
103
- rows = "12345678"
104
- squares = [f"{c}{r}" for r in rows for c in cols] # a1...h8
105
-
106
- vocab = {token: idx for idx, token in enumerate(special_tokens)}
107
-
108
- # Helper to add tokens sequentially
109
- def add_tokens(token_list, suffix=""):
110
- offset = len(vocab)
111
- for i, t in enumerate(token_list):
112
- vocab[f"{t}{suffix}"] = offset + i
113
-
114
- add_tokens(pieces) # WP, WN...
115
- add_tokens(squares, "_f") # a1_f, b1_f... (From Squares)
116
- add_tokens(squares, "_t") # a1_t, b1_t... (To Squares)
117
- add_tokens(suffixes) # (x), (+)...
118
-
119
- return vocab
120
-
121
- @classmethod
122
- def build_vocab_from_iterator(cls, iterator, min_frequency: int = 1) -> "ChessTokenizer":
123
- return cls()
124
-
125
- @classmethod
126
- def build_vocab_from_dataset(cls, **kwargs) -> "ChessTokenizer":
127
- return cls()
128
-
129
  @property
130
  def vocab_size(self) -> int:
131
  return len(self._vocab)
132
-
133
  def get_vocab(self) -> Dict[str, int]:
134
  return dict(self._vocab)
135
-
136
  def _tokenize(self, text: str) -> List[str]:
137
- """
138
- Tokenize a string of moves into factorized tokens.
139
- Robustly handles both "WPe2e4" and "e2e4".
140
- """
141
- raw_chunks = text.strip().split()
142
- tokens = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- for chunk in raw_chunks:
145
- if chunk in self._vocab:
146
- tokens.append(chunk)
147
- continue
148
-
149
- # Regex Parsing: Matches Optional Piece + From + To + Optional Suffix
150
- # Matches "WPe2e4" OR "e2e4" (robust)
151
- match = re.match(r'([WB]?[PRNBQK]?)?([a-h][1-8])([a-h][1-8])(.*)', chunk)
152
-
153
- if match:
154
- p, f, t, s = match.groups()
155
- if p: tokens.append(p) # Piece (if present)
156
- tokens.extend([f"{f}_f", f"{t}_t"]) # From + To
157
- if s: tokens.append(s) # Suffix (if present)
158
-
159
- # Castling Special Case
160
- elif "(o)" in chunk or "(O)" in chunk:
161
- match_castle = re.match(r'([WB]K)?([a-h][1-8])([a-h][1-8])(.*)', chunk)
162
- if match_castle:
163
- p, f, t, s = match_castle.groups()
164
- if p: tokens.append(p)
165
- tokens.extend([f"{f}_f", f"{t}_t", s])
166
- else:
167
- tokens.append(self.unk_token)
168
- else:
169
- tokens.append(self.unk_token)
170
-
171
  return tokens
172
-
 
 
 
 
 
 
 
 
 
 
173
  def _convert_token_to_id(self, token: str) -> int:
174
- return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
175
-
176
  def _convert_id_to_token(self, index: int) -> str:
177
  return self._ids_to_tokens.get(index, self.UNK_TOKEN)
178
-
179
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
180
- """
181
- Converts tokens back to a single string for evaluation.
182
-
183
- Logic:
184
- 1. Strips '_f' and '_t' suffixes.
185
- 2. Joins parts without spaces (e.g. 'WP' + 'e2' + 'e4' -> 'WPe2e4').
186
- 3. Inserts a space ONLY before a new Piece token to separate moves.
187
- """
188
  output = []
189
  special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
190
 
191
- # 1. Clean tokens (remove special tokens and suffixes)
192
- clean_tokens = []
193
  for t in tokens:
194
  if t in special: continue
195
- clean_tokens.append(t.replace("_f", "").replace("_t", ""))
196
-
197
- # 2. Join intelligently
198
- final_str = ""
199
- for i, token in enumerate(clean_tokens):
200
- # Check if this token is a Piece (starts with W or B, length 2)
201
- # This marks the start of a new move.
202
- is_new_move = (len(token) == 2 and token[0] in "WB" and token[1] in "PRNBQK")
203
 
204
- # Add space if it's a new move (and not the very first token)
205
- if i > 0 and is_new_move:
206
- final_str += " " + token
 
 
 
 
 
 
207
  else:
208
- final_str += token
209
-
210
- return final_str.strip()
211
-
212
- def save_vocabulary(
213
- self,
214
- save_directory: str,
215
- filename_prefix: Optional[str] = None,
216
- ) -> tuple:
217
- if not os.path.isdir(save_directory):
218
- os.makedirs(save_directory, exist_ok=True)
219
-
220
- vocab_file = os.path.join(
221
- save_directory,
222
- (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
223
- )
224
-
225
- with open(vocab_file, "w", encoding="utf-8") as f:
226
- json.dump(self._vocab, f, ensure_ascii=False, indent=2)
227
 
228
- return (vocab_file,)
229
 
 
 
 
 
 
230
 
231
- def count_vocab_from_dataset(
232
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
233
- split: str = "train",
234
- column: str = "text",
235
- max_samples: Optional[int] = 10000,
236
- ) -> Dict[str, int]:
237
- """
238
- Count token frequencies in a dataset.
239
- """
240
- from collections import Counter
241
- from datasets import load_dataset
242
-
243
- tokenizer = ChessTokenizer()
244
- dataset = load_dataset(dataset_name, split=split)
245
-
246
- if max_samples is not None:
247
- dataset = dataset.select(range(min(max_samples, len(dataset))))
248
-
249
- token_counts = Counter()
250
-
251
- for example in dataset:
252
- tokens = tokenizer.tokenize(example[column])
253
- token_counts.update(tokens)
254
-
255
- return dict(token_counts)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import json
4
  import os
5
  import re
6
+ from typing import Dict, List, Optional, Tuple, Any, Union, Sequence
 
7
  from transformers import PreTrainedTokenizer
8
 
9
 
10
  class ChessTokenizer(PreTrainedTokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
11
  model_input_names = ["input_ids", "attention_mask"]
12
  vocab_files_names = {"vocab_file": "vocab.json"}
13
+
 
14
  PAD_TOKEN = "[PAD]"
15
  BOS_TOKEN = "[BOS]"
16
  EOS_TOKEN = "[EOS]"
17
  UNK_TOKEN = "[UNK]"
18
+
19
+ SIDE_W = "SIDE_W"
20
+ SIDE_B = "SIDE_B"
21
+ PROMO_PREFIX = "PROMO_"
22
+
23
+ CAPTURE = "CAPTURE"
24
+ CHECK = "CHECK"
25
+ MATE = "MATE"
26
+ CASTLE = "CASTLE"
27
+
28
+ PIECES = ["P", "N", "B", "R", "Q", "K"]
29
+
30
+ MOVE_RE = re.compile(
31
+ r"^(?P<side>[WB])"
32
+ r"(?P<piece>[PNBRQK])"
33
+ r"(?P<from>[a-h][1-8])"
34
+ r"(?P<to>[a-h][1-8])"
35
+ r"(?P<rest>.*)$"
36
+ )
37
+
38
  def __init__(
39
  self,
40
  vocab_file: Optional[str] = None,
41
  vocab: Optional[Dict[str, int]] = None,
42
  **kwargs,
43
  ):
44
+ kwargs.pop("pad_token", None)
45
+ kwargs.pop("bos_token", None)
46
+ kwargs.pop("eos_token", None)
47
+ kwargs.pop("unk_token", None)
48
+
 
 
 
 
49
  self._pad_token = self.PAD_TOKEN
50
  self._bos_token = self.BOS_TOKEN
51
  self._eos_token = self.EOS_TOKEN
52
  self._unk_token = self.UNK_TOKEN
53
 
 
 
 
 
 
 
 
54
  if vocab is not None:
55
  self._vocab = vocab
56
  elif vocab_file is not None and os.path.exists(vocab_file):
57
  with open(vocab_file, "r", encoding="utf-8") as f:
58
  self._vocab = json.load(f)
59
  else:
60
+ self._vocab = self._build_fixed_vocab()
61
+
 
 
62
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
63
+
 
64
  super().__init__(
65
  pad_token=self._pad_token,
66
  bos_token=self._bos_token,
 
68
  unk_token=self._unk_token,
69
  **kwargs,
70
  )
71
+
72
+ def _build_fixed_vocab(self) -> Dict[str, int]:
73
+ special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
74
+ sides = [self.SIDE_W, self.SIDE_B]
75
+ pieces = [f"PIECE_{p}" for p in self.PIECES]
76
+ squares = [f"SQ_{file}{rank}" for file in "abcdefgh" for rank in "12345678"]
77
+ promos = [f"{self.PROMO_PREFIX}{p}" for p in ["Q", "R", "B", "N"]]
78
+ flags = [self.CAPTURE, self.CHECK, self.MATE, self.CASTLE]
79
+
80
+ tokens = special + sides + pieces + squares + promos + flags
81
+ return {tok: i for i, tok in enumerate(tokens)}
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @property
84
  def vocab_size(self) -> int:
85
  return len(self._vocab)
86
+
87
  def get_vocab(self) -> Dict[str, int]:
88
  return dict(self._vocab)
89
+
90
  def _tokenize(self, text: str) -> List[str]:
91
+ out: List[str] = []
92
+ for move in text.strip().split():
93
+ out.extend(self._tokenize_move(move))
94
+ return out
95
+
96
+ def _tokenize_move(self, move: str) -> List[str]:
97
+ m = self.MOVE_RE.match(move)
98
+ if not m: return [self.UNK_TOKEN]
99
+
100
+ side = m.group("side")
101
+ piece = m.group("piece")
102
+ frm = m.group("from")
103
+ to = m.group("to")
104
+ rest = m.group("rest") or ""
105
+
106
+ tokens: List[str] = []
107
+ tokens.append(self.SIDE_W if side == "W" else self.SIDE_B)
108
+ tokens.append(f"PIECE_{piece}")
109
+ tokens.append(f"SQ_{frm}")
110
+ tokens.append(f"SQ_{to}")
111
+
112
+ promo = self._parse_promotion(rest)
113
+ if promo is not None:
114
+ tokens.append(f"{self.PROMO_PREFIX}{promo}")
115
+
116
+ if "(x)" in rest: tokens.append(self.CAPTURE)
117
 
118
+ if "++" in rest or "(+*)" in rest or "#" in rest:
119
+ tokens.append(self.MATE)
120
+ elif "+" in rest or "(+)" in rest:
121
+ tokens.append(self.CHECK)
122
+
123
+ if "(o)" in rest or "(O)" in rest:
124
+ tokens.append(self.CASTLE)
125
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return tokens
127
+
128
+ def _parse_promotion(self, rest: str) -> Optional[str]:
129
+ m = re.search(r"=([QRBNqrbn])", rest)
130
+ if m: return m.group(1).upper()
131
+
132
+ m2 = re.search(r"([QRBNqrbn])", rest)
133
+ if m2 and "(" not in rest:
134
+ if rest.strip() in ["Q", "R", "B", "N", "q", "r", "b", "n"]:
135
+ return rest.strip().upper()
136
+ return None
137
+
138
  def _convert_token_to_id(self, token: str) -> int:
139
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
140
+
141
  def _convert_id_to_token(self, index: int) -> str:
142
  return self._ids_to_tokens.get(index, self.UNK_TOKEN)
143
+
144
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
 
 
 
 
 
 
 
 
145
  output = []
146
  special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
147
 
 
 
148
  for t in tokens:
149
  if t in special: continue
 
 
 
 
 
 
 
 
150
 
151
+ if t == self.SIDE_W: output.append("W")
152
+ elif t == self.SIDE_B: output.append("B")
153
+ elif t.startswith("PIECE_"): output.append(t.replace("PIECE_", ""))
154
+ elif t.startswith("SQ_"): output.append(t.replace("SQ_", ""))
155
+ elif t.startswith(self.PROMO_PREFIX): output.append("=" + t.replace(self.PROMO_PREFIX, ""))
156
+ elif t == self.CAPTURE: output.append("(x)")
157
+ elif t == self.CHECK: output.append("(+)")
158
+ elif t == self.MATE: output.append("(+*)")
159
+ elif t == self.CASTLE: output.append("(o)")
160
  else:
161
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ return "".join(output)
164
 
165
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
166
+ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True)
167
+ vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
168
+ with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2)
169
+ return (vocab_file,)
170
 
171
+ def decode(self, token_ids: Union[int, Sequence[int]], skip_special_tokens: bool = False, **kwargs) -> str:
172
+ if isinstance(token_ids, int): ids = [token_ids]
173
+ elif "torch" in str(type(token_ids)): ids = token_ids.detach().cpu().flatten().tolist()
174
+ else: ids = list(token_ids)
175
+
176
+ toks = [self._convert_id_to_token(i) for i in ids]
177
+ if skip_special_tokens:
178
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
179
+ toks = [t for t in toks if t not in special]
180
+
181
+ return self.convert_tokens_to_string(toks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer_config.json CHANGED
@@ -40,5 +40,11 @@
40
  "model_max_length": 1000000000000000019884624838656,
41
  "pad_token": "[PAD]",
42
  "tokenizer_class": "ChessTokenizer",
 
 
 
 
 
 
43
  "unk_token": "[UNK]"
44
  }
 
40
  "model_max_length": 1000000000000000019884624838656,
41
  "pad_token": "[PAD]",
42
  "tokenizer_class": "ChessTokenizer",
43
+ "auto_map": {
44
+ "AutoTokenizer": ["tokenizer.ChessTokenizer", null]
45
+ },
46
+ "tokenizer_auto_map": {
47
+ "AutoTokenizer": ["tokenizer.ChessTokenizer", null]
48
+ },
49
  "unk_token": "[UNK]"
50
  }
vocab.json CHANGED
@@ -3,149 +3,84 @@
3
  "[BOS]": 1,
4
  "[EOS]": 2,
5
  "[UNK]": 3,
6
- "WP": 4,
7
- "WN": 5,
8
- "WB": 6,
9
- "WR": 7,
10
- "WQ": 8,
11
- "WK": 9,
12
- "BP": 10,
13
- "BN": 11,
14
- "BB": 12,
15
- "BR": 13,
16
- "BQ": 14,
17
- "BK": 15,
18
- "a1_f": 16,
19
- "b1_f": 17,
20
- "c1_f": 18,
21
- "d1_f": 19,
22
- "e1_f": 20,
23
- "f1_f": 21,
24
- "g1_f": 22,
25
- "h1_f": 23,
26
- "a2_f": 24,
27
- "b2_f": 25,
28
- "c2_f": 26,
29
- "d2_f": 27,
30
- "e2_f": 28,
31
- "f2_f": 29,
32
- "g2_f": 30,
33
- "h2_f": 31,
34
- "a3_f": 32,
35
- "b3_f": 33,
36
- "c3_f": 34,
37
- "d3_f": 35,
38
- "e3_f": 36,
39
- "f3_f": 37,
40
- "g3_f": 38,
41
- "h3_f": 39,
42
- "a4_f": 40,
43
- "b4_f": 41,
44
- "c4_f": 42,
45
- "d4_f": 43,
46
- "e4_f": 44,
47
- "f4_f": 45,
48
- "g4_f": 46,
49
- "h4_f": 47,
50
- "a5_f": 48,
51
- "b5_f": 49,
52
- "c5_f": 50,
53
- "d5_f": 51,
54
- "e5_f": 52,
55
- "f5_f": 53,
56
- "g5_f": 54,
57
- "h5_f": 55,
58
- "a6_f": 56,
59
- "b6_f": 57,
60
- "c6_f": 58,
61
- "d6_f": 59,
62
- "e6_f": 60,
63
- "f6_f": 61,
64
- "g6_f": 62,
65
- "h6_f": 63,
66
- "a7_f": 64,
67
- "b7_f": 65,
68
- "c7_f": 66,
69
- "d7_f": 67,
70
- "e7_f": 68,
71
- "f7_f": 69,
72
- "g7_f": 70,
73
- "h7_f": 71,
74
- "a8_f": 72,
75
- "b8_f": 73,
76
- "c8_f": 74,
77
- "d8_f": 75,
78
- "e8_f": 76,
79
- "f8_f": 77,
80
- "g8_f": 78,
81
- "h8_f": 79,
82
- "a1_t": 80,
83
- "b1_t": 81,
84
- "c1_t": 82,
85
- "d1_t": 83,
86
- "e1_t": 84,
87
- "f1_t": 85,
88
- "g1_t": 86,
89
- "h1_t": 87,
90
- "a2_t": 88,
91
- "b2_t": 89,
92
- "c2_t": 90,
93
- "d2_t": 91,
94
- "e2_t": 92,
95
- "f2_t": 93,
96
- "g2_t": 94,
97
- "h2_t": 95,
98
- "a3_t": 96,
99
- "b3_t": 97,
100
- "c3_t": 98,
101
- "d3_t": 99,
102
- "e3_t": 100,
103
- "f3_t": 101,
104
- "g3_t": 102,
105
- "h3_t": 103,
106
- "a4_t": 104,
107
- "b4_t": 105,
108
- "c4_t": 106,
109
- "d4_t": 107,
110
- "e4_t": 108,
111
- "f4_t": 109,
112
- "g4_t": 110,
113
- "h4_t": 111,
114
- "a5_t": 112,
115
- "b5_t": 113,
116
- "c5_t": 114,
117
- "d5_t": 115,
118
- "e5_t": 116,
119
- "f5_t": 117,
120
- "g5_t": 118,
121
- "h5_t": 119,
122
- "a6_t": 120,
123
- "b6_t": 121,
124
- "c6_t": 122,
125
- "d6_t": 123,
126
- "e6_t": 124,
127
- "f6_t": 125,
128
- "g6_t": 126,
129
- "h6_t": 127,
130
- "a7_t": 128,
131
- "b7_t": 129,
132
- "c7_t": 130,
133
- "d7_t": 131,
134
- "e7_t": 132,
135
- "f7_t": 133,
136
- "g7_t": 134,
137
- "h7_t": 135,
138
- "a8_t": 136,
139
- "b8_t": 137,
140
- "c8_t": 138,
141
- "d8_t": 139,
142
- "e8_t": 140,
143
- "f8_t": 141,
144
- "g8_t": 142,
145
- "h8_t": 143,
146
- "(x)": 144,
147
- "(+)": 145,
148
- "(+*)": 146,
149
- "(o)": 147,
150
- "(O)": 148
151
  }
 
3
  "[BOS]": 1,
4
  "[EOS]": 2,
5
  "[UNK]": 3,
6
+ "SIDE_W": 4,
7
+ "SIDE_B": 5,
8
+ "PIECE_P": 6,
9
+ "PIECE_N": 7,
10
+ "PIECE_B": 8,
11
+ "PIECE_R": 9,
12
+ "PIECE_Q": 10,
13
+ "PIECE_K": 11,
14
+ "SQ_a1": 12,
15
+ "SQ_a2": 13,
16
+ "SQ_a3": 14,
17
+ "SQ_a4": 15,
18
+ "SQ_a5": 16,
19
+ "SQ_a6": 17,
20
+ "SQ_a7": 18,
21
+ "SQ_a8": 19,
22
+ "SQ_b1": 20,
23
+ "SQ_b2": 21,
24
+ "SQ_b3": 22,
25
+ "SQ_b4": 23,
26
+ "SQ_b5": 24,
27
+ "SQ_b6": 25,
28
+ "SQ_b7": 26,
29
+ "SQ_b8": 27,
30
+ "SQ_c1": 28,
31
+ "SQ_c2": 29,
32
+ "SQ_c3": 30,
33
+ "SQ_c4": 31,
34
+ "SQ_c5": 32,
35
+ "SQ_c6": 33,
36
+ "SQ_c7": 34,
37
+ "SQ_c8": 35,
38
+ "SQ_d1": 36,
39
+ "SQ_d2": 37,
40
+ "SQ_d3": 38,
41
+ "SQ_d4": 39,
42
+ "SQ_d5": 40,
43
+ "SQ_d6": 41,
44
+ "SQ_d7": 42,
45
+ "SQ_d8": 43,
46
+ "SQ_e1": 44,
47
+ "SQ_e2": 45,
48
+ "SQ_e3": 46,
49
+ "SQ_e4": 47,
50
+ "SQ_e5": 48,
51
+ "SQ_e6": 49,
52
+ "SQ_e7": 50,
53
+ "SQ_e8": 51,
54
+ "SQ_f1": 52,
55
+ "SQ_f2": 53,
56
+ "SQ_f3": 54,
57
+ "SQ_f4": 55,
58
+ "SQ_f5": 56,
59
+ "SQ_f6": 57,
60
+ "SQ_f7": 58,
61
+ "SQ_f8": 59,
62
+ "SQ_g1": 60,
63
+ "SQ_g2": 61,
64
+ "SQ_g3": 62,
65
+ "SQ_g4": 63,
66
+ "SQ_g5": 64,
67
+ "SQ_g6": 65,
68
+ "SQ_g7": 66,
69
+ "SQ_g8": 67,
70
+ "SQ_h1": 68,
71
+ "SQ_h2": 69,
72
+ "SQ_h3": 70,
73
+ "SQ_h4": 71,
74
+ "SQ_h5": 72,
75
+ "SQ_h6": 73,
76
+ "SQ_h7": 74,
77
+ "SQ_h8": 75,
78
+ "PROMO_Q": 76,
79
+ "PROMO_R": 77,
80
+ "PROMO_B": 78,
81
+ "PROMO_N": 79,
82
+ "CAPTURE": 80,
83
+ "CHECK": 81,
84
+ "MATE": 82,
85
+ "CASTLE": 83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  }