steakmagnan commited on
Commit
8b3dbb3
·
verified ·
1 Parent(s): 825bdcf

Chess Challenge submission by steakmagnan

Browse files
Files changed (9) hide show
  1. README.md +31 -0
  2. config.json +24 -0
  3. model.py +352 -0
  4. model.safetensors +3 -0
  5. special_tokens_map.json +6 -0
  6. tokenizer.py +226 -0
  7. tokenizer_config.json +50 -0
  8. training_args.bin +3 -0
  9. vocab.json +83 -0
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chess
5
+ - llm-course
6
+ - chess-challenge
7
+ license: mit
8
+ ---
9
+
10
+ # chess-maxence-V2
11
+
12
+ Chess model submitted to the LLM Course Chess Challenge.
13
+
14
+ ## Submission Info
15
+
16
+ - **Submitted by**: [steakmagnan](https://huggingface.co/steakmagnan)
17
+ - **Parameters**: 798,208
18
+ - **Organization**: LLM-course
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+
25
+ model = AutoModelForCausalLM.from_pretrained("LLM-course/chess-maxence-V2", trust_remote_code=True)
26
+ tokenizer = AutoTokenizer.from_pretrained("LLM-course/chess-maxence-V2", trust_remote_code=True)
27
+ ```
28
+
29
+ ## Evaluation
30
+
31
+ This model is evaluated at the [Chess Challenge Arena](https://huggingface.co/spaces/LLM-course/Chess1MChallenge).
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ChessForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "dropout": 0.1,
7
+ "dtype": "float32",
8
+ "eos_token_id": 2,
9
+ "layer_norm_epsilon": 1e-05,
10
+ "model_type": "chess_transformer",
11
+ "n_ctx": 256,
12
+ "n_embd": 128,
13
+ "n_head": 4,
14
+ "n_inner": 384,
15
+ "n_layer": 4,
16
+ "pad_token_id": 0,
17
+ "tie_weights": true,
18
+ "transformers_version": "4.57.4",
19
+ "vocab_size": 81,
20
+ "auto_map": {
21
+ "AutoConfig": "model.ChessConfig",
22
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
23
+ }
24
+ }
model.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Chess Transformer Model for the Chess Challenge.
4
+
5
+ This module provides a simple GPT-style transformer architecture
6
+ designed to fit within the 1M parameter constraint.
7
+
8
+ Key components:
9
+ - ChessConfig: Configuration class for model hyperparameters
10
+ - ChessForCausalLM: The main model class for next-move prediction
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union, List
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from transformers import PretrainedConfig, PreTrainedModel
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+
26
+ class ChessConfig(PretrainedConfig):
27
+ """
28
+ Configuration class for the Chess Transformer model.
29
+ """
30
+
31
+ model_type = "chess_transformer"
32
+
33
+ def __init__(
34
+ self,
35
+ vocab_size: int = 200, # Approx size for component vocab
36
+ n_embd: int = 120, # Reduced to be divisible by heads and fit budget
37
+ n_layer: int = 6,
38
+ n_head: int = 4,
39
+ n_ctx: int = 250, # Max moves (not tokens)
40
+ n_inner: Optional[int] = None,
41
+ dropout: float = 0.1,
42
+ layer_norm_epsilon: float = 1e-5,
43
+ tie_weights: bool = True,
44
+ pad_token_id: int = 0,
45
+ bos_token_id: int = 1,
46
+ eos_token_id: int = 2,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(
50
+ pad_token_id=pad_token_id,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ **kwargs,
54
+ )
55
+
56
+ self.vocab_size = vocab_size
57
+ self.n_embd = n_embd
58
+ self.n_layer = n_layer
59
+ self.n_head = n_head
60
+ self.n_ctx = n_ctx
61
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd
62
+ self.dropout = dropout
63
+ self.layer_norm_epsilon = layer_norm_epsilon
64
+ self.tie_weights = tie_weights
65
+ self.tie_word_embeddings = bool(tie_weights)
66
+
67
+
68
+ class MultiHeadAttention(nn.Module):
69
+ def __init__(self, config: ChessConfig):
70
+ super().__init__()
71
+ assert config.n_embd % config.n_head == 0
72
+ self.n_head = config.n_head
73
+ self.n_embd = config.n_embd
74
+ self.head_dim = config.n_embd // config.n_head
75
+
76
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
77
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
78
+ self.dropout = nn.Dropout(config.dropout)
79
+
80
+ self.register_buffer(
81
+ "bias",
82
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
83
+ 1, 1, config.n_ctx, config.n_ctx
84
+ ),
85
+ persistent=False,
86
+ )
87
+
88
+ def forward(self, x, attention_mask=None):
89
+ batch_size, seq_len, _ = x.size()
90
+ qkv = self.c_attn(x)
91
+ q, k, v = qkv.split(self.n_embd, dim=2)
92
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
93
+ k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
94
+ v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
95
+
96
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
97
+
98
+ causal_mask = self.bias[:, :, :seq_len, :seq_len]
99
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
100
+
101
+ if attention_mask is not None:
102
+ # Mask should be broadcastable
103
+ attn_weights = attn_weights + attention_mask
104
+
105
+ attn_weights = F.softmax(attn_weights, dim=-1)
106
+ attn_weights = self.dropout(attn_weights)
107
+
108
+ attn_output = torch.matmul(attn_weights, v)
109
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
110
+ batch_size, seq_len, self.n_embd
111
+ )
112
+ return self.c_proj(attn_output)
113
+
114
+
115
+ class FeedForward(nn.Module):
116
+ def __init__(self, config: ChessConfig):
117
+ super().__init__()
118
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
119
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
120
+ self.dropout = nn.Dropout(config.dropout)
121
+
122
+ def forward(self, x):
123
+ x = self.c_fc(x)
124
+ x = F.gelu(x)
125
+ x = self.c_proj(x)
126
+ x = self.dropout(x)
127
+ return x
128
+
129
+
130
+ class TransformerBlock(nn.Module):
131
+ def __init__(self, config: ChessConfig):
132
+ super().__init__()
133
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
134
+ self.attn = MultiHeadAttention(config)
135
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
136
+ self.mlp = FeedForward(config)
137
+
138
+ def forward(self, x, attention_mask=None):
139
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
140
+ x = x + self.mlp(self.ln_2(x))
141
+ return x
142
+
143
+
144
+ class ChessForCausalLM(PreTrainedModel):
145
+ config_class = ChessConfig
146
+ base_model_prefix = "transformer"
147
+ supports_gradient_checkpointing = True
148
+
149
+ def __init__(self, config: ChessConfig):
150
+ super().__init__(config)
151
+
152
+ # Component embeddings (Color, Piece, Src, Dst, Suffix)
153
+ self.wte_color = nn.Embedding(config.vocab_size, config.n_embd)
154
+ self.wte_piece = nn.Embedding(config.vocab_size, config.n_embd)
155
+ self.wte_src = nn.Embedding(config.vocab_size, config.n_embd)
156
+ self.wte_dst = nn.Embedding(config.vocab_size, config.n_embd)
157
+ self.wte_suf = nn.Embedding(config.vocab_size, config.n_embd)
158
+
159
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
160
+ self.drop = nn.Dropout(config.dropout)
161
+
162
+ self.h = nn.ModuleList([
163
+ TransformerBlock(config) for _ in range(config.n_layer)
164
+ ])
165
+
166
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
167
+
168
+ # 5 Heads for predicting next components
169
+ # We model p(NextMove | History).
170
+ # Components of NextMove are predicted conditionally independent given History (simplification)
171
+ # or we could make them autoregressive within the move.
172
+ # For "product encoding", parallel prediction is natural.
173
+ self.head_color = nn.Linear(config.n_embd, config.vocab_size, bias=False)
174
+ self.head_piece = nn.Linear(config.n_embd, config.vocab_size, bias=False)
175
+ self.head_src = nn.Linear(config.n_embd, config.vocab_size, bias=False)
176
+ self.head_dst = nn.Linear(config.n_embd, config.vocab_size, bias=False)
177
+ self.head_suf = nn.Linear(config.n_embd, config.vocab_size, bias=False)
178
+
179
+ self.post_init()
180
+
181
+ def _init_weights(self, module):
182
+ if isinstance(module, nn.Linear):
183
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
184
+ if module.bias is not None:
185
+ torch.nn.init.zeros_(module.bias)
186
+ elif isinstance(module, nn.Embedding):
187
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
188
+ elif isinstance(module, nn.LayerNorm):
189
+ torch.nn.init.ones_(module.weight)
190
+ torch.nn.init.zeros_(module.bias)
191
+
192
+ def get_input_embeddings(self):
193
+ # Return first embedding as proxy, though we have multiple
194
+ return self.wte_color
195
+
196
+ def forward(
197
+ self,
198
+ input_ids: torch.LongTensor,
199
+ attention_mask: Optional[torch.Tensor] = None,
200
+ position_ids: Optional[torch.LongTensor] = None,
201
+ labels: Optional[torch.LongTensor] = None,
202
+ return_dict: Optional[bool] = None,
203
+ **kwargs,
204
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
205
+
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ batch_size, seq_len = input_ids.size()
209
+
210
+ # Ensure sequence length is multiple of 5
211
+ if seq_len % 5 != 0:
212
+ # Pad or truncate? For training we expect aligned batches
213
+ # Truncate to nearest multiple of 5
214
+ new_len = (seq_len // 5) * 5
215
+ input_ids = input_ids[:, :new_len]
216
+ if labels is not None:
217
+ labels = labels[:, :new_len]
218
+ if attention_mask is not None:
219
+ attention_mask = attention_mask[:, :new_len]
220
+ seq_len = new_len
221
+
222
+ num_moves = seq_len // 5
223
+
224
+ # Reshape to (B, L, 5)
225
+ # Components: 0=Color, 1=Piece, 2=Src, 3=Dst, 4=Suf
226
+ reshaped_ids = input_ids.view(batch_size, num_moves, 5)
227
+
228
+ # Product Embedding
229
+ emb_c = self.wte_color(reshaped_ids[:, :, 0])
230
+ emb_p = self.wte_piece(reshaped_ids[:, :, 1])
231
+ emb_s = self.wte_src(reshaped_ids[:, :, 2])
232
+ emb_d = self.wte_dst(reshaped_ids[:, :, 3])
233
+ emb_f = self.wte_suf(reshaped_ids[:, :, 4])
234
+
235
+ # Element-wise product
236
+ token_embeds = emb_c * emb_p * emb_s * emb_d * emb_f
237
+
238
+ # Position Embeddings
239
+ device = input_ids.device
240
+ if position_ids is None:
241
+ position_ids = torch.arange(num_moves, device=device).unsqueeze(0)
242
+
243
+ position_embeds = self.wpe(position_ids)
244
+ hidden_states = self.drop(token_embeds + position_embeds)
245
+
246
+ # Attention mask adaptation
247
+ # input mask is (B, 5L). We need (B, L).
248
+ # We consider a move valid if ALL components are valid? Or ANY?
249
+ # Typically padding is consistent.
250
+ if attention_mask is not None:
251
+ # Take every 5th element or min/max
252
+ reshaped_mask = attention_mask.view(batch_size, num_moves, 5)
253
+ # If any part is unmasked (1), keep it?
254
+ # Usually PAD=0. If all are PAD, then 0.
255
+ chess_mask = reshaped_mask.all(dim=-1).float() # (B, L)
256
+ # Standard broadcast for attention: (B, 1, 1, L)
257
+ extended_attention_mask = (1.0 - chess_mask) * -10000.0
258
+ extended_attention_mask = extended_attention_mask.unsqueeze(1).unsqueeze(2)
259
+ else:
260
+ extended_attention_mask = None
261
+
262
+ # Transformer
263
+ for block in self.h:
264
+ hidden_states = block(hidden_states, attention_mask=extended_attention_mask)
265
+
266
+ hidden_states = self.ln_f(hidden_states)
267
+
268
+ # Output Heads (Predicting Next Move Components)
269
+ logits_c = self.head_color(hidden_states)
270
+ logits_p = self.head_piece(hidden_states)
271
+ logits_s = self.head_src(hidden_states)
272
+ logits_d = self.head_dst(hidden_states)
273
+ logits_f = self.head_suf(hidden_states)
274
+
275
+ # Stack logits: (B, L, 5, V)
276
+ logits_stacked = torch.stack([logits_c, logits_p, logits_s, logits_d, logits_f], dim=2)
277
+
278
+ # Compute Loss
279
+ loss = None
280
+ if labels is not None:
281
+ # Reshape labels: (B, L, 5)
282
+ labels_reshaped = labels.view(batch_size, num_moves, 5)
283
+
284
+ # Shift: Hidden[t] predicts Labels[t+1]
285
+ shift_logits = logits_stacked[:, :-1, :, :].contiguous()
286
+ shift_labels = labels_reshaped[:, 1:, :].contiguous()
287
+
288
+ # Flatten
289
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
290
+ loss = loss_fct(
291
+ shift_logits.view(-1, self.config.vocab_size),
292
+ shift_labels.view(-1)
293
+ )
294
+
295
+ # Return structured output
296
+ # To satisfy Trainer, we might need to return (B, 5L, V) logits?
297
+ # But we produced (B, L, 5, V). Flattening gives (B, 5L, V).
298
+ # Trainer expects logits matching input length usually, or labels length.
299
+
300
+ flat_logits = logits_stacked.view(batch_size, -1, self.config.vocab_size)
301
+
302
+ if not return_dict:
303
+ output = (flat_logits,)
304
+ return ((loss,) + output) if loss is not None else output
305
+
306
+ return CausalLMOutputWithPast(
307
+ loss=loss,
308
+ logits=flat_logits,
309
+ )
310
+
311
+ @torch.no_grad()
312
+ def generate_move(
313
+ self,
314
+ input_ids: torch.LongTensor,
315
+ temperature: float = 1.0,
316
+ top_k: Optional[int] = None,
317
+ top_p: Optional[float] = None,
318
+ ) -> List[int]:
319
+ """
320
+ Generate the next move (5 tokens).
321
+ """
322
+ self.eval()
323
+
324
+ # Forward pass
325
+ # input_ids (1, 5L)
326
+ outputs = self(input_ids)
327
+ # Logits: (1, 5L, V)
328
+ # We want the last move prediction.
329
+ # The logits for the NEXT move are at the very end.
330
+ # Specifically, the last block of 5 logits corresponds to predictions from the last hidden state.
331
+
332
+ # Check dimensions
333
+ next_move_logits = outputs.logits[:, -5:, :] # (1, 5, V)
334
+
335
+ generated = []
336
+ for i in range(5):
337
+ logits = next_move_logits[:, i, :] / temperature
338
+ # Apply filtering
339
+ if top_k is not None:
340
+ v, _ = torch.topk(logits, top_k)
341
+ logits[logits < v[:, [-1]]] = -float('Inf')
342
+
343
+ probs = F.softmax(logits, dim=-1)
344
+ next_token = torch.multinomial(probs, num_samples=1)
345
+ generated.append(next_token.item())
346
+
347
+ return generated
348
+
349
+ # Register
350
+ from transformers import AutoConfig, AutoModelForCausalLM
351
+ AutoConfig.register("chess_transformer", ChessConfig)
352
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9de2b4871bde7dc163d4f8a6f930da7a4f2f5bd89b47e4fadc5c628ceca711d
3
+ size 3197992
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer breaks down moves into 5 components:
5
+ Color, Piece, Source, Destination, Suffix.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ import re
13
+ from typing import Dict, List, Optional
14
+
15
+ from transformers import PreTrainedTokenizer
16
+
17
+
18
+ class ChessTokenizer(PreTrainedTokenizer):
19
+ """
20
+ A component-based tokenizer for chess moves.
21
+
22
+ Each move is split into 5 tokens:
23
+ [Color, Piece, Source, Destination, Suffix]
24
+
25
+ Vocabulary is fixed and deterministic.
26
+ """
27
+
28
+ model_input_names = ["input_ids", "attention_mask"]
29
+ vocab_files_names = {"vocab_file": "vocab.json"}
30
+
31
+ # Special tokens
32
+ PAD_TOKEN = "[PAD]"
33
+ BOS_TOKEN = "[BOS]"
34
+ EOS_TOKEN = "[EOS]"
35
+ UNK_TOKEN = "[UNK]"
36
+
37
+ # Component definitions
38
+ COLORS = ["W", "B"]
39
+ PIECES = ["P", "N", "B", "R", "Q", "K"]
40
+ FILES = ["a", "b", "c", "d", "e", "f", "g", "h"]
41
+ RANKS = ["1", "2", "3", "4", "5", "6", "7", "8"]
42
+ SUFFIXES = ["", "(x)", "(+)", "(+*)", "(o)", "(O)"]
43
+
44
+ def __init__(
45
+ self,
46
+ vocab_file: Optional[str] = None,
47
+ vocab: Optional[Dict[str, int]] = None,
48
+ **kwargs,
49
+ ):
50
+
51
+ self._pad_token = self.PAD_TOKEN
52
+ self._bos_token = self.BOS_TOKEN
53
+ self._eos_token = self.EOS_TOKEN
54
+ self._unk_token = self.UNK_TOKEN
55
+
56
+ kwargs.pop("pad_token", None)
57
+ kwargs.pop("bos_token", None)
58
+ kwargs.pop("eos_token", None)
59
+ kwargs.pop("unk_token", None)
60
+
61
+ if vocab is not None:
62
+ self._vocab = vocab
63
+ elif vocab_file is not None and os.path.exists(vocab_file):
64
+ with open(vocab_file, "r", encoding="utf-8") as f:
65
+ self._vocab = json.load(f)
66
+ else:
67
+ self._vocab = self._create_default_vocab()
68
+
69
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
70
+
71
+ super().__init__(
72
+ pad_token=self._pad_token,
73
+ bos_token=self._bos_token,
74
+ eos_token=self._eos_token,
75
+ unk_token=self._unk_token,
76
+ **kwargs,
77
+ )
78
+
79
+ def _create_default_vocab(self) -> Dict[str, int]:
80
+ tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
81
+
82
+ # Add all possible components
83
+ tokens.extend(self.COLORS)
84
+ tokens.extend(self.PIECES)
85
+
86
+ # Squares
87
+ squares = [f"{f}{r}" for f in self.FILES for r in self.RANKS]
88
+ tokens.extend(squares)
89
+
90
+ # Suffixes (ensure empty string is handled explicitly if needed, but usually empty splitting result needs a token)
91
+ # We will map "no suffix" to a specific token, e.g., "_" or just use PAD?
92
+ # Using a dedicated empty token is safer for the 5-component structure.
93
+ # Let's use "[None]" for empty suffix to be explicit, or just "" if valid key.
94
+ # JSON keys must be strings. "" is valid.
95
+
96
+ # Add suffixes
97
+ for s in self.SUFFIXES:
98
+ if s == "":
99
+ tokens.append("[None]") # Representation for empty suffix
100
+ else:
101
+ tokens.append(s)
102
+
103
+ # Unique tokens only (order matters for ID stability)
104
+ seen = set()
105
+ unique_tokens = []
106
+ for t in tokens:
107
+ if t not in seen:
108
+ unique_tokens.append(t)
109
+ seen.add(t)
110
+
111
+ return {t: i for i, t in enumerate(unique_tokens)}
112
+
113
+ @property
114
+ def vocab_size(self) -> int:
115
+ return len(self._vocab)
116
+
117
+ def get_vocab(self) -> Dict[str, int]:
118
+ return dict(self._vocab)
119
+
120
+ def _tokenize(self, text: str) -> List[str]:
121
+ # Text is space-separated moves
122
+ moves = text.strip().split()
123
+ tokens = []
124
+
125
+ for move in moves:
126
+ # Handle special tokens directly if they appear in text (rare in raw data but good for safety)
127
+ if move in [self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN, self.UNK_TOKEN]:
128
+ # Expand special tokens to 5-tuples for consistency?
129
+ # Or keep as single tokens?
130
+ # If we want the model to reshape (..., 5), we MUST have multiple of 5.
131
+ # Let's repeat them 5 times.
132
+ tokens.extend([move] * 5)
133
+ continue
134
+
135
+ # Parse Move: e.g. WPe2e4(x)
136
+ # Regex to capture: (Color)(Piece)(Src)(Dst)(Suffix)
137
+ # Suffix is optional.
138
+ # However some moves might be castling?
139
+ # Note: Dataset says "(o)/(O)=castling".
140
+ # If the move is literally "(o)", it lacks Color/Piece.
141
+ # But the example `WPe2e4` implies standard algebraic.
142
+ # `(o)` usually appears as `WKe1g1(o)`?
143
+ # Let's assume the string format is always full or identifiable.
144
+
145
+ # Simple parsing:
146
+ # Color: 1 char
147
+ # Piece: 1 char
148
+ # Src: 2 chars
149
+ # Dst: 2 chars
150
+ # Suffix: Remainder
151
+
152
+ if len(move) < 6: # Shortest move WPe2e4 is 6 chars.
153
+ # Maybe castling? "0-0"? No, "extended UCI".
154
+ # If invalid, emit UNK x 5
155
+ tokens.extend([self.UNK_TOKEN] * 5)
156
+ continue
157
+
158
+ c = move[0]
159
+ p = move[1]
160
+ src = move[2:4]
161
+ dst = move[4:6]
162
+ suf = move[6:]
163
+
164
+ if suf == "":
165
+ suf_tok = "[None]"
166
+ else:
167
+ suf_tok = suf
168
+
169
+ # Validation (optional, but good for safety)
170
+ raw_components = [c, p, src, dst, suf_tok]
171
+
172
+ # Check if all are in vocab, else UNK
173
+ final_components = []
174
+ for comp in raw_components:
175
+ if comp in self._vocab:
176
+ final_components.append(comp)
177
+ else:
178
+ final_components.append(self.UNK_TOKEN)
179
+
180
+ tokens.extend(final_components)
181
+
182
+ return tokens
183
+
184
+ def _convert_token_to_id(self, token: str) -> int:
185
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
186
+
187
+ def _convert_id_to_token(self, index: int) -> str:
188
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
189
+
190
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
191
+ # Reconstruct moves
192
+ # tokens is list of components
193
+ output = []
194
+ # Process in chunks of 5
195
+ for i in range(0, len(tokens), 5):
196
+ chunk = tokens[i:i+5]
197
+ if len(chunk) < 5:
198
+ break
199
+
200
+ # Check if special
201
+ if chunk[0] in [self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN]:
202
+ continue # Skip specials for string output
203
+
204
+ c, p, src, dst, suf = chunk
205
+ if suf == "[None]":
206
+ suf = ""
207
+
208
+ output.append(f"{c}{p}{src}{dst}{suf}")
209
+
210
+ return " ".join(output)
211
+
212
+ @classmethod
213
+ def build_vocab_from_dataset(cls, *args, **kwargs):
214
+ # We use a fixed vocab, so just return an instance
215
+ return cls()
216
+
217
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
218
+ if not os.path.isdir(save_directory):
219
+ os.makedirs(save_directory, exist_ok=True)
220
+ vocab_file = os.path.join(
221
+ save_directory,
222
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
223
+ )
224
+ with open(vocab_file, "w", encoding="utf-8") as f:
225
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
226
+ return (vocab_file,)
tokenizer_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "bos_token": "[BOS]",
37
+ "clean_up_tokenization_spaces": false,
38
+ "eos_token": "[EOS]",
39
+ "extra_special_tokens": {},
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "[PAD]",
42
+ "tokenizer_class": "ChessTokenizer",
43
+ "unk_token": "[UNK]",
44
+ "auto_map": {
45
+ "AutoTokenizer": [
46
+ "tokenizer.ChessTokenizer",
47
+ null
48
+ ]
49
+ }
50
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4afcad87091f85d79d28ad591f8e547afc89dd10d790e5c8998ddbf4fb90f97
3
+ size 5777
vocab.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[BOS]": 1,
4
+ "[EOS]": 2,
5
+ "[UNK]": 3,
6
+ "W": 4,
7
+ "B": 5,
8
+ "P": 6,
9
+ "N": 7,
10
+ "R": 8,
11
+ "Q": 9,
12
+ "K": 10,
13
+ "a1": 11,
14
+ "a2": 12,
15
+ "a3": 13,
16
+ "a4": 14,
17
+ "a5": 15,
18
+ "a6": 16,
19
+ "a7": 17,
20
+ "a8": 18,
21
+ "b1": 19,
22
+ "b2": 20,
23
+ "b3": 21,
24
+ "b4": 22,
25
+ "b5": 23,
26
+ "b6": 24,
27
+ "b7": 25,
28
+ "b8": 26,
29
+ "c1": 27,
30
+ "c2": 28,
31
+ "c3": 29,
32
+ "c4": 30,
33
+ "c5": 31,
34
+ "c6": 32,
35
+ "c7": 33,
36
+ "c8": 34,
37
+ "d1": 35,
38
+ "d2": 36,
39
+ "d3": 37,
40
+ "d4": 38,
41
+ "d5": 39,
42
+ "d6": 40,
43
+ "d7": 41,
44
+ "d8": 42,
45
+ "e1": 43,
46
+ "e2": 44,
47
+ "e3": 45,
48
+ "e4": 46,
49
+ "e5": 47,
50
+ "e6": 48,
51
+ "e7": 49,
52
+ "e8": 50,
53
+ "f1": 51,
54
+ "f2": 52,
55
+ "f3": 53,
56
+ "f4": 54,
57
+ "f5": 55,
58
+ "f6": 56,
59
+ "f7": 57,
60
+ "f8": 58,
61
+ "g1": 59,
62
+ "g2": 60,
63
+ "g3": 61,
64
+ "g4": 62,
65
+ "g5": 63,
66
+ "g6": 64,
67
+ "g7": 65,
68
+ "g8": 66,
69
+ "h1": 67,
70
+ "h2": 68,
71
+ "h3": 69,
72
+ "h4": 70,
73
+ "h5": 71,
74
+ "h6": 72,
75
+ "h7": 73,
76
+ "h8": 74,
77
+ "[None]": 75,
78
+ "(x)": 76,
79
+ "(+)": 77,
80
+ "(+*)": 78,
81
+ "(o)": 79,
82
+ "(O)": 80
83
+ }