KHOUTAIBI commited on
Commit
c7b499c
·
verified ·
1 Parent(s): 96c0542

Chess Challenge submission by KHOUTAIBI

Browse files
Files changed (8) hide show
  1. README.md +26 -0
  2. config.json +25 -0
  3. model.py +416 -0
  4. model.safetensors +3 -0
  5. special_tokens_map.json +6 -0
  6. tokenizer.py +142 -0
  7. tokenizer_config.json +49 -0
  8. vocab.json +70 -0
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chess
5
+ - llm-course
6
+ - chess-challenge
7
+ license: mit
8
+ ---
9
+
10
+ # chess-model-KHOUTAIBI_gated
11
+
12
+ Chess model submitted to the LLM Course Chess Challenge.
13
+
14
+ ## Submission Info
15
+
16
+ - **Submitted by**: [KHOUTAIBI](https://huggingface.co/KHOUTAIBI)
17
+ - **Parameters**: 992,242
18
+ - **Organization**: LLM-course
19
+
20
+ ## Model Details
21
+
22
+ - **Architecture**: Chess Transformer (GPT-style)
23
+ - **Vocab size**: 68
24
+ - **Embedding dim**: 96
25
+ - **Layers**: 10
26
+ - **Heads**: 4
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./output_gated/checkpoint-48243/",
3
+ "architectures": [
4
+ "ChessForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "model.ChessConfig",
8
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "dropout": 0.1,
12
+ "eos_token_id": 2,
13
+ "model_type": "chess_transformer",
14
+ "n_ctx": 256,
15
+ "n_embd": 96,
16
+ "n_head": 4,
17
+ "n_inner": 210,
18
+ "n_kv_head": 4,
19
+ "n_layer": 10,
20
+ "pad_token_id": 0,
21
+ "tie_weights": true,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.45.2",
24
+ "vocab_size": 68
25
+ }
model.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+
13
+
14
+ # -------------------------
15
+ # Config
16
+ # -------------------------
17
+ class ChessConfig(PretrainedConfig):
18
+ model_type = "chess_transformer"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size: int = 72,
23
+ n_embd: int = 96,
24
+ n_layer: int = 10,
25
+ n_head: int = 4,
26
+ n_ctx: int = 256,
27
+ n_inner: Optional[int] = None,
28
+ dropout: float = 0.1,
29
+ tie_weights: bool = True,
30
+ pad_token_id: int = 0,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ layer_norm_epsilon: float = 1e-5, # kept for compatibility
34
+ n_kv_head: Optional[int] = None,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(
38
+ pad_token_id=pad_token_id,
39
+ bos_token_id=bos_token_id,
40
+ eos_token_id=eos_token_id,
41
+ **kwargs,
42
+ )
43
+
44
+ self.vocab_size = int(vocab_size)
45
+ self.n_embd = int(n_embd)
46
+ self.n_layer = int(n_layer)
47
+ self.n_head = int(n_head)
48
+ self.n_ctx = int(n_ctx)
49
+ self.n_inner = int(n_inner) if n_inner is not None else int(4 * n_embd) # common for SwiGLU
50
+ self.n_kv_head = n_kv_head if n_kv_head is not None else self.n_head
51
+ self.dropout = float(dropout)
52
+ self.tie_weights = bool(tie_weights)
53
+
54
+ # HF uses this to decide whether to tie embeddings
55
+ self.tie_word_embeddings = bool(tie_weights)
56
+
57
+
58
+ # -------------------------
59
+ # RMSNorm
60
+ # -------------------------
61
+ class RMSNorm(nn.Module):
62
+ def __init__(self, dim: int, eps: float = 1e-8):
63
+ super().__init__()
64
+ self.eps = eps
65
+ self.weight = nn.Parameter(torch.ones(dim))
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ # x: [..., dim]
69
+ norm = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
70
+ return x * norm * self.weight
71
+
72
+
73
+ # -------------------------
74
+ # ALiBi helpers (0 params)
75
+ # -------------------------
76
+ # src/model.py (REPLACE your MultiHeadAttention with this one)
77
+ # Adds: (1) Grouped Query Attention (GQA) with n_kv_heads < n_head
78
+ # (2) Gated attention output (cheap scalar gate)
79
+ # (3) Optional ALiBi bias (recommended for chess, helps generalize with small models)
80
+ #
81
+ # Drop-in: keep the rest of your model.py unchanged.
82
+
83
+
84
+
85
+ def _get_alibi_slopes(n_heads: int) -> torch.Tensor:
86
+ """
87
+ Standard ALiBi slopes (Press et al.).
88
+ Works for any n_heads (not just powers of 2).
89
+ """
90
+ def get_slopes_power_of_2(n):
91
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
92
+ ratio = start
93
+ return [start * (ratio ** i) for i in range(n)]
94
+
95
+ if (n_heads & (n_heads - 1)) == 0: # power of 2
96
+ slopes = get_slopes_power_of_2(n_heads)
97
+ else:
98
+ # closest power of 2 lower
99
+ n_pow2 = 2 ** math.floor(math.log2(n_heads))
100
+ slopes = get_slopes_power_of_2(n_pow2)
101
+ # extra heads: interpolate using next power of 2
102
+ slopes_extra = get_slopes_power_of_2(2 * n_pow2)[0::2][: n_heads - n_pow2]
103
+ slopes = slopes + slopes_extra
104
+
105
+ return torch.tensor(slopes, dtype=torch.float32)
106
+
107
+
108
+ class MultiHeadAttention(nn.Module):
109
+ """
110
+ Grouped Query Attention (GQA) + gated attention output.
111
+
112
+ - Q has n_head heads
113
+ - K,V have n_kv_heads heads (shared across query heads)
114
+ - Output is multiplied by a learned sigmoid gate (scalar per token)
115
+
116
+ Config expectations:
117
+ config.n_embd
118
+ config.n_head
119
+ config.n_ctx
120
+ config.dropout
121
+ (optional) config.n_kv_head (if absent -> defaults to n_head)
122
+
123
+ If you want GQA: set config.n_kv_head = 1 or 2 (must divide n_head).
124
+ """
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"
128
+ self.n_head = config.n_head
129
+ self.n_embd = config.n_embd
130
+ self.head_dim = config.n_embd // config.n_head
131
+
132
+ # ---- GQA heads ----
133
+ self.n_kv_head = getattr(config, "n_kv_head", config.n_head)
134
+ assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
135
+ self.q_per_kv = self.n_head // self.n_kv_head # how many Q heads share one KV head
136
+
137
+ # Projections:
138
+ # Q: d -> d (n_head * head_dim)
139
+ # K,V: d -> (n_kv_head * head_dim)
140
+ self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=True)
141
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=True)
142
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=True)
143
+
144
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=True)
145
+ self.dropout = nn.Dropout(config.dropout)
146
+
147
+ # Cheap gate: scalar per token (almost free params)
148
+ self.gate = nn.Linear(config.n_embd, 1, bias=True)
149
+
150
+ # Causal mask
151
+ self.register_buffer(
152
+ "causal",
153
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx, dtype=torch.bool)),
154
+ persistent=False,
155
+ )
156
+
157
+ # ALiBi slopes (per head)
158
+ self.register_buffer("alibi_slopes", _get_alibi_slopes(self.n_head), persistent=False)
159
+
160
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
161
+ """
162
+ x: [B,T,C]
163
+ attention_mask: [B,T] with 1 for tokens, 0 for padding
164
+ """
165
+ B, T, C = x.shape
166
+ device = x.device
167
+
168
+ # Project to Q,K,V
169
+ q = self.q_proj(x) # [B,T,C]
170
+ k = self.k_proj(x) # [B,T,n_kv*D]
171
+ v = self.v_proj(x) # [B,T,n_kv*D]
172
+
173
+ # Reshape to heads
174
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B,H,T,D]
175
+ k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) # [B,Hkv,T,D]
176
+ v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) # [B,Hkv,T,D]
177
+
178
+ # Expand KV to match Q heads: repeat along head dimension
179
+ if self.n_kv_head != self.n_head:
180
+ k = k.repeat_interleave(self.q_per_kv, dim=1) # [B,H,T,D]
181
+ v = v.repeat_interleave(self.q_per_kv, dim=1) # [B,H,T,D]
182
+
183
+ # Attention scores
184
+ att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [B,H,T,T]
185
+
186
+ # ALiBi bias: -slope * distance
187
+ idx = torch.arange(T, device=device)
188
+ dist = (idx.view(T, 1) - idx.view(1, T)).clamp(min=0).to(att.dtype) # [T,T]
189
+ alibi = -self.alibi_slopes.to(att.dtype).view(1, self.n_head, 1, 1) * dist.view(1, 1, T, T)
190
+ att = att + alibi
191
+
192
+ # Causal mask
193
+ causal = self.causal[:T, :T] # [T,T]
194
+ att = att.masked_fill(~causal.view(1, 1, T, T), float("-inf"))
195
+
196
+ # Padding mask on keys (mask columns)
197
+ if attention_mask is not None:
198
+ key_mask = (attention_mask == 0).unsqueeze(1).unsqueeze(2) # [B,1,1,T]
199
+ att = att.masked_fill(key_mask, float("-inf"))
200
+
201
+ att = F.softmax(att, dim=-1)
202
+ att = self.dropout(att)
203
+
204
+ out = torch.matmul(att, v) # [B,H,T,D]
205
+ out = out.transpose(1, 2).contiguous().view(B, T, C) # [B,T,C]
206
+ out = self.out_proj(out)
207
+
208
+ # Gate (scalar per token)
209
+ g = torch.sigmoid(self.gate(x)) # [B,T,1]
210
+ out = out * g
211
+
212
+ return out
213
+
214
+
215
+ # -------------------------
216
+ # SwiGLU MLP
217
+ # -------------------------
218
+ class SwiGLU(nn.Module):
219
+ def __init__(self, config: ChessConfig):
220
+ super().__init__()
221
+ self.fc = nn.Linear(config.n_embd, 2 * config.n_inner, bias=True)
222
+ self.proj = nn.Linear(config.n_inner, config.n_embd, bias=True)
223
+ self.dropout = nn.Dropout(config.dropout)
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ x = self.fc(x) # [B,T,2*inner]
227
+ a, b = x.chunk(2, dim=-1)
228
+ x = F.silu(a) * b
229
+ x = self.proj(x)
230
+ x = self.dropout(x)
231
+ return x
232
+
233
+
234
+ # -------------------------
235
+ # Transformer block (Pre-norm)
236
+ # -------------------------
237
+ class TransformerBlock(nn.Module):
238
+ def __init__(self, config: ChessConfig):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(config.n_embd)
241
+ self.attn = MultiHeadAttention(config)
242
+ self.norm2 = RMSNorm(config.n_embd)
243
+ self.mlp = SwiGLU(config)
244
+
245
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
246
+ x = x + self.attn(self.norm1(x), attention_mask=attention_mask)
247
+ x = x + self.mlp(self.norm2(x))
248
+ return x
249
+
250
+
251
+ # -------------------------
252
+ # Main model
253
+ # -------------------------
254
+ class ChessForCausalLM(PreTrainedModel):
255
+ """
256
+ GPT-style chess model with:
257
+ - Token embeddings
258
+ - Token-type embeddings (square / promo / special-other)
259
+ - ALiBi attention (no learned positions)
260
+ - RMSNorm + SwiGLU
261
+ """
262
+ config_class = ChessConfig
263
+ base_model_prefix = "transformer"
264
+ supports_gradient_checkpointing = True
265
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
266
+ _no_split_modules = ["TransformerBlock"]
267
+
268
+ def __init__(self, config: ChessConfig):
269
+ super().__init__(config)
270
+
271
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
272
+
273
+ # token-type embeddings: 0=square, 1=special/other, 2=promotion
274
+ self.wtt = nn.Embedding(3, config.n_embd)
275
+
276
+ self.drop = nn.Dropout(config.dropout)
277
+ self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
278
+ self.norm_f = RMSNorm(config.n_embd)
279
+
280
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
281
+
282
+ if config.tie_weights:
283
+ self._tied_weights_keys = ["lm_head.weight"]
284
+
285
+ self.post_init()
286
+
287
+ if config.tie_weights:
288
+ self.tie_weights()
289
+
290
+ def get_input_embeddings(self) -> nn.Module:
291
+ return self.wte
292
+
293
+ def set_input_embeddings(self, new_embeddings: nn.Module):
294
+ self.wte = new_embeddings
295
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
296
+ self.tie_weights()
297
+
298
+ def get_output_embeddings(self) -> nn.Module:
299
+ return self.lm_head
300
+
301
+ def set_output_embeddings(self, new_embeddings: nn.Module):
302
+ self.lm_head = new_embeddings
303
+
304
+ def tie_weights(self):
305
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
306
+ self._tie_or_clone_weights(self.lm_head, self.wte)
307
+
308
+ def _init_weights(self, module: nn.Module):
309
+ if isinstance(module, nn.Linear):
310
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
311
+ if module.bias is not None:
312
+ nn.init.zeros_(module.bias)
313
+ elif isinstance(module, nn.Embedding):
314
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
315
+
316
+ @staticmethod
317
+ def make_type_ids(input_ids: torch.LongTensor) -> torch.LongTensor:
318
+ """
319
+ Build token-type ids without needing the tokenizer at runtime.
320
+ Assumes:
321
+ specials: 0..3
322
+ squares : 4..67 (64 tokens)
323
+ promos : 68..71 (4 tokens)
324
+ Everything else -> type 1 (special/other).
325
+ """
326
+ t = torch.ones_like(input_ids) # default type=1
327
+ # squares
328
+ t = torch.where((input_ids >= 4) & (input_ids <= 67), torch.zeros_like(t), t)
329
+ # promos
330
+ t = torch.where((input_ids >= 68) & (input_ids <= 71), torch.full_like(t, 2), t)
331
+ return t
332
+
333
+ def forward(
334
+ self,
335
+ input_ids: torch.LongTensor,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ labels: Optional[torch.LongTensor] = None,
338
+ return_dict: Optional[bool] = None,
339
+ **kwargs,
340
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
341
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
342
+
343
+ # embeddings
344
+ tok = self.wte(input_ids)
345
+ type_ids = self.make_type_ids(input_ids)
346
+ typ = self.wtt(type_ids)
347
+
348
+ x = self.drop(tok + typ)
349
+
350
+ # blocks
351
+ for block in self.h:
352
+ x = block(x, attention_mask=attention_mask)
353
+
354
+ x = self.norm_f(x)
355
+ logits = self.lm_head(x)
356
+
357
+ loss = None
358
+ if labels is not None:
359
+ # next-token loss
360
+ shift_logits = logits[..., :-1, :].contiguous()
361
+ shift_labels = labels[..., 1:].contiguous()
362
+ loss = F.cross_entropy(
363
+ shift_logits.view(-1, shift_logits.size(-1)),
364
+ shift_labels.view(-1),
365
+ ignore_index=-100,
366
+ )
367
+
368
+ if not return_dict:
369
+ out = (logits,)
370
+ return ((loss,) + out) if loss is not None else out
371
+
372
+ return CausalLMOutputWithPast(
373
+ loss=loss,
374
+ logits=logits,
375
+ past_key_values=None,
376
+ hidden_states=None,
377
+ attentions=None,
378
+ )
379
+
380
+ @torch.no_grad()
381
+ def generate_move(
382
+ self,
383
+ input_ids: torch.LongTensor,
384
+ temperature: float = 1.0,
385
+ top_k: Optional[int] = None,
386
+ top_p: Optional[float] = None,
387
+ ) -> int:
388
+ self.eval()
389
+ out = self(input_ids=input_ids)
390
+ logits = out.logits[:, -1, :] / max(temperature, 1e-6)
391
+
392
+ if top_k is not None and top_k > 0:
393
+ top_k = min(top_k, logits.size(-1))
394
+ thresh = torch.topk(logits, top_k, dim=-1).values[..., -1, None]
395
+ logits = logits.masked_fill(logits < thresh, float("-inf"))
396
+
397
+ if top_p is not None and 0.0 < top_p < 1.0:
398
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
399
+ probs = F.softmax(sorted_logits, dim=-1)
400
+ cum = torch.cumsum(probs, dim=-1)
401
+ remove = cum > top_p
402
+ remove[..., 1:] = remove[..., :-1].clone()
403
+ remove[..., 0] = 0
404
+ remove_idx = remove.scatter(dim=-1, index=sorted_idx, src=remove)
405
+ logits = logits.masked_fill(remove_idx, float("-inf"))
406
+
407
+ probs = F.softmax(logits, dim=-1)
408
+ next_id = torch.multinomial(probs, num_samples=1)
409
+ return int(next_id.item())
410
+
411
+
412
+ # Register custom model for HF Auto classes
413
+ from transformers import AutoConfig, AutoModelForCausalLM
414
+
415
+ AutoConfig.register("chess_transformer", ChessConfig)
416
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:147c8fdd375c14c340ce536523d9a6142bbb85e8cc5ee4c598aabc3a270c250d
3
+ size 3982840
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/tokenizer.py
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import re
7
+ from typing import Dict, List, Optional
8
+
9
+ from transformers import PreTrainedTokenizer
10
+
11
+
12
+ class ChessTokenizer(PreTrainedTokenizer):
13
+ """
14
+ Ultra-simple square tokenizer.
15
+
16
+ Vocab (68 tokens):
17
+ - 4 specials: [PAD] [BOS] [EOS] [UNK]
18
+ - 64 squares: a1..h8
19
+
20
+ Tokenization:
21
+ - Any text containing two squares -> emits those squares as tokens
22
+ - Accepts:
23
+ "WPe2e4(x+)" , "e2e4" , "e2 e4" -> ["e2","e4"]
24
+ - For longer histories, extracts ALL squares in order.
25
+
26
+ Decoding:
27
+ - Joins square tokens with spaces => evaluator regex sees them easily.
28
+ """
29
+
30
+ model_input_names = ["input_ids", "attention_mask"]
31
+ vocab_files_names = {"vocab_file": "vocab.json"}
32
+
33
+ PAD_TOKEN = "[PAD]"
34
+ BOS_TOKEN = "[BOS]"
35
+ EOS_TOKEN = "[EOS]"
36
+ UNK_TOKEN = "[UNK]"
37
+
38
+ _SQUARE_PATTERN = r"[a-h][1-8]"
39
+ _SQUARE_RE = re.compile(_SQUARE_PATTERN)
40
+
41
+ def __init__(
42
+ self,
43
+ vocab_file: Optional[str] = None,
44
+ vocab: Optional[Dict[str, int]] = None,
45
+ **kwargs,
46
+ ):
47
+ self._pad_token = self.PAD_TOKEN
48
+ self._bos_token = self.BOS_TOKEN
49
+ self._eos_token = self.EOS_TOKEN
50
+ self._unk_token = self.UNK_TOKEN
51
+
52
+ kwargs.pop("pad_token", None)
53
+ kwargs.pop("bos_token", None)
54
+ kwargs.pop("eos_token", None)
55
+ kwargs.pop("unk_token", None)
56
+
57
+ if vocab is not None:
58
+ self._vocab = vocab
59
+ elif vocab_file is not None and os.path.exists(vocab_file):
60
+ with open(vocab_file, "r", encoding="utf-8") as f:
61
+ self._vocab = json.load(f)
62
+ else:
63
+ self._vocab = self._create_fixed_vocab()
64
+
65
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
66
+
67
+ super().__init__(
68
+ pad_token=self._pad_token,
69
+ bos_token=self._bos_token,
70
+ eos_token=self._eos_token,
71
+ unk_token=self._unk_token,
72
+ **kwargs,
73
+ )
74
+
75
+ @classmethod
76
+ def _create_fixed_vocab(cls) -> Dict[str, int]:
77
+ specials = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
78
+ files = "abcdefgh"
79
+ ranks = "12345678"
80
+ squares = [f + r for r in ranks for f in files] # a1..h8
81
+ tokens = specials + squares
82
+ return {tok: i for i, tok in enumerate(tokens)}
83
+
84
+ @classmethod
85
+ def build_vocab_from_iterator(cls, iterator, **kwargs) -> "ChessTokenizer":
86
+ return cls(vocab=cls._create_fixed_vocab())
87
+
88
+ @classmethod
89
+ def build_vocab_from_dataset(cls, *args, **kwargs) -> "ChessTokenizer":
90
+ return cls(vocab=cls._create_fixed_vocab())
91
+
92
+ @property
93
+ def vocab_size(self) -> int:
94
+ return len(self._vocab)
95
+
96
+ def get_vocab(self) -> Dict[str, int]:
97
+ return dict(self._vocab)
98
+
99
+ def _tokenize(self, text: str) -> List[str]:
100
+ text = text.strip()
101
+ if not text:
102
+ return []
103
+
104
+ # Keep BOS/EOS tokens if they appear as standalone strings
105
+ # (rare, but safe)
106
+ if text in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}:
107
+ return [text]
108
+
109
+ # Extract all squares in order from the text
110
+ squares = self._SQUARE_RE.findall(text)
111
+ if not squares:
112
+ # if nothing parsable, return UNK token
113
+ return [self.UNK_TOKEN]
114
+
115
+ # Filter to vocab squares only (should always be true)
116
+ out = [sq for sq in squares if sq in self._vocab]
117
+ return out if out else [self.UNK_TOKEN]
118
+
119
+ def _convert_token_to_id(self, token: str) -> int:
120
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
121
+
122
+ def _convert_id_to_token(self, index: int) -> str:
123
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
124
+
125
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
126
+ # Drop special tokens; join squares with spaces so evaluator can parse.
127
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
128
+ toks = [t for t in tokens if t not in special]
129
+ return " ".join(toks)
130
+
131
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
132
+ if not os.path.isdir(save_directory):
133
+ os.makedirs(save_directory, exist_ok=True)
134
+
135
+ vocab_file = os.path.join(
136
+ save_directory,
137
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
138
+ )
139
+ with open(vocab_file, "w", encoding="utf-8") as f:
140
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
141
+
142
+ return (vocab_file,)
tokenizer_config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "auto_map": {
37
+ "AutoTokenizer": [
38
+ "tokenizer.ChessTokenizer",
39
+ null
40
+ ]
41
+ },
42
+ "bos_token": "[BOS]",
43
+ "clean_up_tokenization_spaces": false,
44
+ "eos_token": "[EOS]",
45
+ "model_max_length": 1000000000000000019884624838656,
46
+ "pad_token": "[PAD]",
47
+ "tokenizer_class": "ChessTokenizer",
48
+ "unk_token": "[UNK]"
49
+ }
vocab.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[BOS]": 1,
4
+ "[EOS]": 2,
5
+ "[UNK]": 3,
6
+ "a1": 4,
7
+ "b1": 5,
8
+ "c1": 6,
9
+ "d1": 7,
10
+ "e1": 8,
11
+ "f1": 9,
12
+ "g1": 10,
13
+ "h1": 11,
14
+ "a2": 12,
15
+ "b2": 13,
16
+ "c2": 14,
17
+ "d2": 15,
18
+ "e2": 16,
19
+ "f2": 17,
20
+ "g2": 18,
21
+ "h2": 19,
22
+ "a3": 20,
23
+ "b3": 21,
24
+ "c3": 22,
25
+ "d3": 23,
26
+ "e3": 24,
27
+ "f3": 25,
28
+ "g3": 26,
29
+ "h3": 27,
30
+ "a4": 28,
31
+ "b4": 29,
32
+ "c4": 30,
33
+ "d4": 31,
34
+ "e4": 32,
35
+ "f4": 33,
36
+ "g4": 34,
37
+ "h4": 35,
38
+ "a5": 36,
39
+ "b5": 37,
40
+ "c5": 38,
41
+ "d5": 39,
42
+ "e5": 40,
43
+ "f5": 41,
44
+ "g5": 42,
45
+ "h5": 43,
46
+ "a6": 44,
47
+ "b6": 45,
48
+ "c6": 46,
49
+ "d6": 47,
50
+ "e6": 48,
51
+ "f6": 49,
52
+ "g6": 50,
53
+ "h6": 51,
54
+ "a7": 52,
55
+ "b7": 53,
56
+ "c7": 54,
57
+ "d7": 55,
58
+ "e7": 56,
59
+ "f7": 57,
60
+ "g7": 58,
61
+ "h7": 59,
62
+ "a8": 60,
63
+ "b8": 61,
64
+ "c8": 62,
65
+ "d8": 63,
66
+ "e8": 64,
67
+ "f8": 65,
68
+ "g8": 66,
69
+ "h8": 67
70
+ }