bouhss commited on
Commit
663d8ea
·
verified ·
1 Parent(s): 1045380

Chess Challenge submission by bouhss

Browse files
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chess
5
+ - llm-course
6
+ - chess-challenge
7
+ license: mit
8
+ ---
9
+
10
+ # chess-stockbird2
11
+
12
+ Chess model submitted to the LLM Course Chess Challenge.
13
+
14
+ ## Submission Info
15
+ - **Submitted by**: bouhss
16
+ - **Parameters**: 992,032
17
+ - **Vocab size**: 148
18
+ - **Embedding dim**: 128
19
+ - **Layers**: 6
20
+ - **Heads**: 8
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "my_model_gpu_full/final_model",
3
+ "architectures": [
4
+ "ChessForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "model.ChessConfig",
8
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "dropout": 0.05,
12
+ "eos_token_id": 2,
13
+ "layer_norm_epsilon": 1e-06,
14
+ "mlp_type": "swiglu",
15
+ "model_type": "chess_transformer",
16
+ "n_ctx": 256,
17
+ "n_embd": 128,
18
+ "n_head": 8,
19
+ "n_inner": 248,
20
+ "n_layer": 6,
21
+ "pad_token_id": 0,
22
+ "rope_theta": 10000.0,
23
+ "tie_weights": true,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.39.3",
26
+ "use_rmsnorm": true,
27
+ "use_rope": true,
28
+ "vocab_size": 148
29
+ }
model.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+
4
+ Modern small-LLM upgrades:
5
+ - RoPE (rotary positional embeddings): no learned positional embeddings needed
6
+ - RMSNorm (optional, default True)
7
+ - SwiGLU MLP (optional, default True)
8
+ - Weight tying (default True)
9
+ - Safe loss ignore_index = -100 (HF convention)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+
23
+
24
+ class ChessConfig(PretrainedConfig):
25
+ model_type = "chess_transformer"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size: int = 1200,
30
+
31
+ # Architecture (defaults tuned to be < 1M params for common vocabs)
32
+ n_embd: int = 112,
33
+ n_layer: int = 7,
34
+ n_head: int = 7,
35
+
36
+ # Context window
37
+ n_ctx: int = 512,
38
+
39
+ # MLP hidden size:
40
+ # - if mlp_type="swiglu", this is SwiGLU hidden size h
41
+ # - if mlp_type="gelu", this is FFN inner size
42
+ n_inner: Optional[int] = 192,
43
+
44
+ dropout: float = 0.05,
45
+ layer_norm_epsilon: float = 1e-6,
46
+
47
+ # Position encoding
48
+ use_rope: bool = True,
49
+ rope_theta: float = 10000.0,
50
+
51
+ # Normalization / MLP type
52
+ use_rmsnorm: bool = True,
53
+ mlp_type: str = "swiglu", # "swiglu" or "gelu"
54
+
55
+ # Weight tying
56
+ tie_weights: bool = True,
57
+
58
+ pad_token_id: int = 0,
59
+ bos_token_id: int = 1,
60
+ eos_token_id: int = 2,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ **kwargs,
68
+ )
69
+
70
+ if n_embd % n_head != 0:
71
+ raise ValueError(f"n_embd ({n_embd}) must be divisible by n_head ({n_head})")
72
+
73
+ head_dim = n_embd // n_head
74
+ if use_rope and (head_dim % 2 != 0):
75
+ raise ValueError(
76
+ f"RoPE requires even head_dim, got head_dim={head_dim}. "
77
+ f"Choose n_embd/n_head even."
78
+ )
79
+
80
+ self.vocab_size = vocab_size
81
+ self.n_embd = n_embd
82
+ self.n_layer = n_layer
83
+ self.n_head = n_head
84
+ self.n_ctx = n_ctx
85
+ self.n_inner = n_inner if n_inner is not None else (2 * n_embd)
86
+ self.dropout = dropout
87
+ self.layer_norm_epsilon = layer_norm_epsilon
88
+
89
+ self.use_rope = use_rope
90
+ self.rope_theta = rope_theta
91
+
92
+ self.use_rmsnorm = use_rmsnorm
93
+ self.mlp_type = mlp_type
94
+
95
+ self.tie_weights = tie_weights
96
+ # HF uses this field for embedding tying behavior
97
+ self.tie_word_embeddings = bool(tie_weights)
98
+
99
+
100
+ class RMSNorm(nn.Module):
101
+ def __init__(self, dim: int, eps: float = 1e-6):
102
+ super().__init__()
103
+ self.eps = eps
104
+ self.weight = nn.Parameter(torch.ones(dim))
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ norm = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
108
+ return x * norm * self.weight
109
+
110
+
111
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
112
+ x1 = x[..., 0::2]
113
+ x2 = x[..., 1::2]
114
+ out = torch.empty_like(x)
115
+ out[..., 0::2] = -x2
116
+ out[..., 1::2] = x1
117
+ return out
118
+
119
+
120
+ class RotaryEmbedding(nn.Module):
121
+ """
122
+ RoPE cache builder. Applies RoPE to q,k with shape (B,H,T,D).
123
+ """
124
+
125
+ def __init__(self, head_dim: int, theta: float = 10000.0):
126
+ super().__init__()
127
+ if head_dim % 2 != 0:
128
+ raise ValueError(f"RoPE requires even head_dim, got {head_dim}")
129
+
130
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
131
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
132
+
133
+ self._cos_cached = None
134
+ self._sin_cached = None
135
+ self._seq_len_cached = 0
136
+ self._device_cached = None
137
+ self._dtype_cached = None
138
+
139
+ def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
140
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
141
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq) # (T, D/2)
142
+
143
+ cos = freqs.cos().to(dtype=dtype)
144
+ sin = freqs.sin().to(dtype=dtype)
145
+
146
+ self._cos_cached = cos
147
+ self._sin_cached = sin
148
+ self._seq_len_cached = seq_len
149
+ self._device_cached = device
150
+ self._dtype_cached = dtype
151
+
152
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
153
+ # q,k: (B,H,T,D)
154
+ T = q.size(-2)
155
+ device = q.device
156
+ dtype = q.dtype
157
+
158
+ if (
159
+ self._cos_cached is None
160
+ or T > self._seq_len_cached
161
+ or device != self._device_cached
162
+ or dtype != self._dtype_cached
163
+ ):
164
+ self._build_cache(T, device, dtype)
165
+
166
+ cos = self._cos_cached[:T] # (T, D/2)
167
+ sin = self._sin_cached[:T] # (T, D/2)
168
+
169
+ # broadcast to (1,1,T,D) via repeat_interleave on last dim
170
+ cos = torch.repeat_interleave(cos.unsqueeze(0).unsqueeze(0), 2, dim=-1)
171
+ sin = torch.repeat_interleave(sin.unsqueeze(0).unsqueeze(0), 2, dim=-1)
172
+
173
+ q_out = (q * cos) + (rotate_half(q) * sin)
174
+ k_out = (k * cos) + (rotate_half(k) * sin)
175
+ return q_out, k_out
176
+
177
+
178
+ class MultiHeadAttention(nn.Module):
179
+ def __init__(self, config: ChessConfig):
180
+ super().__init__()
181
+
182
+ self.n_head = config.n_head
183
+ self.n_embd = config.n_embd
184
+ self.head_dim = config.n_embd // config.n_head
185
+
186
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
187
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
188
+ self.dropout = nn.Dropout(config.dropout)
189
+
190
+ self.use_rope = bool(config.use_rope)
191
+ self.rope = RotaryEmbedding(self.head_dim, theta=config.rope_theta) if self.use_rope else None
192
+
193
+ # causal mask buffer (expandable)
194
+ self.register_buffer(
195
+ "bias",
196
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx),
197
+ persistent=False,
198
+ )
199
+
200
+ def _ensure_causal_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype):
201
+ if self.bias.size(-1) >= seq_len and self.bias.device == device:
202
+ return
203
+ self.bias = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=dtype)).view(1, 1, seq_len, seq_len)
204
+
205
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ B, T, _ = x.size()
207
+
208
+ qkv = self.c_attn(x)
209
+ q, k, v = qkv.split(self.n_embd, dim=2)
210
+
211
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B,H,T,D)
212
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
213
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
214
+
215
+ if self.use_rope:
216
+ q, k = self.rope(q, k)
217
+
218
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
219
+
220
+ self._ensure_causal_mask(T, attn.device, attn.dtype)
221
+ causal_mask = self.bias[:, :, :T, :T]
222
+ mask_value = torch.finfo(attn.dtype).min
223
+ attn = attn.masked_fill(causal_mask == 0, mask_value)
224
+
225
+ # padding mask (1=keep, 0=mask)
226
+ if attention_mask is not None:
227
+ am = attention_mask.unsqueeze(1).unsqueeze(2) # (B,1,1,T)
228
+ attn = attn.masked_fill(am == 0, mask_value)
229
+
230
+ attn = F.softmax(attn, dim=-1)
231
+ attn = self.dropout(attn)
232
+
233
+ y = torch.matmul(attn, v) # (B,H,T,D)
234
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_embd)
235
+
236
+ y = self.c_proj(y)
237
+ y = self.dropout(y)
238
+ return y
239
+
240
+
241
+ class SwiGLU(nn.Module):
242
+ def __init__(self, config: ChessConfig):
243
+ super().__init__()
244
+ h = config.n_inner
245
+ self.w12 = nn.Linear(config.n_embd, 2 * h)
246
+ self.w3 = nn.Linear(h, config.n_embd)
247
+ self.dropout = nn.Dropout(config.dropout)
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x12 = self.w12(x)
251
+ x1, x2 = x12.chunk(2, dim=-1)
252
+ x = F.silu(x1) * x2
253
+ x = self.w3(x)
254
+ x = self.dropout(x)
255
+ return x
256
+
257
+
258
+ class FeedForwardGELU(nn.Module):
259
+ def __init__(self, config: ChessConfig):
260
+ super().__init__()
261
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
262
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
263
+ self.dropout = nn.Dropout(config.dropout)
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ x = self.c_fc(x)
267
+ x = F.gelu(x)
268
+ x = self.c_proj(x)
269
+ x = self.dropout(x)
270
+ return x
271
+
272
+
273
+ class TransformerBlock(nn.Module):
274
+ def __init__(self, config: ChessConfig):
275
+ super().__init__()
276
+
277
+ if config.use_rmsnorm:
278
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
279
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
280
+ else:
281
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
282
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
283
+
284
+ self.attn = MultiHeadAttention(config)
285
+
286
+ if config.mlp_type.lower() == "swiglu":
287
+ self.mlp = SwiGLU(config)
288
+ else:
289
+ self.mlp = FeedForwardGELU(config)
290
+
291
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
292
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
293
+ x = x + self.mlp(self.ln_2(x))
294
+ return x
295
+
296
+
297
+ class ChessForCausalLM(PreTrainedModel):
298
+ config_class = ChessConfig
299
+ base_model_prefix = "transformer"
300
+ supports_gradient_checkpointing = True
301
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
302
+ _no_split_modules = ["TransformerBlock"]
303
+
304
+
305
+ def __init__(self, config: ChessConfig):
306
+ super().__init__(config)
307
+
308
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
309
+
310
+ # learned positional embeddings only if RoPE disabled
311
+ self.wpe = None
312
+ if not config.use_rope:
313
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
314
+
315
+ self.drop = nn.Dropout(config.dropout)
316
+ self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
317
+
318
+ if config.use_rmsnorm:
319
+ self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
320
+ else:
321
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
322
+
323
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
324
+
325
+ if config.tie_weights:
326
+ self._tied_weights_keys = ["lm_head.weight"]
327
+
328
+ self.post_init()
329
+
330
+ if config.tie_weights:
331
+ self.tie_weights()
332
+
333
+ def get_input_embeddings(self) -> nn.Module:
334
+ return self.wte
335
+
336
+ def set_input_embeddings(self, new_embeddings: nn.Module):
337
+ self.wte = new_embeddings
338
+ if getattr(self.config, "tie_weights", False):
339
+ self.tie_weights()
340
+
341
+ def get_output_embeddings(self) -> nn.Module:
342
+ return self.lm_head
343
+
344
+ def set_output_embeddings(self, new_embeddings: nn.Module):
345
+ self.lm_head = new_embeddings
346
+
347
+ def tie_weights(self):
348
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
349
+ self._tie_or_clone_weights(self.lm_head, self.wte)
350
+
351
+ def _init_weights(self, module: nn.Module):
352
+ if isinstance(module, nn.Linear):
353
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
354
+ if module.bias is not None:
355
+ torch.nn.init.zeros_(module.bias)
356
+ elif isinstance(module, nn.Embedding):
357
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
358
+
359
+ def forward(
360
+ self,
361
+ input_ids: torch.LongTensor,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ labels: Optional[torch.LongTensor] = None,
365
+ return_dict: Optional[bool] = None,
366
+ **kwargs,
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+ B, T = input_ids.size()
370
+ device = input_ids.device
371
+
372
+ x = self.wte(input_ids)
373
+
374
+ if self.wpe is not None:
375
+ if position_ids is None:
376
+ position_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
377
+ x = x + self.wpe(position_ids)
378
+
379
+ x = self.drop(x)
380
+
381
+ for block in self.h:
382
+ x = block(x, attention_mask=attention_mask)
383
+
384
+ x = self.ln_f(x)
385
+ logits = self.lm_head(x)
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ shift_logits = logits[..., :-1, :].contiguous()
390
+ shift_labels = labels[..., 1:].contiguous()
391
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
392
+ loss = loss_fct(
393
+ shift_logits.view(-1, shift_logits.size(-1)),
394
+ shift_labels.view(-1),
395
+ )
396
+
397
+ if not return_dict:
398
+ output = (logits,)
399
+ return ((loss,) + output) if loss is not None else output
400
+
401
+ return CausalLMOutputWithPast(
402
+ loss=loss,
403
+ logits=logits,
404
+ past_key_values=None,
405
+ hidden_states=None,
406
+ attentions=None,
407
+ )
408
+
409
+ @torch.no_grad()
410
+ def generate_move(
411
+ self,
412
+ input_ids: torch.LongTensor,
413
+ temperature: float = 0.7,
414
+ top_k: Optional[int] = 50,
415
+ top_p: Optional[float] = None,
416
+ ) -> int:
417
+ self.eval()
418
+
419
+ outputs = self(input_ids)
420
+ logits = outputs.logits[:, -1, :] / max(float(temperature), 1e-6)
421
+
422
+ if top_k is not None and top_k > 0:
423
+ k = min(int(top_k), logits.size(-1))
424
+ thresh = torch.topk(logits, k)[0][..., -1, None]
425
+ logits = logits.masked_fill(logits < thresh, torch.finfo(logits.dtype).min)
426
+
427
+ if top_p is not None:
428
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
429
+ probs = F.softmax(sorted_logits, dim=-1)
430
+ cum = torch.cumsum(probs, dim=-1)
431
+ to_remove = cum > float(top_p)
432
+ to_remove[..., 1:] = to_remove[..., :-1].clone()
433
+ to_remove[..., 0] = 0
434
+ indices_to_remove = to_remove.scatter(dim=-1, index=sorted_indices, src=to_remove)
435
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
436
+
437
+ probs = F.softmax(logits, dim=-1)
438
+ next_token = torch.multinomial(probs, num_samples=1)
439
+ return int(next_token.item())
440
+
441
+
442
+ # Register the model with Auto classes
443
+ from transformers import AutoConfig, AutoModelForCausalLM
444
+
445
+ AutoConfig.register("chess_transformer", ChessConfig)
446
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9148bdf02f882142d8414a64c14a568340412aa2d8c046ee1979da5d498f62e3
3
+ size 3973424
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
src/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chess Challenge source module."""
2
+
3
+ from .model import ChessConfig, ChessForCausalLM
4
+ from .tokenizer import ChessTokenizer
5
+
6
+ # Lazy import for evaluate to avoid RuntimeWarning when running as module
7
+ def __getattr__(name):
8
+ if name == "ChessEvaluator":
9
+ from .evaluate import ChessEvaluator
10
+ return ChessEvaluator
11
+ if name == "load_model_from_hub":
12
+ from .evaluate import load_model_from_hub
13
+ return load_model_from_hub
14
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
15
+
16
+ __all__ = [
17
+ "ChessConfig",
18
+ "ChessForCausalLM",
19
+ "ChessTokenizer",
20
+ "ChessEvaluator",
21
+ "load_model_from_hub",
22
+ ]
src/.ipynb_checkpoints/data-checkpoint.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading utilities for the Chess Challenge.
3
+
4
+ This module provides functions to load and process chess game data
5
+ from the Lichess dataset on Hugging Face.
6
+
7
+ IMPORTANT NOTE (compat with template evaluate + custom tokenizers):
8
+ - Do NOT manually prepend BOS in the raw text.
9
+ The tokenizer should handle BOS via build_inputs_with_special_tokens.
10
+ This avoids double-BOS issues and keeps train/eval conventions aligned.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Dict, Iterator, List, Optional
16
+
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+
20
+
21
+ class ChessDataset(Dataset):
22
+ """
23
+ PyTorch Dataset for chess games.
24
+
25
+ Each game is tokenized and truncated/padded to max_length.
26
+ Labels are identical to input_ids; the model shifts internally.
27
+ Padding labels are set to -100 (HF convention) so they are ignored by CE loss.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ tokenizer,
33
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
34
+ split: str = "train",
35
+ column: str = "text",
36
+ max_length: int = 256,
37
+ max_samples: Optional[int] = None,
38
+ ):
39
+ from datasets import load_dataset
40
+
41
+ self.tokenizer = tokenizer
42
+ self.max_length = max_length
43
+ self.column = column
44
+
45
+ dataset = load_dataset(dataset_name, split=split)
46
+
47
+ if max_samples is not None:
48
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
49
+
50
+ self.data = dataset
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.data)
54
+
55
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
56
+ game = self.data[idx][self.column]
57
+
58
+ # IMPORTANT: do NOT prepend BOS manually in raw text.
59
+ # The tokenizer should add BOS (and only BOS if desired) via
60
+ # build_inputs_with_special_tokens, keeping things compatible with evaluate.py.
61
+ encoding = self.tokenizer(
62
+ game,
63
+ truncation=True,
64
+ max_length=self.max_length,
65
+ padding="max_length",
66
+ return_tensors="pt",
67
+ )
68
+
69
+ input_ids = encoding["input_ids"].squeeze(0)
70
+ attention_mask = encoding["attention_mask"].squeeze(0)
71
+
72
+ labels = input_ids.clone()
73
+ labels[attention_mask == 0] = -100
74
+
75
+ return {
76
+ "input_ids": input_ids,
77
+ "attention_mask": attention_mask,
78
+ "labels": labels,
79
+ }
80
+
81
+
82
+ class ChessDataCollator:
83
+ """
84
+ Data collator for chess games.
85
+
86
+ Here sequences are already padded to max_length in the dataset,
87
+ so we just stack tensors.
88
+ """
89
+
90
+ def __init__(self, tokenizer, max_length: int = 256):
91
+ self.tokenizer = tokenizer
92
+ self.max_length = max_length
93
+
94
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
95
+ input_ids = torch.stack([f["input_ids"] for f in features])
96
+ attention_mask = torch.stack([f["attention_mask"] for f in features])
97
+ labels = torch.stack([f["labels"] for f in features])
98
+
99
+ return {
100
+ "input_ids": input_ids,
101
+ "attention_mask": attention_mask,
102
+ "labels": labels,
103
+ }
104
+
105
+
106
+ def create_train_val_datasets(
107
+ tokenizer,
108
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
109
+ max_length: int = 256,
110
+ train_samples: Optional[int] = None,
111
+ val_samples: int = 5000,
112
+ val_ratio: float = 0.05,
113
+ ):
114
+ """
115
+ Create training and validation datasets.
116
+
117
+ Splits the dataset deterministically by index:
118
+ - train: [0:n_train)
119
+ - val: [n_train:n_train+n_val)
120
+
121
+ Returns:
122
+ (train_dataset, val_dataset)
123
+ """
124
+ from datasets import load_dataset
125
+
126
+ full_dataset = load_dataset(dataset_name, split="train")
127
+ total = len(full_dataset)
128
+
129
+ if train_samples is not None:
130
+ n_train = min(train_samples, total - val_samples)
131
+ else:
132
+ n_train = int(total * (1 - val_ratio))
133
+
134
+ n_val = min(val_samples, total - n_train)
135
+
136
+ train_data = full_dataset.select(range(n_train))
137
+ val_data = full_dataset.select(range(n_train, n_train + n_val))
138
+
139
+ train_dataset = ChessDataset(
140
+ tokenizer=tokenizer,
141
+ dataset_name=dataset_name,
142
+ max_length=max_length,
143
+ )
144
+ train_dataset.data = train_data
145
+
146
+ val_dataset = ChessDataset(
147
+ tokenizer=tokenizer,
148
+ dataset_name=dataset_name,
149
+ max_length=max_length,
150
+ )
151
+ val_dataset.data = val_data
152
+
153
+ return train_dataset, val_dataset
154
+
155
+
156
+ def stream_games(
157
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
158
+ split: str = "train",
159
+ column: str = "text",
160
+ ) -> Iterator[str]:
161
+ """
162
+ Stream games from the dataset for memory-efficient processing.
163
+ """
164
+ from datasets import load_dataset
165
+
166
+ dataset = load_dataset(dataset_name, split=split, streaming=True)
167
+ for example in dataset:
168
+ yield example[column]
169
+
170
+
171
+ def analyze_dataset_statistics(
172
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
173
+ max_samples: int = 10000,
174
+ ) -> Dict:
175
+ """
176
+ Analyze statistics of the chess dataset (non-streaming).
177
+ """
178
+ from collections import Counter
179
+ from datasets import load_dataset
180
+
181
+ dataset = load_dataset(dataset_name, split="train")
182
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
183
+
184
+ game_lengths = []
185
+ move_counts = Counter()
186
+ opening_moves = Counter()
187
+
188
+ for example in dataset:
189
+ moves = example["text"].strip().split()
190
+ game_lengths.append(len(moves))
191
+ move_counts.update(moves)
192
+
193
+ if len(moves) >= 4:
194
+ opening = " ".join(moves[:4])
195
+ opening_moves[opening] += 1
196
+
197
+ return {
198
+ "total_games": len(dataset),
199
+ "avg_game_length": sum(game_lengths) / len(game_lengths),
200
+ "min_game_length": min(game_lengths),
201
+ "max_game_length": max(game_lengths),
202
+ "unique_moves": len(move_counts),
203
+ "most_common_moves": move_counts.most_common(20),
204
+ "most_common_openings": opening_moves.most_common(10),
205
+ }
src/.ipynb_checkpoints/evaluate-checkpoint.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for the Chess Challenge.
3
+
4
+ This script evaluates a trained chess model by playing games against
5
+ Stockfish and computing ELO ratings.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import random
12
+ import re
13
+ from dataclasses import dataclass
14
+ from typing import List, Optional, Tuple
15
+
16
+ import torch
17
+
18
+
19
+ @dataclass
20
+ class GameResult:
21
+ """Result of a single game."""
22
+ moves: List[str]
23
+ result: str # "1-0", "0-1", or "1/2-1/2"
24
+ model_color: str # "white" or "black"
25
+ termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
26
+ illegal_move_count: int
27
+
28
+
29
+ class ChessEvaluator:
30
+ """
31
+ Evaluator for chess models.
32
+
33
+ This class handles playing games between a trained model and Stockfish,
34
+ tracking results, and computing ELO ratings.
35
+
36
+ Supports any tokenization format as long as the model generates valid
37
+ chess squares (e.g., e2, e4). The evaluator extracts UCI moves by finding
38
+ square patterns in the generated output.
39
+ """
40
+
41
+ # Regex pattern to match chess squares
42
+ SQUARE_PATTERN = r"[a-h][1-8]"
43
+
44
+ def __init__(
45
+ self,
46
+ model,
47
+ tokenizer,
48
+ stockfish_path: Optional[str] = None,
49
+ stockfish_level: int = 1,
50
+ max_retries: int = 3,
51
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
52
+ ):
53
+ """
54
+ Initialize the evaluator.
55
+
56
+ Args:
57
+ model: The trained chess model.
58
+ tokenizer: The chess tokenizer.
59
+ stockfish_path: Path to Stockfish executable.
60
+ stockfish_level: Stockfish skill level (0-20).
61
+ max_retries: Maximum retries for illegal moves.
62
+ device: Device to run the model on.
63
+ """
64
+ self.model = model.to(device)
65
+ self.model.eval()
66
+ self.tokenizer = tokenizer
67
+ self.max_retries = max_retries
68
+ self.device = device
69
+
70
+ # Initialize Stockfish
71
+ try:
72
+ import chess
73
+ import chess.engine
74
+
75
+ self.chess = chess
76
+
77
+ if stockfish_path is None:
78
+ # Try common paths
79
+ import shutil
80
+
81
+ stockfish_path = shutil.which("stockfish")
82
+
83
+ if stockfish_path:
84
+ self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
85
+ self.engine.configure({"Skill Level": stockfish_level})
86
+ else:
87
+ print("WARNING: Stockfish not found. Install it for full evaluation.")
88
+ self.engine = None
89
+
90
+ except ImportError:
91
+ raise ImportError(
92
+ "python-chess is required for evaluation. "
93
+ "Install it with: pip install python-chess"
94
+ )
95
+
96
+ def __del__(self):
97
+ """Clean up Stockfish engine."""
98
+ if hasattr(self, "engine") and self.engine:
99
+ self.engine.quit()
100
+
101
+ def _detect_tokenizer_format(self) -> str:
102
+ """
103
+ Detect the tokenizer's expected move format by testing tokenization.
104
+
105
+ Tests various formats with a sample move and picks the one that
106
+ produces the fewest unknown tokens. This makes evaluation work
107
+ with any tokenizer format.
108
+
109
+ Supported formats:
110
+ - 'decomposed': "WP e2_f e4_t" (piece, from_suffix, to_suffix)
111
+ - 'standard': "WPe2e4" (combined with optional annotations)
112
+ - 'uci': "e2e4" (pure UCI notation)
113
+ - 'uci_spaced': "e2 e4" (UCI with space separator)
114
+
115
+ Returns:
116
+ The format string that best matches the tokenizer's vocabulary.
117
+ """
118
+ if hasattr(self, "_cached_format"):
119
+ return self._cached_format
120
+
121
+ test_formats = {
122
+ "decomposed": "WP e2_f e4_t",
123
+ "standard": "WPe2e4",
124
+ "uci": "e2e4",
125
+ "uci_spaced": "e2 e4",
126
+ }
127
+
128
+ unk_token_id = getattr(self.tokenizer, "unk_token_id", None)
129
+ best_format = "standard"
130
+ min_unk_count = float("inf")
131
+
132
+ for fmt, sample in test_formats.items():
133
+ try:
134
+ tokens = self.tokenizer.encode(sample, add_special_tokens=False)
135
+ unk_count = tokens.count(unk_token_id) if unk_token_id is not None else 0
136
+ if len(tokens) == 1 and unk_count == 1:
137
+ unk_count = 100 # heavy penalty
138
+ if unk_count < min_unk_count:
139
+ min_unk_count = unk_count
140
+ best_format = fmt
141
+ except Exception:
142
+ continue
143
+
144
+ self._cached_format = best_format
145
+ return best_format
146
+
147
+ def _format_move(
148
+ self,
149
+ color: str,
150
+ piece: str,
151
+ from_sq: str,
152
+ to_sq: str,
153
+ promotion: str = None,
154
+ ) -> str:
155
+ fmt = self._detect_tokenizer_format()
156
+
157
+ if fmt == "decomposed":
158
+ move_str = f"{color}{piece} {from_sq}_f {to_sq}_t"
159
+ elif fmt == "uci":
160
+ move_str = f"{from_sq}{to_sq}"
161
+ if promotion:
162
+ move_str += promotion.lower()
163
+ elif fmt == "uci_spaced":
164
+ move_str = f"{from_sq} {to_sq}"
165
+ if promotion:
166
+ move_str += f" {promotion.lower()}"
167
+ else: # standard
168
+ move_str = f"{color}{piece}{from_sq}{to_sq}"
169
+ if promotion:
170
+ move_str += f"={promotion}"
171
+
172
+ return move_str
173
+
174
+ def _convert_board_to_moves(self, board) -> str:
175
+ moves = []
176
+ temp_board = self.chess.Board()
177
+ fmt = self._detect_tokenizer_format()
178
+
179
+ for move in board.move_stack:
180
+ color = "W" if temp_board.turn == self.chess.WHITE else "B"
181
+ piece = temp_board.piece_at(move.from_square)
182
+ piece_letter = piece.symbol().upper() if piece else "P"
183
+
184
+ from_sq = self.chess.square_name(move.from_square)
185
+ to_sq = self.chess.square_name(move.to_square)
186
+
187
+ promo = None
188
+ if move.promotion:
189
+ promo = self.chess.piece_symbol(move.promotion).upper()
190
+
191
+ move_str = self._format_move(color, piece_letter, from_sq, to_sq, promo)
192
+
193
+ if fmt == "standard":
194
+ if temp_board.is_capture(move):
195
+ move_str += "(x)"
196
+
197
+ temp_board.push(move)
198
+
199
+ if temp_board.is_checkmate():
200
+ if "(x)" in move_str:
201
+ move_str = move_str.replace("(x)", "(x+*)")
202
+ else:
203
+ move_str += "(+*)"
204
+ elif temp_board.is_check():
205
+ if "(x)" in move_str:
206
+ move_str = move_str.replace("(x)", "(x+)")
207
+ else:
208
+ move_str += "(+)"
209
+
210
+ if piece_letter == "K":
211
+ if abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
212
+ if to_sq[0] == "g":
213
+ move_str = move_str.split("(")[0] + "(o)"
214
+ else:
215
+ move_str = move_str.split("(")[0] + "(O)"
216
+ else:
217
+ temp_board.push(move)
218
+
219
+ moves.append(move_str)
220
+
221
+ return " ".join(moves)
222
+
223
+ def _is_separator_token(self, token_str: str) -> bool:
224
+ if hasattr(self.tokenizer, "eos_token") and token_str == self.tokenizer.eos_token:
225
+ return True
226
+ if token_str.strip() == "" and len(token_str) > 0:
227
+ return True
228
+ if token_str != token_str.rstrip():
229
+ return True
230
+ return False
231
+
232
+ def _extract_uci_move(self, text: str) -> Optional[str]:
233
+ if not text:
234
+ return None
235
+
236
+ squares = re.findall(self.SQUARE_PATTERN, text)
237
+ if len(squares) < 2:
238
+ return None
239
+
240
+ from_sq, to_sq = squares[0], squares[1]
241
+ uci_move = from_sq + to_sq
242
+
243
+ to_sq_idx = text.find(to_sq)
244
+ if to_sq_idx != -1:
245
+ remaining = text[to_sq_idx + 2 : to_sq_idx + 5]
246
+ promo_match = re.search(r"[=]?([qrbnQRBN])", remaining)
247
+ if promo_match:
248
+ uci_move += promo_match.group(1).lower()
249
+
250
+ return uci_move
251
+
252
+ def _has_complete_move(self, text: str) -> bool:
253
+ squares = re.findall(self.SQUARE_PATTERN, text)
254
+ return len(squares) >= 2
255
+
256
+ def _generate_move_tokens(
257
+ self,
258
+ input_ids: torch.Tensor,
259
+ temperature: float = 0.7,
260
+ top_k: int = 10,
261
+ max_tokens: int = 20,
262
+ ) -> str:
263
+ generated_tokens = []
264
+ current_ids = input_ids.clone()
265
+ accumulated_text = ""
266
+
267
+ for _ in range(max_tokens):
268
+ with torch.no_grad():
269
+ outputs = self.model(input_ids=current_ids)
270
+ logits = outputs.logits[:, -1, :] / max(temperature, 1e-6)
271
+
272
+ if top_k > 0:
273
+ top_k_vals = torch.topk(logits, min(top_k, logits.size(-1)))
274
+ indices_to_remove = logits < top_k_vals[0][..., -1, None]
275
+ logits[indices_to_remove] = float("-inf")
276
+
277
+ probs = torch.softmax(logits, dim=-1)
278
+ next_token = torch.multinomial(probs, num_samples=1)
279
+
280
+ token_str = self.tokenizer.decode(next_token[0])
281
+
282
+ if self._is_separator_token(token_str):
283
+ if self._has_complete_move(accumulated_text):
284
+ break
285
+ if hasattr(self.tokenizer, "eos_token") and token_str == self.tokenizer.eos_token:
286
+ break
287
+ if accumulated_text:
288
+ break
289
+
290
+ generated_tokens.append(next_token[0])
291
+ current_ids = torch.cat([current_ids, next_token], dim=-1)
292
+ accumulated_text += token_str
293
+
294
+ if self._has_complete_move(accumulated_text):
295
+ squares = re.findall(self.SQUARE_PATTERN, accumulated_text)
296
+ if len(squares) >= 2:
297
+ to_sq = squares[1]
298
+ if to_sq[1] in "18":
299
+ if len(generated_tokens) > 3:
300
+ break
301
+ else:
302
+ break
303
+
304
+ if generated_tokens:
305
+ all_tokens = torch.cat(generated_tokens, dim=0)
306
+ move_str = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
307
+ return move_str.strip()
308
+
309
+ return ""
310
+
311
+ def _get_model_move(
312
+ self,
313
+ board,
314
+ temperature: float = 0.7,
315
+ top_k: int = 10,
316
+ ) -> Tuple[Optional[str], int]:
317
+ self.model.eval()
318
+
319
+ moves_str = self._convert_board_to_moves(board)
320
+
321
+ if not moves_str:
322
+ input_text = self.tokenizer.bos_token
323
+ else:
324
+ input_text = self.tokenizer.bos_token + " " + moves_str
325
+
326
+ inputs = self.tokenizer(
327
+ input_text,
328
+ return_tensors="pt",
329
+ truncation=True,
330
+ max_length=self.model.config.n_ctx - 10,
331
+ ).to(self.device)
332
+
333
+ for retry in range(self.max_retries):
334
+ move_text = self._generate_move_tokens(
335
+ inputs["input_ids"],
336
+ temperature=temperature,
337
+ top_k=top_k,
338
+ )
339
+
340
+ uci_move = self._extract_uci_move(move_text)
341
+
342
+ if uci_move:
343
+ try:
344
+ move = self.chess.Move.from_uci(uci_move)
345
+ if move in board.legal_moves:
346
+ return uci_move, retry
347
+ except (ValueError, self.chess.InvalidMoveError):
348
+ pass
349
+
350
+ return None, self.max_retries
351
+
352
+ def _get_stockfish_move(self, board, time_limit: float = 0.1) -> str:
353
+ if self.engine is None:
354
+ raise RuntimeError("Stockfish engine not initialized")
355
+
356
+ result = self.engine.play(board, self.chess.engine.Limit(time=time_limit))
357
+ return result.move.uci()
358
+
359
+ def play_game(
360
+ self,
361
+ model_color: str = "white",
362
+ max_moves: int = 200,
363
+ temperature: float = 0.7,
364
+ ) -> GameResult:
365
+ board = self.chess.Board()
366
+ moves = []
367
+ illegal_move_count = 0
368
+
369
+ model_is_white = model_color == "white"
370
+
371
+ while not board.is_game_over() and len(moves) < max_moves:
372
+ is_model_turn = (board.turn == self.chess.WHITE) == model_is_white
373
+
374
+ if is_model_turn:
375
+ uci_move, retries = self._get_model_move(board, temperature)
376
+ illegal_move_count += retries
377
+
378
+ if uci_move is None:
379
+ return GameResult(
380
+ moves=moves,
381
+ result="0-1" if model_is_white else "1-0",
382
+ model_color=model_color,
383
+ termination="illegal_move",
384
+ illegal_move_count=illegal_move_count + 1,
385
+ )
386
+
387
+ move = self.chess.Move.from_uci(uci_move)
388
+ else:
389
+ if self.engine:
390
+ uci_move = self._get_stockfish_move(board)
391
+ move = self.chess.Move.from_uci(uci_move)
392
+ else:
393
+ move = random.choice(list(board.legal_moves))
394
+
395
+ board.push(move)
396
+ moves.append(move.uci())
397
+
398
+ if board.is_checkmate():
399
+ if board.turn == self.chess.WHITE:
400
+ result = "0-1"
401
+ else:
402
+ result = "1-0"
403
+ termination = "checkmate"
404
+ elif board.is_stalemate():
405
+ result = "1/2-1/2"
406
+ termination = "stalemate"
407
+ elif board.is_insufficient_material():
408
+ result = "1/2-1/2"
409
+ termination = "insufficient_material"
410
+ elif board.can_claim_draw():
411
+ result = "1/2-1/2"
412
+ termination = "draw_claim"
413
+ elif len(moves) >= max_moves:
414
+ result = "1/2-1/2"
415
+ termination = "max_moves"
416
+ else:
417
+ result = "1/2-1/2"
418
+ termination = "unknown"
419
+
420
+ return GameResult(
421
+ moves=moves,
422
+ result=result,
423
+ model_color=model_color,
424
+ termination=termination,
425
+ illegal_move_count=illegal_move_count,
426
+ )
427
+
428
+ def evaluate_legal_moves(
429
+ self,
430
+ n_positions: int = 1000,
431
+ temperature: float = 0.7,
432
+ verbose: bool = True,
433
+ seed: int = 42,
434
+ ) -> dict:
435
+ random.seed(seed)
436
+ torch.manual_seed(seed)
437
+
438
+ results = {
439
+ "total_positions": 0,
440
+ "legal_first_try": 0,
441
+ "legal_with_retry": 0,
442
+ "illegal_all_retries": 0,
443
+ "positions": [],
444
+ }
445
+
446
+ for i in range(n_positions):
447
+ board = self.chess.Board()
448
+
449
+ n_random_moves = random.randint(5, 40)
450
+ for _ in range(n_random_moves):
451
+ if board.is_game_over():
452
+ break
453
+ move = random.choice(list(board.legal_moves))
454
+ board.push(move)
455
+
456
+ if board.is_game_over():
457
+ continue
458
+
459
+ results["total_positions"] += 1
460
+
461
+ uci_move, retries = self._get_model_move(board, temperature)
462
+
463
+ position_result = {
464
+ "fen": board.fen(),
465
+ "move_number": len(board.move_stack),
466
+ "legal": uci_move is not None,
467
+ "retries": retries,
468
+ }
469
+ results["positions"].append(position_result)
470
+
471
+ if uci_move is not None:
472
+ if retries == 0:
473
+ results["legal_first_try"] += 1
474
+ else:
475
+ results["legal_with_retry"] += 1
476
+ else:
477
+ results["illegal_all_retries"] += 1
478
+
479
+ if verbose and (i + 1) % 100 == 0:
480
+ legal_rate = (results["legal_first_try"] + results["legal_with_retry"]) / results["total_positions"]
481
+ print(f" Positions: {i + 1}/{n_positions} | Legal rate: {legal_rate:.1%}")
482
+
483
+ total = results["total_positions"]
484
+ if total > 0:
485
+ results["legal_rate_first_try"] = results["legal_first_try"] / total
486
+ results["legal_rate_with_retry"] = (results["legal_first_try"] + results["legal_with_retry"]) / total
487
+ results["illegal_rate"] = results["illegal_all_retries"] / total
488
+ else:
489
+ results["legal_rate_first_try"] = 0
490
+ results["legal_rate_with_retry"] = 0
491
+ results["illegal_rate"] = 1
492
+
493
+ return results
494
+
495
+ def evaluate(
496
+ self,
497
+ n_games: int = 100,
498
+ temperature: float = 0.7,
499
+ verbose: bool = True,
500
+ ) -> dict:
501
+ results = {
502
+ "wins": 0,
503
+ "losses": 0,
504
+ "draws": 0,
505
+ "illegal_moves": 0,
506
+ "total_moves": 0,
507
+ "games": [],
508
+ }
509
+
510
+ for i in range(n_games):
511
+ model_color = "white" if i % 2 == 0 else "black"
512
+
513
+ game = self.play_game(
514
+ model_color=model_color,
515
+ temperature=temperature,
516
+ )
517
+
518
+ results["games"].append(game)
519
+ results["total_moves"] += len(game.moves)
520
+ results["illegal_moves"] += game.illegal_move_count
521
+
522
+ if game.result == "1/2-1/2":
523
+ results["draws"] += 1
524
+ elif (game.result == "1-0" and model_color == "white") or (game.result == "0-1" and model_color == "black"):
525
+ results["wins"] += 1
526
+ else:
527
+ results["losses"] += 1
528
+
529
+ if verbose and (i + 1) % 10 == 0:
530
+ print(
531
+ f" Games: {i + 1}/{n_games} | "
532
+ f"W: {results['wins']} L: {results['losses']} D: {results['draws']}"
533
+ )
534
+
535
+ total = results["wins"] + results["losses"] + results["draws"]
536
+ results["win_rate"] = results["wins"] / total if total > 0 else 0
537
+ results["draw_rate"] = results["draws"] / total if total > 0 else 0
538
+ results["loss_rate"] = results["losses"] / total if total > 0 else 0
539
+
540
+ total_attempts = results["total_moves"] + results["illegal_moves"]
541
+ results["avg_game_length"] = total_attempts / total if total > 0 else 0
542
+ results["illegal_move_rate"] = results["illegal_moves"] / total_attempts if total_attempts > 0 else 0
543
+
544
+ stockfish_elo = 1350
545
+ if results["win_rate"] > 0 or results["loss_rate"] > 0:
546
+ score = results["wins"] + 0.5 * results["draws"]
547
+ if score > 0:
548
+ win_ratio = score / total
549
+ if 0 < win_ratio < 1:
550
+ elo_diff = -400 * (1 - 2 * win_ratio) / (1 if win_ratio > 0.5 else -1)
551
+ results["estimated_elo"] = stockfish_elo + elo_diff
552
+ else:
553
+ results["estimated_elo"] = stockfish_elo + (400 if win_ratio >= 1 else -400)
554
+ else:
555
+ results["estimated_elo"] = stockfish_elo - 400
556
+ else:
557
+ results["estimated_elo"] = None
558
+
559
+ return results
560
+
561
+
562
+ def load_model_from_hub(model_id: str, device: str = "auto", verbose: bool = True):
563
+ from transformers import AutoModelForCausalLM, AutoTokenizer
564
+
565
+ # Import to register custom classes
566
+ from src.model import ChessConfig, ChessForCausalLM
567
+ from src.tokenizer import ChessTokenizer
568
+
569
+ tokenizer_source = None
570
+ try:
571
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
572
+ tokenizer_source = "AutoTokenizer (from Hub with trust_remote_code=True)"
573
+ except Exception as e:
574
+ if verbose:
575
+ print(f" AutoTokenizer failed: {e}")
576
+ tokenizer = ChessTokenizer.from_pretrained(model_id)
577
+ tokenizer_source = "ChessTokenizer (local class, vocab from Hub)"
578
+
579
+ model = AutoModelForCausalLM.from_pretrained(
580
+ model_id,
581
+ trust_remote_code=True,
582
+ device_map=device,
583
+ )
584
+
585
+ if verbose:
586
+ print(f" Tokenizer loaded via: {tokenizer_source}")
587
+ print(f" Tokenizer class: {type(tokenizer).__name__}")
588
+ print(f" Tokenizer vocab size: {tokenizer.vocab_size}")
589
+ if hasattr(tokenizer, "_vocab"):
590
+ print(f" Tokenizer has _vocab attribute: yes ({len(tokenizer._vocab)} entries)")
591
+
592
+ return model, tokenizer
593
+
594
+
595
+ def main():
596
+ parser = argparse.ArgumentParser(description="Evaluate a chess model")
597
+
598
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the model or Hugging Face model ID")
599
+ parser.add_argument("--mode", type=str, default="legal", choices=["legal", "winrate", "both"])
600
+ parser.add_argument("--stockfish_path", type=str, default=None, help="Path to Stockfish executable")
601
+ parser.add_argument("--stockfish_level", type=int, default=1, help="Stockfish skill level (0-20)")
602
+ parser.add_argument("--n_positions", type=int, default=500, help="Number of positions for legal move evaluation")
603
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
604
+ parser.add_argument("--n_games", type=int, default=100, help="Number of games to play for win rate evaluation")
605
+ parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
606
+
607
+ args = parser.parse_args()
608
+
609
+ print("=" * 60)
610
+ print("CHESS CHALLENGE - EVALUATION")
611
+ print("=" * 60)
612
+
613
+ print(f"\nLoading model from: {args.model_path}")
614
+
615
+ import os
616
+ is_local_path = os.path.exists(args.model_path)
617
+
618
+ if is_local_path:
619
+ # Local path
620
+ from transformers import AutoModelForCausalLM
621
+ from src.tokenizer import ChessTokenizer
622
+ from src.model import ChessConfig, ChessForCausalLM
623
+
624
+ tokenizer = ChessTokenizer.from_pretrained(args.model_path)
625
+
626
+ # IMPORTANT FIX:
627
+ # Our custom ChessForCausalLM does NOT support device_map="auto" unless _no_split_modules is defined.
628
+ # So we load normally and move to device explicitly.
629
+ device = "cuda" if torch.cuda.is_available() else "cpu"
630
+
631
+ model = AutoModelForCausalLM.from_pretrained(
632
+ args.model_path,
633
+ trust_remote_code=True,
634
+ )
635
+ model.to(device)
636
+ model.eval()
637
+ else:
638
+ if args.model_path.startswith(".") or args.model_path.startswith("/"):
639
+ raise FileNotFoundError(
640
+ f"Local model path not found: {args.model_path}\n"
641
+ f"Please check that the path exists and contains model files."
642
+ )
643
+ model, tokenizer = load_model_from_hub(args.model_path)
644
+
645
+ print(f"\nSetting up evaluator...")
646
+ evaluator = ChessEvaluator(
647
+ model=model,
648
+ tokenizer=tokenizer,
649
+ stockfish_path=args.stockfish_path,
650
+ stockfish_level=args.stockfish_level,
651
+ )
652
+
653
+ if args.mode in ["legal", "both"]:
654
+ print(f"\n" + "=" * 60)
655
+ print("PHASE 1: LEGAL MOVE EVALUATION")
656
+ print("=" * 60)
657
+ print(f"Testing {args.n_positions} random positions...")
658
+
659
+ legal_results = evaluator.evaluate_legal_moves(
660
+ n_positions=args.n_positions,
661
+ temperature=args.temperature,
662
+ verbose=True,
663
+ seed=args.seed,
664
+ )
665
+
666
+ print("\n" + "-" * 40)
667
+ print("LEGAL MOVE RESULTS")
668
+ print("-" * 40)
669
+ print(f" Positions tested: {legal_results['total_positions']}")
670
+ print(f" Legal (1st try): {legal_results['legal_first_try']} ({legal_results['legal_rate_first_try']:.1%})")
671
+ print(
672
+ f" Legal (with retry): {legal_results['legal_first_try'] + legal_results['legal_with_retry']}"
673
+ f" ({legal_results['legal_rate_with_retry']:.1%})"
674
+ )
675
+ print(f" Always illegal: {legal_results['illegal_all_retries']} ({legal_results['illegal_rate']:.1%})")
676
+
677
+ if args.mode in ["winrate", "both"]:
678
+ print(f"\n" + "=" * 60)
679
+ print("PHASE 2: WIN RATE EVALUATION")
680
+ print("=" * 60)
681
+ print(f"Playing {args.n_games} games against Stockfish (Level {args.stockfish_level})...")
682
+
683
+ winrate_results = evaluator.evaluate(
684
+ n_games=args.n_games,
685
+ temperature=args.temperature,
686
+ verbose=True,
687
+ )
688
+
689
+ print("\n" + "-" * 40)
690
+ print("WIN RATE RESULTS")
691
+ print("-" * 40)
692
+ print(f" Wins: {winrate_results['wins']}")
693
+ print(f" Losses: {winrate_results['losses']}")
694
+ print(f" Draws: {winrate_results['draws']}")
695
+ print(f"\n Win Rate: {winrate_results['win_rate']:.1%}")
696
+ print(f" Draw Rate: {winrate_results['draw_rate']:.1%}")
697
+ print(f" Loss Rate: {winrate_results['loss_rate']:.1%}")
698
+ print(f"\n Avg Game Length: {winrate_results['avg_game_length']:.1f} moves")
699
+ print(f" Illegal Move Rate: {winrate_results['illegal_move_rate']:.2%}")
700
+
701
+ if winrate_results.get("estimated_elo", None):
702
+ print(f"\n Estimated ELO: {winrate_results['estimated_elo']:.0f}")
703
+
704
+ print("\n" + "=" * 60)
705
+ print("EVALUATION COMPLETE")
706
+ print("=" * 60)
707
+
708
+
709
+ if __name__ == "__main__":
710
+ main()
src/.ipynb_checkpoints/model-checkpoint.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+
4
+ Modern small-LLM upgrades:
5
+ - RoPE (rotary positional embeddings): no learned positional embeddings needed
6
+ - RMSNorm (optional, default True)
7
+ - SwiGLU MLP (optional, default True)
8
+ - Weight tying (default True)
9
+ - Safe loss ignore_index = -100 (HF convention)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+
23
+
24
+ class ChessConfig(PretrainedConfig):
25
+ model_type = "chess_transformer"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size: int = 1200,
30
+
31
+ # Architecture (defaults tuned to be < 1M params for common vocabs)
32
+ n_embd: int = 112,
33
+ n_layer: int = 7,
34
+ n_head: int = 7,
35
+
36
+ # Context window
37
+ n_ctx: int = 512,
38
+
39
+ # MLP hidden size:
40
+ # - if mlp_type="swiglu", this is SwiGLU hidden size h
41
+ # - if mlp_type="gelu", this is FFN inner size
42
+ n_inner: Optional[int] = 192,
43
+
44
+ dropout: float = 0.05,
45
+ layer_norm_epsilon: float = 1e-6,
46
+
47
+ # Position encoding
48
+ use_rope: bool = True,
49
+ rope_theta: float = 10000.0,
50
+
51
+ # Normalization / MLP type
52
+ use_rmsnorm: bool = True,
53
+ mlp_type: str = "swiglu", # "swiglu" or "gelu"
54
+
55
+ # Weight tying
56
+ tie_weights: bool = True,
57
+
58
+ pad_token_id: int = 0,
59
+ bos_token_id: int = 1,
60
+ eos_token_id: int = 2,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ **kwargs,
68
+ )
69
+
70
+ if n_embd % n_head != 0:
71
+ raise ValueError(f"n_embd ({n_embd}) must be divisible by n_head ({n_head})")
72
+
73
+ head_dim = n_embd // n_head
74
+ if use_rope and (head_dim % 2 != 0):
75
+ raise ValueError(
76
+ f"RoPE requires even head_dim, got head_dim={head_dim}. "
77
+ f"Choose n_embd/n_head even."
78
+ )
79
+
80
+ self.vocab_size = vocab_size
81
+ self.n_embd = n_embd
82
+ self.n_layer = n_layer
83
+ self.n_head = n_head
84
+ self.n_ctx = n_ctx
85
+ self.n_inner = n_inner if n_inner is not None else (2 * n_embd)
86
+ self.dropout = dropout
87
+ self.layer_norm_epsilon = layer_norm_epsilon
88
+
89
+ self.use_rope = use_rope
90
+ self.rope_theta = rope_theta
91
+
92
+ self.use_rmsnorm = use_rmsnorm
93
+ self.mlp_type = mlp_type
94
+
95
+ self.tie_weights = tie_weights
96
+ # HF uses this field for embedding tying behavior
97
+ self.tie_word_embeddings = bool(tie_weights)
98
+
99
+
100
+ class RMSNorm(nn.Module):
101
+ def __init__(self, dim: int, eps: float = 1e-6):
102
+ super().__init__()
103
+ self.eps = eps
104
+ self.weight = nn.Parameter(torch.ones(dim))
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ norm = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
108
+ return x * norm * self.weight
109
+
110
+
111
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
112
+ x1 = x[..., 0::2]
113
+ x2 = x[..., 1::2]
114
+ out = torch.empty_like(x)
115
+ out[..., 0::2] = -x2
116
+ out[..., 1::2] = x1
117
+ return out
118
+
119
+
120
+ class RotaryEmbedding(nn.Module):
121
+ """
122
+ RoPE cache builder. Applies RoPE to q,k with shape (B,H,T,D).
123
+ """
124
+
125
+ def __init__(self, head_dim: int, theta: float = 10000.0):
126
+ super().__init__()
127
+ if head_dim % 2 != 0:
128
+ raise ValueError(f"RoPE requires even head_dim, got {head_dim}")
129
+
130
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
131
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
132
+
133
+ self._cos_cached = None
134
+ self._sin_cached = None
135
+ self._seq_len_cached = 0
136
+ self._device_cached = None
137
+ self._dtype_cached = None
138
+
139
+ def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
140
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
141
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq) # (T, D/2)
142
+
143
+ cos = freqs.cos().to(dtype=dtype)
144
+ sin = freqs.sin().to(dtype=dtype)
145
+
146
+ self._cos_cached = cos
147
+ self._sin_cached = sin
148
+ self._seq_len_cached = seq_len
149
+ self._device_cached = device
150
+ self._dtype_cached = dtype
151
+
152
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
153
+ # q,k: (B,H,T,D)
154
+ T = q.size(-2)
155
+ device = q.device
156
+ dtype = q.dtype
157
+
158
+ if (
159
+ self._cos_cached is None
160
+ or T > self._seq_len_cached
161
+ or device != self._device_cached
162
+ or dtype != self._dtype_cached
163
+ ):
164
+ self._build_cache(T, device, dtype)
165
+
166
+ cos = self._cos_cached[:T] # (T, D/2)
167
+ sin = self._sin_cached[:T] # (T, D/2)
168
+
169
+ # broadcast to (1,1,T,D) via repeat_interleave on last dim
170
+ cos = torch.repeat_interleave(cos.unsqueeze(0).unsqueeze(0), 2, dim=-1)
171
+ sin = torch.repeat_interleave(sin.unsqueeze(0).unsqueeze(0), 2, dim=-1)
172
+
173
+ q_out = (q * cos) + (rotate_half(q) * sin)
174
+ k_out = (k * cos) + (rotate_half(k) * sin)
175
+ return q_out, k_out
176
+
177
+
178
+ class MultiHeadAttention(nn.Module):
179
+ def __init__(self, config: ChessConfig):
180
+ super().__init__()
181
+
182
+ self.n_head = config.n_head
183
+ self.n_embd = config.n_embd
184
+ self.head_dim = config.n_embd // config.n_head
185
+
186
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
187
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
188
+ self.dropout = nn.Dropout(config.dropout)
189
+
190
+ self.use_rope = bool(config.use_rope)
191
+ self.rope = RotaryEmbedding(self.head_dim, theta=config.rope_theta) if self.use_rope else None
192
+
193
+ # causal mask buffer (expandable)
194
+ self.register_buffer(
195
+ "bias",
196
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx),
197
+ persistent=False,
198
+ )
199
+
200
+ def _ensure_causal_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype):
201
+ if self.bias.size(-1) >= seq_len and self.bias.device == device:
202
+ return
203
+ self.bias = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=dtype)).view(1, 1, seq_len, seq_len)
204
+
205
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ B, T, _ = x.size()
207
+
208
+ qkv = self.c_attn(x)
209
+ q, k, v = qkv.split(self.n_embd, dim=2)
210
+
211
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B,H,T,D)
212
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
213
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
214
+
215
+ if self.use_rope:
216
+ q, k = self.rope(q, k)
217
+
218
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
219
+
220
+ self._ensure_causal_mask(T, attn.device, attn.dtype)
221
+ causal_mask = self.bias[:, :, :T, :T]
222
+ mask_value = torch.finfo(attn.dtype).min
223
+ attn = attn.masked_fill(causal_mask == 0, mask_value)
224
+
225
+ # padding mask (1=keep, 0=mask)
226
+ if attention_mask is not None:
227
+ am = attention_mask.unsqueeze(1).unsqueeze(2) # (B,1,1,T)
228
+ attn = attn.masked_fill(am == 0, mask_value)
229
+
230
+ attn = F.softmax(attn, dim=-1)
231
+ attn = self.dropout(attn)
232
+
233
+ y = torch.matmul(attn, v) # (B,H,T,D)
234
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_embd)
235
+
236
+ y = self.c_proj(y)
237
+ y = self.dropout(y)
238
+ return y
239
+
240
+
241
+ class SwiGLU(nn.Module):
242
+ def __init__(self, config: ChessConfig):
243
+ super().__init__()
244
+ h = config.n_inner
245
+ self.w12 = nn.Linear(config.n_embd, 2 * h)
246
+ self.w3 = nn.Linear(h, config.n_embd)
247
+ self.dropout = nn.Dropout(config.dropout)
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x12 = self.w12(x)
251
+ x1, x2 = x12.chunk(2, dim=-1)
252
+ x = F.silu(x1) * x2
253
+ x = self.w3(x)
254
+ x = self.dropout(x)
255
+ return x
256
+
257
+
258
+ class FeedForwardGELU(nn.Module):
259
+ def __init__(self, config: ChessConfig):
260
+ super().__init__()
261
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
262
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
263
+ self.dropout = nn.Dropout(config.dropout)
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ x = self.c_fc(x)
267
+ x = F.gelu(x)
268
+ x = self.c_proj(x)
269
+ x = self.dropout(x)
270
+ return x
271
+
272
+
273
+ class TransformerBlock(nn.Module):
274
+ def __init__(self, config: ChessConfig):
275
+ super().__init__()
276
+
277
+ if config.use_rmsnorm:
278
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
279
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
280
+ else:
281
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
282
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
283
+
284
+ self.attn = MultiHeadAttention(config)
285
+
286
+ if config.mlp_type.lower() == "swiglu":
287
+ self.mlp = SwiGLU(config)
288
+ else:
289
+ self.mlp = FeedForwardGELU(config)
290
+
291
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
292
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
293
+ x = x + self.mlp(self.ln_2(x))
294
+ return x
295
+
296
+
297
+ class ChessForCausalLM(PreTrainedModel):
298
+ config_class = ChessConfig
299
+ base_model_prefix = "transformer"
300
+ supports_gradient_checkpointing = True
301
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
302
+ _no_split_modules = ["TransformerBlock"]
303
+
304
+
305
+ def __init__(self, config: ChessConfig):
306
+ super().__init__(config)
307
+
308
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
309
+
310
+ # learned positional embeddings only if RoPE disabled
311
+ self.wpe = None
312
+ if not config.use_rope:
313
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
314
+
315
+ self.drop = nn.Dropout(config.dropout)
316
+ self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
317
+
318
+ if config.use_rmsnorm:
319
+ self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
320
+ else:
321
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
322
+
323
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
324
+
325
+ if config.tie_weights:
326
+ self._tied_weights_keys = ["lm_head.weight"]
327
+
328
+ self.post_init()
329
+
330
+ if config.tie_weights:
331
+ self.tie_weights()
332
+
333
+ def get_input_embeddings(self) -> nn.Module:
334
+ return self.wte
335
+
336
+ def set_input_embeddings(self, new_embeddings: nn.Module):
337
+ self.wte = new_embeddings
338
+ if getattr(self.config, "tie_weights", False):
339
+ self.tie_weights()
340
+
341
+ def get_output_embeddings(self) -> nn.Module:
342
+ return self.lm_head
343
+
344
+ def set_output_embeddings(self, new_embeddings: nn.Module):
345
+ self.lm_head = new_embeddings
346
+
347
+ def tie_weights(self):
348
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
349
+ self._tie_or_clone_weights(self.lm_head, self.wte)
350
+
351
+ def _init_weights(self, module: nn.Module):
352
+ if isinstance(module, nn.Linear):
353
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
354
+ if module.bias is not None:
355
+ torch.nn.init.zeros_(module.bias)
356
+ elif isinstance(module, nn.Embedding):
357
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
358
+
359
+ def forward(
360
+ self,
361
+ input_ids: torch.LongTensor,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ labels: Optional[torch.LongTensor] = None,
365
+ return_dict: Optional[bool] = None,
366
+ **kwargs,
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+ B, T = input_ids.size()
370
+ device = input_ids.device
371
+
372
+ x = self.wte(input_ids)
373
+
374
+ if self.wpe is not None:
375
+ if position_ids is None:
376
+ position_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
377
+ x = x + self.wpe(position_ids)
378
+
379
+ x = self.drop(x)
380
+
381
+ for block in self.h:
382
+ x = block(x, attention_mask=attention_mask)
383
+
384
+ x = self.ln_f(x)
385
+ logits = self.lm_head(x)
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ shift_logits = logits[..., :-1, :].contiguous()
390
+ shift_labels = labels[..., 1:].contiguous()
391
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
392
+ loss = loss_fct(
393
+ shift_logits.view(-1, shift_logits.size(-1)),
394
+ shift_labels.view(-1),
395
+ )
396
+
397
+ if not return_dict:
398
+ output = (logits,)
399
+ return ((loss,) + output) if loss is not None else output
400
+
401
+ return CausalLMOutputWithPast(
402
+ loss=loss,
403
+ logits=logits,
404
+ past_key_values=None,
405
+ hidden_states=None,
406
+ attentions=None,
407
+ )
408
+
409
+ @torch.no_grad()
410
+ def generate_move(
411
+ self,
412
+ input_ids: torch.LongTensor,
413
+ temperature: float = 0.7,
414
+ top_k: Optional[int] = 50,
415
+ top_p: Optional[float] = None,
416
+ ) -> int:
417
+ self.eval()
418
+
419
+ outputs = self(input_ids)
420
+ logits = outputs.logits[:, -1, :] / max(float(temperature), 1e-6)
421
+
422
+ if top_k is not None and top_k > 0:
423
+ k = min(int(top_k), logits.size(-1))
424
+ thresh = torch.topk(logits, k)[0][..., -1, None]
425
+ logits = logits.masked_fill(logits < thresh, torch.finfo(logits.dtype).min)
426
+
427
+ if top_p is not None:
428
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
429
+ probs = F.softmax(sorted_logits, dim=-1)
430
+ cum = torch.cumsum(probs, dim=-1)
431
+ to_remove = cum > float(top_p)
432
+ to_remove[..., 1:] = to_remove[..., :-1].clone()
433
+ to_remove[..., 0] = 0
434
+ indices_to_remove = to_remove.scatter(dim=-1, index=sorted_indices, src=to_remove)
435
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
436
+
437
+ probs = F.softmax(logits, dim=-1)
438
+ next_token = torch.multinomial(probs, num_samples=1)
439
+ return int(next_token.item())
440
+
441
+
442
+ # Register the model with Auto classes
443
+ from transformers import AutoConfig, AutoModelForCausalLM
444
+
445
+ AutoConfig.register("chess_transformer", ChessConfig)
446
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
src/.ipynb_checkpoints/tokenizer-checkpoint.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Decomposed Chess Tokenizer for the Chess Challenge.
3
+
4
+ Each move becomes 3 or 4 tokens:
5
+ WP e2_f e4_t
6
+ BN g8_f f6_t
7
+ Promotion adds an extra token:
8
+ WP e7_f e8_t =q
9
+
10
+ Why this helps:
11
+ - Fixed small vocab (~150 tokens)
12
+ - Near-zero OOV / UNK, so the evaluator can always parse squares
13
+ - Compatible with the provided evaluate.py (it auto-detects 'decomposed')
14
+
15
+ Special tokens behavior:
16
+ - Adds BOS only (NO EOS)
17
+ - If BOS already present, does not add it twice
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import os
24
+ from typing import Dict, List, Optional
25
+
26
+ from transformers import PreTrainedTokenizer
27
+
28
+
29
+ class ChessTokenizer(PreTrainedTokenizer):
30
+ model_input_names = ["input_ids", "attention_mask"]
31
+ vocab_files_names = {"vocab_file": "vocab.json"}
32
+
33
+ PAD_TOKEN = "[PAD]"
34
+ BOS_TOKEN = "[BOS]"
35
+ EOS_TOKEN = "[EOS]" # kept for compatibility, not auto-added
36
+ UNK_TOKEN = "[UNK]"
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_file: Optional[str] = None,
41
+ vocab: Optional[Dict[str, int]] = None,
42
+ **kwargs,
43
+ ):
44
+ self._pad_token = self.PAD_TOKEN
45
+ self._bos_token = self.BOS_TOKEN
46
+ self._eos_token = self.EOS_TOKEN
47
+ self._unk_token = self.UNK_TOKEN
48
+
49
+ # avoid duplicates from kwargs
50
+ kwargs.pop("pad_token", None)
51
+ kwargs.pop("bos_token", None)
52
+ kwargs.pop("eos_token", None)
53
+ kwargs.pop("unk_token", None)
54
+
55
+ if vocab is not None:
56
+ self._vocab = vocab
57
+ elif vocab_file is not None and os.path.exists(vocab_file):
58
+ with open(vocab_file, "r", encoding="utf-8") as f:
59
+ self._vocab = json.load(f)
60
+ else:
61
+ self._vocab = self._build_fixed_vocab()
62
+
63
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
64
+
65
+ super().__init__(
66
+ pad_token=self._pad_token,
67
+ bos_token=self._bos_token,
68
+ eos_token=self._eos_token,
69
+ unk_token=self._unk_token,
70
+ **kwargs,
71
+ )
72
+
73
+ # --------------------------
74
+ # Fixed vocab: pieces + squares + promos
75
+ # --------------------------
76
+ @staticmethod
77
+ def _all_squares() -> List[str]:
78
+ files = "abcdefgh"
79
+ ranks = "12345678"
80
+ return [f + r for r in ranks for f in files] # a1..h8
81
+
82
+ def _build_fixed_vocab(self) -> Dict[str, int]:
83
+ special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
84
+
85
+ # piece tokens: WP..WK, BP..BK
86
+ piece_tokens = [f"{c}{p}" for c in "WB" for p in "PNBRQK"]
87
+
88
+ squares = self._all_squares()
89
+ from_tokens = [f"{sq}_f" for sq in squares]
90
+ to_tokens = [f"{sq}_t" for sq in squares]
91
+
92
+ promo_tokens = ["=q", "=r", "=b", "=n"]
93
+
94
+ tokens = special + piece_tokens + from_tokens + to_tokens + promo_tokens
95
+ return {tok: i for i, tok in enumerate(tokens)}
96
+
97
+ # --------------------------
98
+ # Special tokens handling (robust with evaluate.py)
99
+ # --------------------------
100
+ def build_inputs_with_special_tokens(
101
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
102
+ ) -> List[int]:
103
+ # BOS only, NO EOS
104
+ if token_ids_1 is not None:
105
+ token_ids_0 = token_ids_0 + token_ids_1
106
+
107
+ if token_ids_0 and token_ids_0[0] == self.bos_token_id:
108
+ return token_ids_0
109
+ return [self.bos_token_id] + token_ids_0
110
+
111
+ def get_special_tokens_mask(
112
+ self,
113
+ token_ids_0: List[int],
114
+ token_ids_1: Optional[List[int]] = None,
115
+ already_has_special_tokens: bool = False,
116
+ ) -> List[int]:
117
+ if already_has_special_tokens:
118
+ specials = {self.pad_token_id, self.bos_token_id, self.eos_token_id, self.unk_token_id}
119
+ return [1 if t in specials else 0 for t in token_ids_0]
120
+
121
+ if token_ids_1 is None:
122
+ return [1] + [0] * len(token_ids_0)
123
+ return [1] + [0] * (len(token_ids_0) + len(token_ids_1))
124
+
125
+ def create_token_type_ids_from_sequences(
126
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
127
+ ) -> List[int]:
128
+ if token_ids_1 is None:
129
+ return [0] * (len(token_ids_0) + 1)
130
+ return [0] * (len(token_ids_0) + len(token_ids_1) + 1)
131
+
132
+ # --------------------------
133
+ # Tokenization
134
+ # --------------------------
135
+ def _tokenize(self, text: str) -> List[str]:
136
+ if not text or not text.strip():
137
+ return []
138
+
139
+ parts = text.strip().split()
140
+ out: List[str] = []
141
+
142
+ for tok in parts:
143
+ # allow literal special tokens present in text
144
+ if tok in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}:
145
+ out.append(tok)
146
+ continue
147
+
148
+ # already decomposed tokens
149
+ if (len(tok) == 2 and tok[0] in "WB" and tok[1] in "PNBRQK") or tok.endswith("_f") or tok.endswith("_t") or tok in {"=q", "=r", "=b", "=n"}:
150
+ out.append(tok)
151
+ continue
152
+
153
+ # parse extended UCI (dataset): WPe2e4, BNg8f6(x), WPe7e8=Q(+), ...
154
+ if len(tok) < 6:
155
+ out.append(self.UNK_TOKEN)
156
+ continue
157
+
158
+ color = tok[0]
159
+ piece = tok[1]
160
+ from_sq = tok[2:4]
161
+ to_sq = tok[4:6]
162
+
163
+ out.append(f"{color}{piece}")
164
+ out.append(f"{from_sq}_f")
165
+ out.append(f"{to_sq}_t")
166
+
167
+ # promotion like "=Q"
168
+ if "=" in tok:
169
+ try:
170
+ promo_part = tok.split("=", 1)[1]
171
+ promo_letter = promo_part[0].lower()
172
+ promo_tok = f"={promo_letter}"
173
+ if promo_tok in self._vocab:
174
+ out.append(promo_tok)
175
+ except Exception:
176
+ pass
177
+
178
+ return out
179
+
180
+ def _convert_token_to_id(self, token: str) -> int:
181
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
182
+
183
+ def _convert_id_to_token(self, index: int) -> str:
184
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
185
+
186
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
187
+ return " ".join(tokens)
188
+
189
+ # --------------------------
190
+ # Vocab I/O
191
+ # --------------------------
192
+ @property
193
+ def vocab_size(self) -> int:
194
+ return len(self._vocab)
195
+
196
+ def get_vocab(self) -> Dict[str, int]:
197
+ return dict(self._vocab)
198
+
199
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
200
+ os.makedirs(save_directory, exist_ok=True)
201
+ vocab_file = os.path.join(
202
+ save_directory,
203
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
204
+ )
205
+ with open(vocab_file, "w", encoding="utf-8") as f:
206
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
207
+ return (vocab_file,)
src/.ipynb_checkpoints/train-checkpoint.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for the Chess Challenge.
3
+
4
+ GPU-optimized version (still compatible with older transformers/accelerate):
5
+ - Uses fp16/bf16 automatically on GPU
6
+ - Uses evaluation + saving per EPOCH by default (much faster than steps)
7
+ - Enables dataloader_num_workers + pin_memory on GPU
8
+ - Optional torch.compile for speed (safe-guarded)
9
+ - Keeps your robust TrainingArguments compatibility (evaluation_strategy vs eval_strategy)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import os
16
+ import warnings
17
+ from pathlib import Path
18
+
19
+ warnings.filterwarnings("ignore", message="'return' in a 'finally' block")
20
+
21
+ import torch
22
+ from transformers import Trainer, TrainingArguments, set_seed
23
+
24
+ from src.data import ChessDataCollator, create_train_val_datasets
25
+ from src.model import ChessConfig, ChessForCausalLM
26
+ from src.tokenizer import ChessTokenizer
27
+ from src.utils import count_parameters, print_parameter_budget
28
+
29
+
30
+ def parse_args():
31
+ p = argparse.ArgumentParser(description="Train a chess-playing language model")
32
+
33
+ # ---------------- Model ----------------
34
+ p.add_argument("--n_embd", type=int, default=128, help="Embedding dimension")
35
+ p.add_argument("--n_layer", type=int, default=6, help="Number of transformer layers")
36
+ p.add_argument("--n_head", type=int, default=8, help="Number of attention heads")
37
+ # For speed on GPU, 256 is often a great default; override via CLI if needed.
38
+ p.add_argument("--n_ctx", type=int, default=256, help="Maximum context length")
39
+
40
+ p.add_argument("--n_inner", type=int, default=248, help="MLP hidden size (SwiGLU: h)")
41
+ p.add_argument("--dropout", type=float, default=0.05, help="Dropout probability")
42
+ p.add_argument("--no_tie_weights", action="store_true", help="Disable weight tying")
43
+
44
+ # improved model.py flags
45
+ p.add_argument("--use_rope", action="store_true", help="Use RoPE (recommended)")
46
+ p.add_argument("--mlp_type", type=str, default="swiglu", choices=["swiglu", "gelu"], help="MLP type")
47
+ p.add_argument("--use_rmsnorm", action="store_true", help="Use RMSNorm (recommended)")
48
+
49
+ # ---------------- Data ----------------
50
+ p.add_argument("--dataset_name", type=str, default="dlouapre/lichess_2025-01_1M")
51
+ p.add_argument("--max_train_samples", type=int, default=None, help="Optional cap for train samples")
52
+ p.add_argument("--val_samples", type=int, default=5000)
53
+
54
+ p.add_argument(
55
+ "--tokenizer_dir",
56
+ type=str,
57
+ default="./tokenizer_cache",
58
+ help="Where to save/load the tokenizer (vocab.json)",
59
+ )
60
+
61
+ # ---------------- Training ----------------
62
+ p.add_argument("--output_dir", type=str, default="./output")
63
+ p.add_argument("--num_train_epochs", type=int, default=3)
64
+
65
+ # For speed: prefer larger batch and smaller accumulation.
66
+ p.add_argument("--per_device_train_batch_size", type=int, default=64)
67
+ p.add_argument("--per_device_eval_batch_size", type=int, default=128)
68
+ p.add_argument("--gradient_accumulation_steps", type=int, default=1)
69
+
70
+ p.add_argument("--learning_rate", type=float, default=3e-4)
71
+ p.add_argument("--weight_decay", type=float, default=0.1)
72
+ p.add_argument("--warmup_steps", type=int, default=300)
73
+
74
+ p.add_argument("--seed", type=int, default=42)
75
+
76
+ # ---------------- Logging / Save ----------------
77
+ p.add_argument("--logging_steps", type=int, default=50)
78
+
79
+ # Eval/save config: epoch by default (much faster). Still allow steps if user wants.
80
+ p.add_argument("--eval_strategy", type=str, default="epoch", choices=["epoch", "steps"], help="Evaluation strategy")
81
+ p.add_argument("--save_strategy", type=str, default="epoch", choices=["epoch", "steps"], help="Save strategy")
82
+ p.add_argument("--eval_steps", type=int, default=1000, help="Only used if eval_strategy=steps")
83
+ p.add_argument("--save_steps", type=int, default=1000, help="Only used if save_strategy=steps")
84
+
85
+ # ---------------- Speed knobs ----------------
86
+ p.add_argument("--dataloader_num_workers", type=int, default=2, help="CPU workers for dataloader")
87
+ p.add_argument("--torch_compile", action="store_true", help="Enable torch.compile on GPU (can speed up)")
88
+
89
+ return p.parse_args()
90
+
91
+
92
+ def load_or_create_tokenizer(args) -> ChessTokenizer:
93
+ tok_dir = Path(args.tokenizer_dir)
94
+ tok_dir.mkdir(parents=True, exist_ok=True)
95
+
96
+ vocab_path = tok_dir / "vocab.json"
97
+ if vocab_path.exists():
98
+ print(f"Loading tokenizer from {tok_dir} ...")
99
+ return ChessTokenizer(vocab_file=str(vocab_path))
100
+
101
+ print("Creating fixed-vocab tokenizer (decomposed) ...")
102
+ tok = ChessTokenizer()
103
+ tok.save_pretrained(str(tok_dir))
104
+ print(f"Tokenizer saved to {tok_dir} (vocab_size={tok.vocab_size})")
105
+ return tok
106
+
107
+
108
+ def _make_training_args(args) -> TrainingArguments:
109
+ """
110
+ Compatibility layer for transformers versions:
111
+ - some use evaluation_strategy, others use eval_strategy
112
+ - we keep it robust while using faster defaults (epoch eval/save).
113
+ """
114
+ use_gpu = torch.cuda.is_available()
115
+ use_bf16 = bool(use_gpu and torch.cuda.is_bf16_supported())
116
+ use_fp16 = bool(use_gpu and not use_bf16)
117
+
118
+ common = dict(
119
+ output_dir=args.output_dir,
120
+ num_train_epochs=args.num_train_epochs,
121
+
122
+ per_device_train_batch_size=args.per_device_train_batch_size,
123
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
124
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
125
+
126
+ learning_rate=args.learning_rate,
127
+ weight_decay=args.weight_decay,
128
+ warmup_steps=args.warmup_steps,
129
+ lr_scheduler_type="cosine",
130
+
131
+ max_grad_norm=1.0,
132
+
133
+ logging_dir=os.path.join(args.output_dir, "logs"),
134
+ logging_steps=args.logging_steps,
135
+
136
+ save_total_limit=2,
137
+ load_best_model_at_end=True,
138
+ metric_for_best_model="eval_loss",
139
+ greater_is_better=False,
140
+
141
+ seed=args.seed,
142
+ report_to=["none"],
143
+
144
+ # Mixed precision for GPU speed
145
+ fp16=use_fp16,
146
+ bf16=use_bf16,
147
+
148
+ # DataLoader perf
149
+ dataloader_num_workers=args.dataloader_num_workers,
150
+ dataloader_pin_memory=use_gpu,
151
+
152
+ # Important for custom batches
153
+ remove_unused_columns=False,
154
+ )
155
+
156
+ # Build kwargs depending on epoch vs steps
157
+ eval_kwargs = {}
158
+ if args.eval_strategy == "steps":
159
+ eval_kwargs["eval_steps"] = args.eval_steps
160
+ save_kwargs = {}
161
+ if args.save_strategy == "steps":
162
+ save_kwargs["save_steps"] = args.save_steps
163
+
164
+ # Try standard HF arg names first
165
+ try:
166
+ return TrainingArguments(
167
+ **common,
168
+ evaluation_strategy=args.eval_strategy,
169
+ save_strategy=args.save_strategy,
170
+ **eval_kwargs,
171
+ **save_kwargs,
172
+ )
173
+ except TypeError:
174
+ # Fallback for forks/older variants that renamed args
175
+ return TrainingArguments(
176
+ **common,
177
+ eval_strategy=args.eval_strategy,
178
+ save_strategy=args.save_strategy,
179
+ **eval_kwargs,
180
+ **save_kwargs,
181
+ )
182
+
183
+
184
+ def main():
185
+ args = parse_args()
186
+ set_seed(args.seed)
187
+
188
+ print("=" * 60)
189
+ print("CHESS CHALLENGE - TRAINING")
190
+ print("=" * 60)
191
+
192
+ tokenizer = load_or_create_tokenizer(args)
193
+ actual_vocab_size = tokenizer.vocab_size
194
+ print(f" Vocab size used: {actual_vocab_size}")
195
+
196
+ print("\nCreating model configuration...")
197
+ config = ChessConfig(
198
+ vocab_size=actual_vocab_size,
199
+ n_embd=args.n_embd,
200
+ n_layer=args.n_layer,
201
+ n_head=args.n_head,
202
+ n_ctx=args.n_ctx,
203
+ n_inner=args.n_inner,
204
+ dropout=args.dropout,
205
+ tie_weights=not args.no_tie_weights,
206
+ pad_token_id=tokenizer.pad_token_id,
207
+ bos_token_id=tokenizer.bos_token_id,
208
+ eos_token_id=tokenizer.eos_token_id,
209
+ use_rope=bool(args.use_rope),
210
+ mlp_type=args.mlp_type,
211
+ use_rmsnorm=bool(args.use_rmsnorm),
212
+ )
213
+
214
+ print_parameter_budget(config)
215
+
216
+ print("\nCreating model...")
217
+ model = ChessForCausalLM(config)
218
+
219
+ # Optional torch.compile (GPU only)
220
+ if args.torch_compile and torch.cuda.is_available():
221
+ try:
222
+ model = torch.compile(model)
223
+ print("✓ torch.compile enabled")
224
+ except Exception as e:
225
+ print(f"WARNING: torch.compile failed ({e}). Continuing without it.")
226
+
227
+ n_params = count_parameters(model)
228
+ print(f" Total parameters: {n_params:,}")
229
+ print("✓ Model is within 1M parameter limit" if n_params <= 1_000_000 else "WARNING: Model exceeds 1M!")
230
+
231
+ print("\nLoading datasets...")
232
+ train_dataset, val_dataset = create_train_val_datasets(
233
+ tokenizer=tokenizer,
234
+ dataset_name=args.dataset_name,
235
+ max_length=args.n_ctx,
236
+ train_samples=args.max_train_samples,
237
+ val_samples=args.val_samples,
238
+ )
239
+ print(f" Training samples: {len(train_dataset):,}")
240
+ print(f" Validation samples: {len(val_dataset):,}")
241
+
242
+ data_collator = ChessDataCollator(tokenizer, max_length=args.n_ctx)
243
+
244
+ training_args = _make_training_args(args)
245
+
246
+ trainer = Trainer(
247
+ model=model,
248
+ args=training_args,
249
+ train_dataset=train_dataset,
250
+ eval_dataset=val_dataset,
251
+ data_collator=data_collator,
252
+ tokenizer=tokenizer,
253
+ )
254
+
255
+ print("\nStarting training...")
256
+ trainer.train()
257
+
258
+ out_final = os.path.join(args.output_dir, "final_model")
259
+ print("\nSaving final model...")
260
+ trainer.save_model(out_final)
261
+ tokenizer.save_pretrained(out_final)
262
+
263
+ print("\nTraining complete!")
264
+ print(f" Model saved to: {out_final}")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()
src/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Chess Challenge.
3
+
4
+ This module provides helper functions for:
5
+ - Parameter counting and budget analysis (including RoPE / SwiGLU / RMSNorm variants)
6
+ - Move validation and conversion with python-chess
7
+ - Optional: compute legal-move rate over a whole game string
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ from typing import Dict, Optional, TYPE_CHECKING
14
+
15
+ import torch.nn as nn
16
+
17
+ if TYPE_CHECKING:
18
+ from src.model import ChessConfig
19
+
20
+
21
+ # =========================
22
+ # Parameter counting
23
+ # =========================
24
+
25
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
26
+ """
27
+ Count the number of parameters in a model.
28
+
29
+ Args:
30
+ model: The PyTorch model.
31
+ trainable_only: If True, only count trainable parameters.
32
+
33
+ Returns:
34
+ Total number of parameters.
35
+ """
36
+ if trainable_only:
37
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ return sum(p.numel() for p in model.parameters())
39
+
40
+
41
+ def count_parameters_by_component(model: nn.Module) -> Dict[str, int]:
42
+ """
43
+ Count parameters broken down by leaf modules.
44
+
45
+ Args:
46
+ model: The PyTorch model.
47
+
48
+ Returns:
49
+ Dictionary mapping module names to parameter counts.
50
+ """
51
+ counts: Dict[str, int] = {}
52
+ for name, module in model.named_modules():
53
+ if len(list(module.children())) == 0: # leaf module
54
+ param_count = sum(p.numel() for p in module.parameters(recurse=False))
55
+ if param_count > 0:
56
+ counts[name] = param_count
57
+ return counts
58
+
59
+
60
+ def estimate_parameters(config: "ChessConfig") -> Dict[str, int]:
61
+ """
62
+ Estimate parameter count for a configuration.
63
+
64
+ Works for:
65
+ - learned position embeddings (wpe) or RoPE (no pos params)
66
+ - GELU FFN (d -> n_inner -> d)
67
+ - SwiGLU FFN (d -> 2h, h -> d) where h = n_inner
68
+ - LayerNorm (weight+bias) vs RMSNorm (weight only)
69
+ - tied or untied LM head
70
+
71
+ NOTE: This is an estimate of *weights + biases* for the common implementation
72
+ patterns used in this repo.
73
+ """
74
+ V = int(config.vocab_size)
75
+ d = int(config.n_embd)
76
+ L = int(config.n_layer)
77
+ n_ctx = int(config.n_ctx)
78
+ n_inner = int(config.n_inner)
79
+
80
+ use_rope = bool(getattr(config, "use_rope", False))
81
+ use_rmsnorm = bool(getattr(config, "use_rmsnorm", False))
82
+ mlp_type = str(getattr(config, "mlp_type", "gelu")).lower()
83
+ tie = bool(getattr(config, "tie_weights", True))
84
+
85
+ # Embeddings
86
+ token_embeddings = V * d
87
+ position_embeddings = 0 if use_rope else (n_ctx * d)
88
+
89
+ # Attention per layer:
90
+ # c_attn: d -> 3d : weight 3d*d, bias 3d
91
+ # c_proj: d -> d : weight d*d, bias d
92
+ attn_qkv_per_layer = 3 * d * d + 3 * d
93
+ attn_proj_per_layer = d * d + d
94
+
95
+ # FFN per layer
96
+ if mlp_type == "swiglu":
97
+ # w12: d -> 2h : weight 2h*d, bias 2h
98
+ # w3: h -> d : weight d*h, bias d
99
+ h = n_inner
100
+ ffn_per_layer = (2 * h * d + 2 * h) + (d * h + d) # 3*d*h + (2h + d)
101
+ else:
102
+ # GELU: d -> n_inner -> d
103
+ ffn_per_layer = (d * n_inner + n_inner) + (n_inner * d + d) # 2*d*n_inner + (n_inner + d)
104
+
105
+ # Norm params
106
+ # LayerNorm: weight+bias => 2d ; RMSNorm: weight => d
107
+ norm_params = d if use_rmsnorm else 2 * d
108
+ norms_per_layer = 2 * norm_params # ln_1 + ln_2
109
+ final_norm = norm_params
110
+
111
+ per_layer = attn_qkv_per_layer + attn_proj_per_layer + ffn_per_layer + norms_per_layer
112
+ total_transformer_layers = L * per_layer
113
+
114
+ # LM head
115
+ # In this repo, lm_head is typically Linear(d, V, bias=False).
116
+ # If untied, count V*d parameters.
117
+ lm_head = 0 if tie else (V * d)
118
+
119
+ total = token_embeddings + position_embeddings + total_transformer_layers + final_norm + lm_head
120
+
121
+ return {
122
+ "token_embeddings": token_embeddings,
123
+ "position_embeddings": position_embeddings,
124
+ "attention_qkv_per_layer": attn_qkv_per_layer,
125
+ "attention_proj_per_layer": attn_proj_per_layer,
126
+ "ffn_per_layer": ffn_per_layer,
127
+ "norms_per_layer": norms_per_layer,
128
+ "final_norm": final_norm,
129
+ "total_transformer_layers": total_transformer_layers,
130
+ "lm_head": lm_head,
131
+ "total": total,
132
+ "notes": {
133
+ "use_rope": use_rope,
134
+ "use_rmsnorm": use_rmsnorm,
135
+ "mlp_type": mlp_type,
136
+ "tie_weights": tie,
137
+ },
138
+ }
139
+
140
+
141
+ def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None:
142
+ """
143
+ Print a formatted parameter budget analysis.
144
+
145
+ Args:
146
+ config: Model configuration.
147
+ limit: Parameter limit.
148
+ """
149
+ est = estimate_parameters(config)
150
+
151
+ print("=" * 60)
152
+ print("PARAMETER BUDGET ANALYSIS")
153
+ print("=" * 60)
154
+ print("\nConfiguration:")
155
+ print(f" vocab_size (V) = {config.vocab_size}")
156
+ print(f" n_embd (d) = {config.n_embd}")
157
+ print(f" n_layer (L) = {config.n_layer}")
158
+ print(f" n_head = {config.n_head}")
159
+ print(f" n_ctx = {config.n_ctx}")
160
+ print(f" n_inner = {config.n_inner}")
161
+ print(f" tie_weights = {getattr(config, 'tie_weights', True)}")
162
+ if hasattr(config, "use_rope"):
163
+ print(f" use_rope = {getattr(config, 'use_rope', False)}")
164
+ if hasattr(config, "mlp_type"):
165
+ print(f" mlp_type = {getattr(config, 'mlp_type', 'gelu')}")
166
+ if hasattr(config, "use_rmsnorm"):
167
+ print(f" use_rmsnorm = {getattr(config, 'use_rmsnorm', False)}")
168
+
169
+ print("\nParameter Breakdown (estimate):")
170
+ print(f" Token Embeddings: {est['token_embeddings']:>10,}")
171
+ print(f" Position Embeddings: {est['position_embeddings']:>10,}")
172
+ print(f" Transformer Layers: {est['total_transformer_layers']:>10,}")
173
+ print(f" Final Norm: {est['final_norm']:>10,}")
174
+ if getattr(config, "tie_weights", True):
175
+ print(f" LM Head: {'(tied)':>10}")
176
+ else:
177
+ print(f" LM Head: {est['lm_head']:>10,}")
178
+
179
+ print(" " + "-" * 32)
180
+ print(f" TOTAL: {est['total']:>10,}")
181
+
182
+ remaining = limit - est["total"]
183
+ print("\nBudget Status:")
184
+ print(f" Limit: {limit:>10,}")
185
+ print(f" Used: {est['total']:>10,}")
186
+ print(f" Remaining: {remaining:>10,}")
187
+
188
+ if est["total"] <= limit:
189
+ print(f"\n✓ Within budget! ({est['total'] / limit * 100:.1f}% used)")
190
+ else:
191
+ print(f"\n✗ OVER BUDGET by {-remaining:,} parameters!")
192
+ print("=" * 60)
193
+
194
+
195
+ # =========================
196
+ # Move conversion / validation (python-chess)
197
+ # =========================
198
+
199
+ def convert_extended_uci_to_uci(move: str) -> str:
200
+ """
201
+ Convert extended UCI format to standard UCI format.
202
+
203
+ Extended UCI format (dataset):
204
+ [W|B][Piece][from_sq][to_sq][suffixes...]
205
+ e.g. "WPe2e4", "BNg8f6(x)", "WKe1g1(o)", "WPe7e8=Q(+)"
206
+ Standard UCI:
207
+ "e2e4", "g8f6", "e1g1", "e7e8q"
208
+ """
209
+ if len(move) < 6:
210
+ return move
211
+
212
+ from_sq = move[2:4]
213
+ to_sq = move[4:6]
214
+
215
+ promotion = ""
216
+ if "=" in move:
217
+ promo_idx = move.index("=")
218
+ if promo_idx + 1 < len(move):
219
+ promotion = move[promo_idx + 1].lower()
220
+
221
+ return from_sq + to_sq + promotion
222
+
223
+
224
+ def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool:
225
+ """
226
+ Validate a single move using python-chess against a given board state.
227
+
228
+ IMPORTANT:
229
+ - If board_fen is None, validation is against the initial position.
230
+ For validating a *game*, use `legal_rate_game_text` which advances the board.
231
+
232
+ Args:
233
+ move: Move in extended UCI format.
234
+ board_fen: FEN string of the current board (optional).
235
+
236
+ Returns:
237
+ True if move is legal on that board, else False.
238
+ """
239
+ try:
240
+ import chess
241
+ except ImportError:
242
+ raise ImportError(
243
+ "python-chess is required for move validation. Install it with: pip install python-chess"
244
+ )
245
+
246
+ if len(move) < 6:
247
+ return False
248
+
249
+ board = chess.Board(board_fen) if board_fen else chess.Board()
250
+ uci_move = convert_extended_uci_to_uci(move)
251
+
252
+ try:
253
+ move_obj = chess.Move.from_uci(uci_move)
254
+ return move_obj in board.legal_moves
255
+ except Exception:
256
+ return False
257
+
258
+
259
+ def legal_rate_game_text(game_text: str, stop_on_illegal: bool = True) -> float:
260
+ """
261
+ Compute the fraction of legal moves in a space-separated extended-UCI game string.
262
+
263
+ Args:
264
+ game_text: "WPe2e4 BPe7e5 ..." (space-separated moves)
265
+ stop_on_illegal: If True, stop at first illegal move.
266
+
267
+ Returns:
268
+ legal / total (total is moves processed, or total moves if stop_on_illegal=False)
269
+ """
270
+ try:
271
+ import chess
272
+ except ImportError:
273
+ raise ImportError("python-chess is required. Install it with: pip install python-chess")
274
+
275
+ moves = game_text.strip().split()
276
+ if not moves:
277
+ return 0.0
278
+
279
+ board = chess.Board()
280
+ legal = 0
281
+ total = 0
282
+
283
+ for mv in moves:
284
+ total += 1
285
+ uci = convert_extended_uci_to_uci(mv)
286
+ try:
287
+ m = chess.Move.from_uci(uci)
288
+ except Exception:
289
+ if stop_on_illegal:
290
+ break
291
+ continue
292
+
293
+ if m in board.legal_moves:
294
+ legal += 1
295
+ board.push(m)
296
+ else:
297
+ if stop_on_illegal:
298
+ break
299
+
300
+ return legal / max(total, 1)
301
+
302
+
303
+ def convert_uci_to_extended(uci_move: str, board_fen: str) -> str:
304
+ """
305
+ Convert standard UCI move to extended UCI format used by the dataset.
306
+
307
+ Args:
308
+ uci_move: e.g., "e2e4", "e7e8q", "e1g1"
309
+ board_fen: FEN of current board (must match move)
310
+
311
+ Returns:
312
+ Extended UCI like "WPe2e4", with suffixes:
313
+ - (x) capture
314
+ - (+) check
315
+ - (+*) checkmate
316
+ - (x+) capture+check
317
+ - (x+*) capture+checkmate
318
+ - (o) / (O) castling
319
+ - promotions as "=Q" etc
320
+ """
321
+ try:
322
+ import chess
323
+ except ImportError:
324
+ raise ImportError("python-chess is required for move conversion. Install it with: pip install python-chess")
325
+
326
+ board = chess.Board(board_fen)
327
+ move = chess.Move.from_uci(uci_move)
328
+
329
+ color = "W" if board.turn == chess.WHITE else "B"
330
+
331
+ piece = board.piece_at(move.from_square)
332
+ piece_letter = piece.symbol().upper() if piece else "P"
333
+
334
+ from_sq = chess.square_name(move.from_square)
335
+ to_sq = chess.square_name(move.to_square)
336
+
337
+ result = f"{color}{piece_letter}{from_sq}{to_sq}"
338
+
339
+ # Promotion
340
+ if move.promotion:
341
+ result += f"={chess.piece_symbol(move.promotion).upper()}"
342
+
343
+ # Capture suffix
344
+ if board.is_capture(move):
345
+ result += "(x)"
346
+
347
+ # Check / mate suffix (need to push)
348
+ board.push(move)
349
+ if board.is_checkmate():
350
+ if "(x)" in result:
351
+ result = result.replace("(x)", "(x+*)")
352
+ else:
353
+ result += "(+*)"
354
+ elif board.is_check():
355
+ if "(x)" in result:
356
+ result = result.replace("(x)", "(x+)")
357
+ else:
358
+ result += "(+)"
359
+ board.pop()
360
+
361
+ # Castling (dataset wants (o)/(O), usually no other suffix with it)
362
+ if board.is_castling(move):
363
+ result = re.sub(r"\([^)]*\)", "", result) # drop any (...) suffix
364
+ if move.to_square in [chess.G1, chess.G8]:
365
+ result += "(o)"
366
+ else:
367
+ result += "(O)"
368
+
369
+ return result
src/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chess Challenge source module."""
2
+
3
+ from .model import ChessConfig, ChessForCausalLM
4
+ from .tokenizer import ChessTokenizer
5
+
6
+ # Lazy import for evaluate to avoid RuntimeWarning when running as module
7
+ def __getattr__(name):
8
+ if name == "ChessEvaluator":
9
+ from .evaluate import ChessEvaluator
10
+ return ChessEvaluator
11
+ if name == "load_model_from_hub":
12
+ from .evaluate import load_model_from_hub
13
+ return load_model_from_hub
14
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
15
+
16
+ __all__ = [
17
+ "ChessConfig",
18
+ "ChessForCausalLM",
19
+ "ChessTokenizer",
20
+ "ChessEvaluator",
21
+ "load_model_from_hub",
22
+ ]
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (878 Bytes). View file
 
src/__pycache__/data.cpython-311.pyc ADDED
Binary file (8.93 kB). View file
 
src/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (32.5 kB). View file
 
src/__pycache__/model.cpython-311.pyc ADDED
Binary file (26.3 kB). View file
 
src/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
src/__pycache__/train.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
src/data.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading utilities for the Chess Challenge.
3
+
4
+ This module provides functions to load and process chess game data
5
+ from the Lichess dataset on Hugging Face.
6
+
7
+ IMPORTANT NOTE (compat with template evaluate + custom tokenizers):
8
+ - Do NOT manually prepend BOS in the raw text.
9
+ The tokenizer should handle BOS via build_inputs_with_special_tokens.
10
+ This avoids double-BOS issues and keeps train/eval conventions aligned.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Dict, Iterator, List, Optional
16
+
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+
20
+
21
+ class ChessDataset(Dataset):
22
+ """
23
+ PyTorch Dataset for chess games.
24
+
25
+ Each game is tokenized and truncated/padded to max_length.
26
+ Labels are identical to input_ids; the model shifts internally.
27
+ Padding labels are set to -100 (HF convention) so they are ignored by CE loss.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ tokenizer,
33
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
34
+ split: str = "train",
35
+ column: str = "text",
36
+ max_length: int = 256,
37
+ max_samples: Optional[int] = None,
38
+ ):
39
+ from datasets import load_dataset
40
+
41
+ self.tokenizer = tokenizer
42
+ self.max_length = max_length
43
+ self.column = column
44
+
45
+ dataset = load_dataset(dataset_name, split=split)
46
+
47
+ if max_samples is not None:
48
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
49
+
50
+ self.data = dataset
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.data)
54
+
55
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
56
+ game = self.data[idx][self.column]
57
+
58
+ # IMPORTANT: do NOT prepend BOS manually in raw text.
59
+ # The tokenizer should add BOS (and only BOS if desired) via
60
+ # build_inputs_with_special_tokens, keeping things compatible with evaluate.py.
61
+ encoding = self.tokenizer(
62
+ game,
63
+ truncation=True,
64
+ max_length=self.max_length,
65
+ padding="max_length",
66
+ return_tensors="pt",
67
+ )
68
+
69
+ input_ids = encoding["input_ids"].squeeze(0)
70
+ attention_mask = encoding["attention_mask"].squeeze(0)
71
+
72
+ labels = input_ids.clone()
73
+ labels[attention_mask == 0] = -100
74
+
75
+ return {
76
+ "input_ids": input_ids,
77
+ "attention_mask": attention_mask,
78
+ "labels": labels,
79
+ }
80
+
81
+
82
+ class ChessDataCollator:
83
+ """
84
+ Data collator for chess games.
85
+
86
+ Here sequences are already padded to max_length in the dataset,
87
+ so we just stack tensors.
88
+ """
89
+
90
+ def __init__(self, tokenizer, max_length: int = 256):
91
+ self.tokenizer = tokenizer
92
+ self.max_length = max_length
93
+
94
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
95
+ input_ids = torch.stack([f["input_ids"] for f in features])
96
+ attention_mask = torch.stack([f["attention_mask"] for f in features])
97
+ labels = torch.stack([f["labels"] for f in features])
98
+
99
+ return {
100
+ "input_ids": input_ids,
101
+ "attention_mask": attention_mask,
102
+ "labels": labels,
103
+ }
104
+
105
+
106
+ def create_train_val_datasets(
107
+ tokenizer,
108
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
109
+ max_length: int = 256,
110
+ train_samples: Optional[int] = None,
111
+ val_samples: int = 5000,
112
+ val_ratio: float = 0.05,
113
+ ):
114
+ """
115
+ Create training and validation datasets.
116
+
117
+ Splits the dataset deterministically by index:
118
+ - train: [0:n_train)
119
+ - val: [n_train:n_train+n_val)
120
+
121
+ Returns:
122
+ (train_dataset, val_dataset)
123
+ """
124
+ from datasets import load_dataset
125
+
126
+ full_dataset = load_dataset(dataset_name, split="train")
127
+ total = len(full_dataset)
128
+
129
+ if train_samples is not None:
130
+ n_train = min(train_samples, total - val_samples)
131
+ else:
132
+ n_train = int(total * (1 - val_ratio))
133
+
134
+ n_val = min(val_samples, total - n_train)
135
+
136
+ train_data = full_dataset.select(range(n_train))
137
+ val_data = full_dataset.select(range(n_train, n_train + n_val))
138
+
139
+ train_dataset = ChessDataset(
140
+ tokenizer=tokenizer,
141
+ dataset_name=dataset_name,
142
+ max_length=max_length,
143
+ )
144
+ train_dataset.data = train_data
145
+
146
+ val_dataset = ChessDataset(
147
+ tokenizer=tokenizer,
148
+ dataset_name=dataset_name,
149
+ max_length=max_length,
150
+ )
151
+ val_dataset.data = val_data
152
+
153
+ return train_dataset, val_dataset
154
+
155
+
156
+ def stream_games(
157
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
158
+ split: str = "train",
159
+ column: str = "text",
160
+ ) -> Iterator[str]:
161
+ """
162
+ Stream games from the dataset for memory-efficient processing.
163
+ """
164
+ from datasets import load_dataset
165
+
166
+ dataset = load_dataset(dataset_name, split=split, streaming=True)
167
+ for example in dataset:
168
+ yield example[column]
169
+
170
+
171
+ def analyze_dataset_statistics(
172
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
173
+ max_samples: int = 10000,
174
+ ) -> Dict:
175
+ """
176
+ Analyze statistics of the chess dataset (non-streaming).
177
+ """
178
+ from collections import Counter
179
+ from datasets import load_dataset
180
+
181
+ dataset = load_dataset(dataset_name, split="train")
182
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
183
+
184
+ game_lengths = []
185
+ move_counts = Counter()
186
+ opening_moves = Counter()
187
+
188
+ for example in dataset:
189
+ moves = example["text"].strip().split()
190
+ game_lengths.append(len(moves))
191
+ move_counts.update(moves)
192
+
193
+ if len(moves) >= 4:
194
+ opening = " ".join(moves[:4])
195
+ opening_moves[opening] += 1
196
+
197
+ return {
198
+ "total_games": len(dataset),
199
+ "avg_game_length": sum(game_lengths) / len(game_lengths),
200
+ "min_game_length": min(game_lengths),
201
+ "max_game_length": max(game_lengths),
202
+ "unique_moves": len(move_counts),
203
+ "most_common_moves": move_counts.most_common(20),
204
+ "most_common_openings": opening_moves.most_common(10),
205
+ }
src/evaluate.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for the Chess Challenge.
3
+
4
+ This script evaluates a trained chess model by playing games against
5
+ Stockfish and computing ELO ratings.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import random
12
+ import re
13
+ from dataclasses import dataclass
14
+ from typing import List, Optional, Tuple
15
+
16
+ import torch
17
+
18
+
19
+ @dataclass
20
+ class GameResult:
21
+ """Result of a single game."""
22
+ moves: List[str]
23
+ result: str # "1-0", "0-1", or "1/2-1/2"
24
+ model_color: str # "white" or "black"
25
+ termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
26
+ illegal_move_count: int
27
+
28
+
29
+ class ChessEvaluator:
30
+ """
31
+ Evaluator for chess models.
32
+
33
+ This class handles playing games between a trained model and Stockfish,
34
+ tracking results, and computing ELO ratings.
35
+
36
+ Supports any tokenization format as long as the model generates valid
37
+ chess squares (e.g., e2, e4). The evaluator extracts UCI moves by finding
38
+ square patterns in the generated output.
39
+ """
40
+
41
+ # Regex pattern to match chess squares
42
+ SQUARE_PATTERN = r"[a-h][1-8]"
43
+
44
+ def __init__(
45
+ self,
46
+ model,
47
+ tokenizer,
48
+ stockfish_path: Optional[str] = None,
49
+ stockfish_level: int = 1,
50
+ max_retries: int = 3,
51
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
52
+ ):
53
+ """
54
+ Initialize the evaluator.
55
+
56
+ Args:
57
+ model: The trained chess model.
58
+ tokenizer: The chess tokenizer.
59
+ stockfish_path: Path to Stockfish executable.
60
+ stockfish_level: Stockfish skill level (0-20).
61
+ max_retries: Maximum retries for illegal moves.
62
+ device: Device to run the model on.
63
+ """
64
+ self.model = model.to(device)
65
+ self.model.eval()
66
+ self.tokenizer = tokenizer
67
+ self.max_retries = max_retries
68
+ self.device = device
69
+
70
+ # Initialize Stockfish
71
+ try:
72
+ import chess
73
+ import chess.engine
74
+
75
+ self.chess = chess
76
+
77
+ if stockfish_path is None:
78
+ # Try common paths
79
+ import shutil
80
+
81
+ stockfish_path = shutil.which("stockfish")
82
+
83
+ if stockfish_path:
84
+ self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
85
+ self.engine.configure({"Skill Level": stockfish_level})
86
+ else:
87
+ print("WARNING: Stockfish not found. Install it for full evaluation.")
88
+ self.engine = None
89
+
90
+ except ImportError:
91
+ raise ImportError(
92
+ "python-chess is required for evaluation. "
93
+ "Install it with: pip install python-chess"
94
+ )
95
+
96
+ def __del__(self):
97
+ """Clean up Stockfish engine."""
98
+ if hasattr(self, "engine") and self.engine:
99
+ self.engine.quit()
100
+
101
+ def _detect_tokenizer_format(self) -> str:
102
+ """
103
+ Detect the tokenizer's expected move format by testing tokenization.
104
+
105
+ Tests various formats with a sample move and picks the one that
106
+ produces the fewest unknown tokens. This makes evaluation work
107
+ with any tokenizer format.
108
+
109
+ Supported formats:
110
+ - 'decomposed': "WP e2_f e4_t" (piece, from_suffix, to_suffix)
111
+ - 'standard': "WPe2e4" (combined with optional annotations)
112
+ - 'uci': "e2e4" (pure UCI notation)
113
+ - 'uci_spaced': "e2 e4" (UCI with space separator)
114
+
115
+ Returns:
116
+ The format string that best matches the tokenizer's vocabulary.
117
+ """
118
+ if hasattr(self, "_cached_format"):
119
+ return self._cached_format
120
+
121
+ test_formats = {
122
+ "decomposed": "WP e2_f e4_t",
123
+ "standard": "WPe2e4",
124
+ "uci": "e2e4",
125
+ "uci_spaced": "e2 e4",
126
+ }
127
+
128
+ unk_token_id = getattr(self.tokenizer, "unk_token_id", None)
129
+ best_format = "standard"
130
+ min_unk_count = float("inf")
131
+
132
+ for fmt, sample in test_formats.items():
133
+ try:
134
+ tokens = self.tokenizer.encode(sample, add_special_tokens=False)
135
+ unk_count = tokens.count(unk_token_id) if unk_token_id is not None else 0
136
+ if len(tokens) == 1 and unk_count == 1:
137
+ unk_count = 100 # heavy penalty
138
+ if unk_count < min_unk_count:
139
+ min_unk_count = unk_count
140
+ best_format = fmt
141
+ except Exception:
142
+ continue
143
+
144
+ self._cached_format = best_format
145
+ return best_format
146
+
147
+ def _format_move(
148
+ self,
149
+ color: str,
150
+ piece: str,
151
+ from_sq: str,
152
+ to_sq: str,
153
+ promotion: str = None,
154
+ ) -> str:
155
+ fmt = self._detect_tokenizer_format()
156
+
157
+ if fmt == "decomposed":
158
+ move_str = f"{color}{piece} {from_sq}_f {to_sq}_t"
159
+ elif fmt == "uci":
160
+ move_str = f"{from_sq}{to_sq}"
161
+ if promotion:
162
+ move_str += promotion.lower()
163
+ elif fmt == "uci_spaced":
164
+ move_str = f"{from_sq} {to_sq}"
165
+ if promotion:
166
+ move_str += f" {promotion.lower()}"
167
+ else: # standard
168
+ move_str = f"{color}{piece}{from_sq}{to_sq}"
169
+ if promotion:
170
+ move_str += f"={promotion}"
171
+
172
+ return move_str
173
+
174
+ def _convert_board_to_moves(self, board) -> str:
175
+ moves = []
176
+ temp_board = self.chess.Board()
177
+ fmt = self._detect_tokenizer_format()
178
+
179
+ for move in board.move_stack:
180
+ color = "W" if temp_board.turn == self.chess.WHITE else "B"
181
+ piece = temp_board.piece_at(move.from_square)
182
+ piece_letter = piece.symbol().upper() if piece else "P"
183
+
184
+ from_sq = self.chess.square_name(move.from_square)
185
+ to_sq = self.chess.square_name(move.to_square)
186
+
187
+ promo = None
188
+ if move.promotion:
189
+ promo = self.chess.piece_symbol(move.promotion).upper()
190
+
191
+ move_str = self._format_move(color, piece_letter, from_sq, to_sq, promo)
192
+
193
+ if fmt == "standard":
194
+ if temp_board.is_capture(move):
195
+ move_str += "(x)"
196
+
197
+ temp_board.push(move)
198
+
199
+ if temp_board.is_checkmate():
200
+ if "(x)" in move_str:
201
+ move_str = move_str.replace("(x)", "(x+*)")
202
+ else:
203
+ move_str += "(+*)"
204
+ elif temp_board.is_check():
205
+ if "(x)" in move_str:
206
+ move_str = move_str.replace("(x)", "(x+)")
207
+ else:
208
+ move_str += "(+)"
209
+
210
+ if piece_letter == "K":
211
+ if abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
212
+ if to_sq[0] == "g":
213
+ move_str = move_str.split("(")[0] + "(o)"
214
+ else:
215
+ move_str = move_str.split("(")[0] + "(O)"
216
+ else:
217
+ temp_board.push(move)
218
+
219
+ moves.append(move_str)
220
+
221
+ return " ".join(moves)
222
+
223
+ def _is_separator_token(self, token_str: str) -> bool:
224
+ if hasattr(self.tokenizer, "eos_token") and token_str == self.tokenizer.eos_token:
225
+ return True
226
+ if token_str.strip() == "" and len(token_str) > 0:
227
+ return True
228
+ if token_str != token_str.rstrip():
229
+ return True
230
+ return False
231
+
232
+ def _extract_uci_move(self, text: str) -> Optional[str]:
233
+ if not text:
234
+ return None
235
+
236
+ squares = re.findall(self.SQUARE_PATTERN, text)
237
+ if len(squares) < 2:
238
+ return None
239
+
240
+ from_sq, to_sq = squares[0], squares[1]
241
+ uci_move = from_sq + to_sq
242
+
243
+ to_sq_idx = text.find(to_sq)
244
+ if to_sq_idx != -1:
245
+ remaining = text[to_sq_idx + 2 : to_sq_idx + 5]
246
+ promo_match = re.search(r"[=]?([qrbnQRBN])", remaining)
247
+ if promo_match:
248
+ uci_move += promo_match.group(1).lower()
249
+
250
+ return uci_move
251
+
252
+ def _has_complete_move(self, text: str) -> bool:
253
+ squares = re.findall(self.SQUARE_PATTERN, text)
254
+ return len(squares) >= 2
255
+
256
+ def _generate_move_tokens(
257
+ self,
258
+ input_ids: torch.Tensor,
259
+ temperature: float = 0.7,
260
+ top_k: int = 10,
261
+ max_tokens: int = 20,
262
+ ) -> str:
263
+ generated_tokens = []
264
+ current_ids = input_ids.clone()
265
+ accumulated_text = ""
266
+
267
+ for _ in range(max_tokens):
268
+ with torch.no_grad():
269
+ outputs = self.model(input_ids=current_ids)
270
+ logits = outputs.logits[:, -1, :] / max(temperature, 1e-6)
271
+
272
+ if top_k > 0:
273
+ top_k_vals = torch.topk(logits, min(top_k, logits.size(-1)))
274
+ indices_to_remove = logits < top_k_vals[0][..., -1, None]
275
+ logits[indices_to_remove] = float("-inf")
276
+
277
+ probs = torch.softmax(logits, dim=-1)
278
+ next_token = torch.multinomial(probs, num_samples=1)
279
+
280
+ token_str = self.tokenizer.decode(next_token[0])
281
+
282
+ if self._is_separator_token(token_str):
283
+ if self._has_complete_move(accumulated_text):
284
+ break
285
+ if hasattr(self.tokenizer, "eos_token") and token_str == self.tokenizer.eos_token:
286
+ break
287
+ if accumulated_text:
288
+ break
289
+
290
+ generated_tokens.append(next_token[0])
291
+ current_ids = torch.cat([current_ids, next_token], dim=-1)
292
+ accumulated_text += token_str
293
+
294
+ if self._has_complete_move(accumulated_text):
295
+ squares = re.findall(self.SQUARE_PATTERN, accumulated_text)
296
+ if len(squares) >= 2:
297
+ to_sq = squares[1]
298
+ if to_sq[1] in "18":
299
+ if len(generated_tokens) > 3:
300
+ break
301
+ else:
302
+ break
303
+
304
+ if generated_tokens:
305
+ all_tokens = torch.cat(generated_tokens, dim=0)
306
+ move_str = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
307
+ return move_str.strip()
308
+
309
+ return ""
310
+
311
+ def _get_model_move(
312
+ self,
313
+ board,
314
+ temperature: float = 0.7,
315
+ top_k: int = 10,
316
+ ) -> Tuple[Optional[str], int]:
317
+ self.model.eval()
318
+
319
+ moves_str = self._convert_board_to_moves(board)
320
+
321
+ if not moves_str:
322
+ input_text = self.tokenizer.bos_token
323
+ else:
324
+ input_text = self.tokenizer.bos_token + " " + moves_str
325
+
326
+ inputs = self.tokenizer(
327
+ input_text,
328
+ return_tensors="pt",
329
+ truncation=True,
330
+ max_length=self.model.config.n_ctx - 10,
331
+ ).to(self.device)
332
+
333
+ for retry in range(self.max_retries):
334
+ move_text = self._generate_move_tokens(
335
+ inputs["input_ids"],
336
+ temperature=temperature,
337
+ top_k=top_k,
338
+ )
339
+
340
+ uci_move = self._extract_uci_move(move_text)
341
+
342
+ if uci_move:
343
+ try:
344
+ move = self.chess.Move.from_uci(uci_move)
345
+ if move in board.legal_moves:
346
+ return uci_move, retry
347
+ except (ValueError, self.chess.InvalidMoveError):
348
+ pass
349
+
350
+ return None, self.max_retries
351
+
352
+ def _get_stockfish_move(self, board, time_limit: float = 0.1) -> str:
353
+ if self.engine is None:
354
+ raise RuntimeError("Stockfish engine not initialized")
355
+
356
+ result = self.engine.play(board, self.chess.engine.Limit(time=time_limit))
357
+ return result.move.uci()
358
+
359
+ def play_game(
360
+ self,
361
+ model_color: str = "white",
362
+ max_moves: int = 200,
363
+ temperature: float = 0.7,
364
+ ) -> GameResult:
365
+ board = self.chess.Board()
366
+ moves = []
367
+ illegal_move_count = 0
368
+
369
+ model_is_white = model_color == "white"
370
+
371
+ while not board.is_game_over() and len(moves) < max_moves:
372
+ is_model_turn = (board.turn == self.chess.WHITE) == model_is_white
373
+
374
+ if is_model_turn:
375
+ uci_move, retries = self._get_model_move(board, temperature)
376
+ illegal_move_count += retries
377
+
378
+ if uci_move is None:
379
+ return GameResult(
380
+ moves=moves,
381
+ result="0-1" if model_is_white else "1-0",
382
+ model_color=model_color,
383
+ termination="illegal_move",
384
+ illegal_move_count=illegal_move_count + 1,
385
+ )
386
+
387
+ move = self.chess.Move.from_uci(uci_move)
388
+ else:
389
+ if self.engine:
390
+ uci_move = self._get_stockfish_move(board)
391
+ move = self.chess.Move.from_uci(uci_move)
392
+ else:
393
+ move = random.choice(list(board.legal_moves))
394
+
395
+ board.push(move)
396
+ moves.append(move.uci())
397
+
398
+ if board.is_checkmate():
399
+ if board.turn == self.chess.WHITE:
400
+ result = "0-1"
401
+ else:
402
+ result = "1-0"
403
+ termination = "checkmate"
404
+ elif board.is_stalemate():
405
+ result = "1/2-1/2"
406
+ termination = "stalemate"
407
+ elif board.is_insufficient_material():
408
+ result = "1/2-1/2"
409
+ termination = "insufficient_material"
410
+ elif board.can_claim_draw():
411
+ result = "1/2-1/2"
412
+ termination = "draw_claim"
413
+ elif len(moves) >= max_moves:
414
+ result = "1/2-1/2"
415
+ termination = "max_moves"
416
+ else:
417
+ result = "1/2-1/2"
418
+ termination = "unknown"
419
+
420
+ return GameResult(
421
+ moves=moves,
422
+ result=result,
423
+ model_color=model_color,
424
+ termination=termination,
425
+ illegal_move_count=illegal_move_count,
426
+ )
427
+
428
+ def evaluate_legal_moves(
429
+ self,
430
+ n_positions: int = 1000,
431
+ temperature: float = 0.7,
432
+ verbose: bool = True,
433
+ seed: int = 42,
434
+ ) -> dict:
435
+ random.seed(seed)
436
+ torch.manual_seed(seed)
437
+
438
+ results = {
439
+ "total_positions": 0,
440
+ "legal_first_try": 0,
441
+ "legal_with_retry": 0,
442
+ "illegal_all_retries": 0,
443
+ "positions": [],
444
+ }
445
+
446
+ for i in range(n_positions):
447
+ board = self.chess.Board()
448
+
449
+ n_random_moves = random.randint(5, 40)
450
+ for _ in range(n_random_moves):
451
+ if board.is_game_over():
452
+ break
453
+ move = random.choice(list(board.legal_moves))
454
+ board.push(move)
455
+
456
+ if board.is_game_over():
457
+ continue
458
+
459
+ results["total_positions"] += 1
460
+
461
+ uci_move, retries = self._get_model_move(board, temperature)
462
+
463
+ position_result = {
464
+ "fen": board.fen(),
465
+ "move_number": len(board.move_stack),
466
+ "legal": uci_move is not None,
467
+ "retries": retries,
468
+ }
469
+ results["positions"].append(position_result)
470
+
471
+ if uci_move is not None:
472
+ if retries == 0:
473
+ results["legal_first_try"] += 1
474
+ else:
475
+ results["legal_with_retry"] += 1
476
+ else:
477
+ results["illegal_all_retries"] += 1
478
+
479
+ if verbose and (i + 1) % 100 == 0:
480
+ legal_rate = (results["legal_first_try"] + results["legal_with_retry"]) / results["total_positions"]
481
+ print(f" Positions: {i + 1}/{n_positions} | Legal rate: {legal_rate:.1%}")
482
+
483
+ total = results["total_positions"]
484
+ if total > 0:
485
+ results["legal_rate_first_try"] = results["legal_first_try"] / total
486
+ results["legal_rate_with_retry"] = (results["legal_first_try"] + results["legal_with_retry"]) / total
487
+ results["illegal_rate"] = results["illegal_all_retries"] / total
488
+ else:
489
+ results["legal_rate_first_try"] = 0
490
+ results["legal_rate_with_retry"] = 0
491
+ results["illegal_rate"] = 1
492
+
493
+ return results
494
+
495
+ def evaluate(
496
+ self,
497
+ n_games: int = 100,
498
+ temperature: float = 0.7,
499
+ verbose: bool = True,
500
+ ) -> dict:
501
+ results = {
502
+ "wins": 0,
503
+ "losses": 0,
504
+ "draws": 0,
505
+ "illegal_moves": 0,
506
+ "total_moves": 0,
507
+ "games": [],
508
+ }
509
+
510
+ for i in range(n_games):
511
+ model_color = "white" if i % 2 == 0 else "black"
512
+
513
+ game = self.play_game(
514
+ model_color=model_color,
515
+ temperature=temperature,
516
+ )
517
+
518
+ results["games"].append(game)
519
+ results["total_moves"] += len(game.moves)
520
+ results["illegal_moves"] += game.illegal_move_count
521
+
522
+ if game.result == "1/2-1/2":
523
+ results["draws"] += 1
524
+ elif (game.result == "1-0" and model_color == "white") or (game.result == "0-1" and model_color == "black"):
525
+ results["wins"] += 1
526
+ else:
527
+ results["losses"] += 1
528
+
529
+ if verbose and (i + 1) % 10 == 0:
530
+ print(
531
+ f" Games: {i + 1}/{n_games} | "
532
+ f"W: {results['wins']} L: {results['losses']} D: {results['draws']}"
533
+ )
534
+
535
+ total = results["wins"] + results["losses"] + results["draws"]
536
+ results["win_rate"] = results["wins"] / total if total > 0 else 0
537
+ results["draw_rate"] = results["draws"] / total if total > 0 else 0
538
+ results["loss_rate"] = results["losses"] / total if total > 0 else 0
539
+
540
+ total_attempts = results["total_moves"] + results["illegal_moves"]
541
+ results["avg_game_length"] = total_attempts / total if total > 0 else 0
542
+ results["illegal_move_rate"] = results["illegal_moves"] / total_attempts if total_attempts > 0 else 0
543
+
544
+ stockfish_elo = 1350
545
+ if results["win_rate"] > 0 or results["loss_rate"] > 0:
546
+ score = results["wins"] + 0.5 * results["draws"]
547
+ if score > 0:
548
+ win_ratio = score / total
549
+ if 0 < win_ratio < 1:
550
+ elo_diff = -400 * (1 - 2 * win_ratio) / (1 if win_ratio > 0.5 else -1)
551
+ results["estimated_elo"] = stockfish_elo + elo_diff
552
+ else:
553
+ results["estimated_elo"] = stockfish_elo + (400 if win_ratio >= 1 else -400)
554
+ else:
555
+ results["estimated_elo"] = stockfish_elo - 400
556
+ else:
557
+ results["estimated_elo"] = None
558
+
559
+ return results
560
+
561
+
562
+ def load_model_from_hub(model_id: str, device: str = "auto", verbose: bool = True):
563
+ from transformers import AutoModelForCausalLM, AutoTokenizer
564
+
565
+ # Import to register custom classes
566
+ from src.model import ChessConfig, ChessForCausalLM
567
+ from src.tokenizer import ChessTokenizer
568
+
569
+ tokenizer_source = None
570
+ try:
571
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
572
+ tokenizer_source = "AutoTokenizer (from Hub with trust_remote_code=True)"
573
+ except Exception as e:
574
+ if verbose:
575
+ print(f" AutoTokenizer failed: {e}")
576
+ tokenizer = ChessTokenizer.from_pretrained(model_id)
577
+ tokenizer_source = "ChessTokenizer (local class, vocab from Hub)"
578
+
579
+ model = AutoModelForCausalLM.from_pretrained(
580
+ model_id,
581
+ trust_remote_code=True,
582
+ device_map=device,
583
+ )
584
+
585
+ if verbose:
586
+ print(f" Tokenizer loaded via: {tokenizer_source}")
587
+ print(f" Tokenizer class: {type(tokenizer).__name__}")
588
+ print(f" Tokenizer vocab size: {tokenizer.vocab_size}")
589
+ if hasattr(tokenizer, "_vocab"):
590
+ print(f" Tokenizer has _vocab attribute: yes ({len(tokenizer._vocab)} entries)")
591
+
592
+ return model, tokenizer
593
+
594
+
595
+ def main():
596
+ parser = argparse.ArgumentParser(description="Evaluate a chess model")
597
+
598
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the model or Hugging Face model ID")
599
+ parser.add_argument("--mode", type=str, default="legal", choices=["legal", "winrate", "both"])
600
+ parser.add_argument("--stockfish_path", type=str, default=None, help="Path to Stockfish executable")
601
+ parser.add_argument("--stockfish_level", type=int, default=1, help="Stockfish skill level (0-20)")
602
+ parser.add_argument("--n_positions", type=int, default=500, help="Number of positions for legal move evaluation")
603
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
604
+ parser.add_argument("--n_games", type=int, default=100, help="Number of games to play for win rate evaluation")
605
+ parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
606
+
607
+ args = parser.parse_args()
608
+
609
+ print("=" * 60)
610
+ print("CHESS CHALLENGE - EVALUATION")
611
+ print("=" * 60)
612
+
613
+ print(f"\nLoading model from: {args.model_path}")
614
+
615
+ import os
616
+ is_local_path = os.path.exists(args.model_path)
617
+
618
+ if is_local_path:
619
+ # Local path
620
+ from transformers import AutoModelForCausalLM
621
+ from src.tokenizer import ChessTokenizer
622
+ from src.model import ChessConfig, ChessForCausalLM
623
+
624
+ tokenizer = ChessTokenizer.from_pretrained(args.model_path)
625
+
626
+ # IMPORTANT FIX:
627
+ # Our custom ChessForCausalLM does NOT support device_map="auto" unless _no_split_modules is defined.
628
+ # So we load normally and move to device explicitly.
629
+ device = "cuda" if torch.cuda.is_available() else "cpu"
630
+
631
+ model = AutoModelForCausalLM.from_pretrained(
632
+ args.model_path,
633
+ trust_remote_code=True,
634
+ )
635
+ model.to(device)
636
+ model.eval()
637
+ else:
638
+ if args.model_path.startswith(".") or args.model_path.startswith("/"):
639
+ raise FileNotFoundError(
640
+ f"Local model path not found: {args.model_path}\n"
641
+ f"Please check that the path exists and contains model files."
642
+ )
643
+ model, tokenizer = load_model_from_hub(args.model_path)
644
+
645
+ print(f"\nSetting up evaluator...")
646
+ evaluator = ChessEvaluator(
647
+ model=model,
648
+ tokenizer=tokenizer,
649
+ stockfish_path=args.stockfish_path,
650
+ stockfish_level=args.stockfish_level,
651
+ )
652
+
653
+ if args.mode in ["legal", "both"]:
654
+ print(f"\n" + "=" * 60)
655
+ print("PHASE 1: LEGAL MOVE EVALUATION")
656
+ print("=" * 60)
657
+ print(f"Testing {args.n_positions} random positions...")
658
+
659
+ legal_results = evaluator.evaluate_legal_moves(
660
+ n_positions=args.n_positions,
661
+ temperature=args.temperature,
662
+ verbose=True,
663
+ seed=args.seed,
664
+ )
665
+
666
+ print("\n" + "-" * 40)
667
+ print("LEGAL MOVE RESULTS")
668
+ print("-" * 40)
669
+ print(f" Positions tested: {legal_results['total_positions']}")
670
+ print(f" Legal (1st try): {legal_results['legal_first_try']} ({legal_results['legal_rate_first_try']:.1%})")
671
+ print(
672
+ f" Legal (with retry): {legal_results['legal_first_try'] + legal_results['legal_with_retry']}"
673
+ f" ({legal_results['legal_rate_with_retry']:.1%})"
674
+ )
675
+ print(f" Always illegal: {legal_results['illegal_all_retries']} ({legal_results['illegal_rate']:.1%})")
676
+
677
+ if args.mode in ["winrate", "both"]:
678
+ print(f"\n" + "=" * 60)
679
+ print("PHASE 2: WIN RATE EVALUATION")
680
+ print("=" * 60)
681
+ print(f"Playing {args.n_games} games against Stockfish (Level {args.stockfish_level})...")
682
+
683
+ winrate_results = evaluator.evaluate(
684
+ n_games=args.n_games,
685
+ temperature=args.temperature,
686
+ verbose=True,
687
+ )
688
+
689
+ print("\n" + "-" * 40)
690
+ print("WIN RATE RESULTS")
691
+ print("-" * 40)
692
+ print(f" Wins: {winrate_results['wins']}")
693
+ print(f" Losses: {winrate_results['losses']}")
694
+ print(f" Draws: {winrate_results['draws']}")
695
+ print(f"\n Win Rate: {winrate_results['win_rate']:.1%}")
696
+ print(f" Draw Rate: {winrate_results['draw_rate']:.1%}")
697
+ print(f" Loss Rate: {winrate_results['loss_rate']:.1%}")
698
+ print(f"\n Avg Game Length: {winrate_results['avg_game_length']:.1f} moves")
699
+ print(f" Illegal Move Rate: {winrate_results['illegal_move_rate']:.2%}")
700
+
701
+ if winrate_results.get("estimated_elo", None):
702
+ print(f"\n Estimated ELO: {winrate_results['estimated_elo']:.0f}")
703
+
704
+ print("\n" + "=" * 60)
705
+ print("EVALUATION COMPLETE")
706
+ print("=" * 60)
707
+
708
+
709
+ if __name__ == "__main__":
710
+ main()
src/model.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+
4
+ Modern small-LLM upgrades:
5
+ - RoPE (rotary positional embeddings): no learned positional embeddings needed
6
+ - RMSNorm (optional, default True)
7
+ - SwiGLU MLP (optional, default True)
8
+ - Weight tying (default True)
9
+ - Safe loss ignore_index = -100 (HF convention)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+
23
+
24
+ class ChessConfig(PretrainedConfig):
25
+ model_type = "chess_transformer"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size: int = 1200,
30
+
31
+ # Architecture (defaults tuned to be < 1M params for common vocabs)
32
+ n_embd: int = 112,
33
+ n_layer: int = 7,
34
+ n_head: int = 7,
35
+
36
+ # Context window
37
+ n_ctx: int = 512,
38
+
39
+ # MLP hidden size:
40
+ # - if mlp_type="swiglu", this is SwiGLU hidden size h
41
+ # - if mlp_type="gelu", this is FFN inner size
42
+ n_inner: Optional[int] = 192,
43
+
44
+ dropout: float = 0.05,
45
+ layer_norm_epsilon: float = 1e-6,
46
+
47
+ # Position encoding
48
+ use_rope: bool = True,
49
+ rope_theta: float = 10000.0,
50
+
51
+ # Normalization / MLP type
52
+ use_rmsnorm: bool = True,
53
+ mlp_type: str = "swiglu", # "swiglu" or "gelu"
54
+
55
+ # Weight tying
56
+ tie_weights: bool = True,
57
+
58
+ pad_token_id: int = 0,
59
+ bos_token_id: int = 1,
60
+ eos_token_id: int = 2,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ **kwargs,
68
+ )
69
+
70
+ if n_embd % n_head != 0:
71
+ raise ValueError(f"n_embd ({n_embd}) must be divisible by n_head ({n_head})")
72
+
73
+ head_dim = n_embd // n_head
74
+ if use_rope and (head_dim % 2 != 0):
75
+ raise ValueError(
76
+ f"RoPE requires even head_dim, got head_dim={head_dim}. "
77
+ f"Choose n_embd/n_head even."
78
+ )
79
+
80
+ self.vocab_size = vocab_size
81
+ self.n_embd = n_embd
82
+ self.n_layer = n_layer
83
+ self.n_head = n_head
84
+ self.n_ctx = n_ctx
85
+ self.n_inner = n_inner if n_inner is not None else (2 * n_embd)
86
+ self.dropout = dropout
87
+ self.layer_norm_epsilon = layer_norm_epsilon
88
+
89
+ self.use_rope = use_rope
90
+ self.rope_theta = rope_theta
91
+
92
+ self.use_rmsnorm = use_rmsnorm
93
+ self.mlp_type = mlp_type
94
+
95
+ self.tie_weights = tie_weights
96
+ # HF uses this field for embedding tying behavior
97
+ self.tie_word_embeddings = bool(tie_weights)
98
+
99
+
100
+ class RMSNorm(nn.Module):
101
+ def __init__(self, dim: int, eps: float = 1e-6):
102
+ super().__init__()
103
+ self.eps = eps
104
+ self.weight = nn.Parameter(torch.ones(dim))
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ norm = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
108
+ return x * norm * self.weight
109
+
110
+
111
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
112
+ x1 = x[..., 0::2]
113
+ x2 = x[..., 1::2]
114
+ out = torch.empty_like(x)
115
+ out[..., 0::2] = -x2
116
+ out[..., 1::2] = x1
117
+ return out
118
+
119
+
120
+ class RotaryEmbedding(nn.Module):
121
+ """
122
+ RoPE cache builder. Applies RoPE to q,k with shape (B,H,T,D).
123
+ """
124
+
125
+ def __init__(self, head_dim: int, theta: float = 10000.0):
126
+ super().__init__()
127
+ if head_dim % 2 != 0:
128
+ raise ValueError(f"RoPE requires even head_dim, got {head_dim}")
129
+
130
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
131
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
132
+
133
+ self._cos_cached = None
134
+ self._sin_cached = None
135
+ self._seq_len_cached = 0
136
+ self._device_cached = None
137
+ self._dtype_cached = None
138
+
139
+ def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
140
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
141
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq) # (T, D/2)
142
+
143
+ cos = freqs.cos().to(dtype=dtype)
144
+ sin = freqs.sin().to(dtype=dtype)
145
+
146
+ self._cos_cached = cos
147
+ self._sin_cached = sin
148
+ self._seq_len_cached = seq_len
149
+ self._device_cached = device
150
+ self._dtype_cached = dtype
151
+
152
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
153
+ # q,k: (B,H,T,D)
154
+ T = q.size(-2)
155
+ device = q.device
156
+ dtype = q.dtype
157
+
158
+ if (
159
+ self._cos_cached is None
160
+ or T > self._seq_len_cached
161
+ or device != self._device_cached
162
+ or dtype != self._dtype_cached
163
+ ):
164
+ self._build_cache(T, device, dtype)
165
+
166
+ cos = self._cos_cached[:T] # (T, D/2)
167
+ sin = self._sin_cached[:T] # (T, D/2)
168
+
169
+ # broadcast to (1,1,T,D) via repeat_interleave on last dim
170
+ cos = torch.repeat_interleave(cos.unsqueeze(0).unsqueeze(0), 2, dim=-1)
171
+ sin = torch.repeat_interleave(sin.unsqueeze(0).unsqueeze(0), 2, dim=-1)
172
+
173
+ q_out = (q * cos) + (rotate_half(q) * sin)
174
+ k_out = (k * cos) + (rotate_half(k) * sin)
175
+ return q_out, k_out
176
+
177
+
178
+ class MultiHeadAttention(nn.Module):
179
+ def __init__(self, config: ChessConfig):
180
+ super().__init__()
181
+
182
+ self.n_head = config.n_head
183
+ self.n_embd = config.n_embd
184
+ self.head_dim = config.n_embd // config.n_head
185
+
186
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
187
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
188
+ self.dropout = nn.Dropout(config.dropout)
189
+
190
+ self.use_rope = bool(config.use_rope)
191
+ self.rope = RotaryEmbedding(self.head_dim, theta=config.rope_theta) if self.use_rope else None
192
+
193
+ # causal mask buffer (expandable)
194
+ self.register_buffer(
195
+ "bias",
196
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx),
197
+ persistent=False,
198
+ )
199
+
200
+ def _ensure_causal_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype):
201
+ if self.bias.size(-1) >= seq_len and self.bias.device == device:
202
+ return
203
+ self.bias = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=dtype)).view(1, 1, seq_len, seq_len)
204
+
205
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ B, T, _ = x.size()
207
+
208
+ qkv = self.c_attn(x)
209
+ q, k, v = qkv.split(self.n_embd, dim=2)
210
+
211
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B,H,T,D)
212
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
213
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
214
+
215
+ if self.use_rope:
216
+ q, k = self.rope(q, k)
217
+
218
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
219
+
220
+ self._ensure_causal_mask(T, attn.device, attn.dtype)
221
+ causal_mask = self.bias[:, :, :T, :T]
222
+ mask_value = torch.finfo(attn.dtype).min
223
+ attn = attn.masked_fill(causal_mask == 0, mask_value)
224
+
225
+ # padding mask (1=keep, 0=mask)
226
+ if attention_mask is not None:
227
+ am = attention_mask.unsqueeze(1).unsqueeze(2) # (B,1,1,T)
228
+ attn = attn.masked_fill(am == 0, mask_value)
229
+
230
+ attn = F.softmax(attn, dim=-1)
231
+ attn = self.dropout(attn)
232
+
233
+ y = torch.matmul(attn, v) # (B,H,T,D)
234
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_embd)
235
+
236
+ y = self.c_proj(y)
237
+ y = self.dropout(y)
238
+ return y
239
+
240
+
241
+ class SwiGLU(nn.Module):
242
+ def __init__(self, config: ChessConfig):
243
+ super().__init__()
244
+ h = config.n_inner
245
+ self.w12 = nn.Linear(config.n_embd, 2 * h)
246
+ self.w3 = nn.Linear(h, config.n_embd)
247
+ self.dropout = nn.Dropout(config.dropout)
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x12 = self.w12(x)
251
+ x1, x2 = x12.chunk(2, dim=-1)
252
+ x = F.silu(x1) * x2
253
+ x = self.w3(x)
254
+ x = self.dropout(x)
255
+ return x
256
+
257
+
258
+ class FeedForwardGELU(nn.Module):
259
+ def __init__(self, config: ChessConfig):
260
+ super().__init__()
261
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
262
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
263
+ self.dropout = nn.Dropout(config.dropout)
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ x = self.c_fc(x)
267
+ x = F.gelu(x)
268
+ x = self.c_proj(x)
269
+ x = self.dropout(x)
270
+ return x
271
+
272
+
273
+ class TransformerBlock(nn.Module):
274
+ def __init__(self, config: ChessConfig):
275
+ super().__init__()
276
+
277
+ if config.use_rmsnorm:
278
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
279
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
280
+ else:
281
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
282
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
283
+
284
+ self.attn = MultiHeadAttention(config)
285
+
286
+ if config.mlp_type.lower() == "swiglu":
287
+ self.mlp = SwiGLU(config)
288
+ else:
289
+ self.mlp = FeedForwardGELU(config)
290
+
291
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
292
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
293
+ x = x + self.mlp(self.ln_2(x))
294
+ return x
295
+
296
+
297
+ class ChessForCausalLM(PreTrainedModel):
298
+ config_class = ChessConfig
299
+ base_model_prefix = "transformer"
300
+ supports_gradient_checkpointing = True
301
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
302
+ _no_split_modules = ["TransformerBlock"]
303
+
304
+
305
+ def __init__(self, config: ChessConfig):
306
+ super().__init__(config)
307
+
308
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
309
+
310
+ # learned positional embeddings only if RoPE disabled
311
+ self.wpe = None
312
+ if not config.use_rope:
313
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
314
+
315
+ self.drop = nn.Dropout(config.dropout)
316
+ self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
317
+
318
+ if config.use_rmsnorm:
319
+ self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
320
+ else:
321
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
322
+
323
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
324
+
325
+ if config.tie_weights:
326
+ self._tied_weights_keys = ["lm_head.weight"]
327
+
328
+ self.post_init()
329
+
330
+ if config.tie_weights:
331
+ self.tie_weights()
332
+
333
+ def get_input_embeddings(self) -> nn.Module:
334
+ return self.wte
335
+
336
+ def set_input_embeddings(self, new_embeddings: nn.Module):
337
+ self.wte = new_embeddings
338
+ if getattr(self.config, "tie_weights", False):
339
+ self.tie_weights()
340
+
341
+ def get_output_embeddings(self) -> nn.Module:
342
+ return self.lm_head
343
+
344
+ def set_output_embeddings(self, new_embeddings: nn.Module):
345
+ self.lm_head = new_embeddings
346
+
347
+ def tie_weights(self):
348
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
349
+ self._tie_or_clone_weights(self.lm_head, self.wte)
350
+
351
+ def _init_weights(self, module: nn.Module):
352
+ if isinstance(module, nn.Linear):
353
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
354
+ if module.bias is not None:
355
+ torch.nn.init.zeros_(module.bias)
356
+ elif isinstance(module, nn.Embedding):
357
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
358
+
359
+ def forward(
360
+ self,
361
+ input_ids: torch.LongTensor,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ labels: Optional[torch.LongTensor] = None,
365
+ return_dict: Optional[bool] = None,
366
+ **kwargs,
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+ B, T = input_ids.size()
370
+ device = input_ids.device
371
+
372
+ x = self.wte(input_ids)
373
+
374
+ if self.wpe is not None:
375
+ if position_ids is None:
376
+ position_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
377
+ x = x + self.wpe(position_ids)
378
+
379
+ x = self.drop(x)
380
+
381
+ for block in self.h:
382
+ x = block(x, attention_mask=attention_mask)
383
+
384
+ x = self.ln_f(x)
385
+ logits = self.lm_head(x)
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ shift_logits = logits[..., :-1, :].contiguous()
390
+ shift_labels = labels[..., 1:].contiguous()
391
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
392
+ loss = loss_fct(
393
+ shift_logits.view(-1, shift_logits.size(-1)),
394
+ shift_labels.view(-1),
395
+ )
396
+
397
+ if not return_dict:
398
+ output = (logits,)
399
+ return ((loss,) + output) if loss is not None else output
400
+
401
+ return CausalLMOutputWithPast(
402
+ loss=loss,
403
+ logits=logits,
404
+ past_key_values=None,
405
+ hidden_states=None,
406
+ attentions=None,
407
+ )
408
+
409
+ @torch.no_grad()
410
+ def generate_move(
411
+ self,
412
+ input_ids: torch.LongTensor,
413
+ temperature: float = 0.7,
414
+ top_k: Optional[int] = 50,
415
+ top_p: Optional[float] = None,
416
+ ) -> int:
417
+ self.eval()
418
+
419
+ outputs = self(input_ids)
420
+ logits = outputs.logits[:, -1, :] / max(float(temperature), 1e-6)
421
+
422
+ if top_k is not None and top_k > 0:
423
+ k = min(int(top_k), logits.size(-1))
424
+ thresh = torch.topk(logits, k)[0][..., -1, None]
425
+ logits = logits.masked_fill(logits < thresh, torch.finfo(logits.dtype).min)
426
+
427
+ if top_p is not None:
428
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
429
+ probs = F.softmax(sorted_logits, dim=-1)
430
+ cum = torch.cumsum(probs, dim=-1)
431
+ to_remove = cum > float(top_p)
432
+ to_remove[..., 1:] = to_remove[..., :-1].clone()
433
+ to_remove[..., 0] = 0
434
+ indices_to_remove = to_remove.scatter(dim=-1, index=sorted_indices, src=to_remove)
435
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
436
+
437
+ probs = F.softmax(logits, dim=-1)
438
+ next_token = torch.multinomial(probs, num_samples=1)
439
+ return int(next_token.item())
440
+
441
+
442
+ # Register the model with Auto classes
443
+ from transformers import AutoConfig, AutoModelForCausalLM
444
+
445
+ AutoConfig.register("chess_transformer", ChessConfig)
446
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
src/tokenizer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Decomposed Chess Tokenizer for the Chess Challenge.
3
+
4
+ Each move becomes 3 or 4 tokens:
5
+ WP e2_f e4_t
6
+ BN g8_f f6_t
7
+ Promotion adds an extra token:
8
+ WP e7_f e8_t =q
9
+
10
+ Why this helps:
11
+ - Fixed small vocab (~150 tokens)
12
+ - Near-zero OOV / UNK, so the evaluator can always parse squares
13
+ - Compatible with the provided evaluate.py (it auto-detects 'decomposed')
14
+
15
+ Special tokens behavior:
16
+ - Adds BOS only (NO EOS)
17
+ - If BOS already present, does not add it twice
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import os
24
+ from typing import Dict, List, Optional
25
+
26
+ from transformers import PreTrainedTokenizer
27
+
28
+
29
+ class ChessTokenizer(PreTrainedTokenizer):
30
+ model_input_names = ["input_ids", "attention_mask"]
31
+ vocab_files_names = {"vocab_file": "vocab.json"}
32
+
33
+ PAD_TOKEN = "[PAD]"
34
+ BOS_TOKEN = "[BOS]"
35
+ EOS_TOKEN = "[EOS]" # kept for compatibility, not auto-added
36
+ UNK_TOKEN = "[UNK]"
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_file: Optional[str] = None,
41
+ vocab: Optional[Dict[str, int]] = None,
42
+ **kwargs,
43
+ ):
44
+ self._pad_token = self.PAD_TOKEN
45
+ self._bos_token = self.BOS_TOKEN
46
+ self._eos_token = self.EOS_TOKEN
47
+ self._unk_token = self.UNK_TOKEN
48
+
49
+ # avoid duplicates from kwargs
50
+ kwargs.pop("pad_token", None)
51
+ kwargs.pop("bos_token", None)
52
+ kwargs.pop("eos_token", None)
53
+ kwargs.pop("unk_token", None)
54
+
55
+ if vocab is not None:
56
+ self._vocab = vocab
57
+ elif vocab_file is not None and os.path.exists(vocab_file):
58
+ with open(vocab_file, "r", encoding="utf-8") as f:
59
+ self._vocab = json.load(f)
60
+ else:
61
+ self._vocab = self._build_fixed_vocab()
62
+
63
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
64
+
65
+ super().__init__(
66
+ pad_token=self._pad_token,
67
+ bos_token=self._bos_token,
68
+ eos_token=self._eos_token,
69
+ unk_token=self._unk_token,
70
+ **kwargs,
71
+ )
72
+
73
+ # --------------------------
74
+ # Fixed vocab: pieces + squares + promos
75
+ # --------------------------
76
+ @staticmethod
77
+ def _all_squares() -> List[str]:
78
+ files = "abcdefgh"
79
+ ranks = "12345678"
80
+ return [f + r for r in ranks for f in files] # a1..h8
81
+
82
+ def _build_fixed_vocab(self) -> Dict[str, int]:
83
+ special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
84
+
85
+ # piece tokens: WP..WK, BP..BK
86
+ piece_tokens = [f"{c}{p}" for c in "WB" for p in "PNBRQK"]
87
+
88
+ squares = self._all_squares()
89
+ from_tokens = [f"{sq}_f" for sq in squares]
90
+ to_tokens = [f"{sq}_t" for sq in squares]
91
+
92
+ promo_tokens = ["=q", "=r", "=b", "=n"]
93
+
94
+ tokens = special + piece_tokens + from_tokens + to_tokens + promo_tokens
95
+ return {tok: i for i, tok in enumerate(tokens)}
96
+
97
+ # --------------------------
98
+ # Special tokens handling (robust with evaluate.py)
99
+ # --------------------------
100
+ def build_inputs_with_special_tokens(
101
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
102
+ ) -> List[int]:
103
+ # BOS only, NO EOS
104
+ if token_ids_1 is not None:
105
+ token_ids_0 = token_ids_0 + token_ids_1
106
+
107
+ if token_ids_0 and token_ids_0[0] == self.bos_token_id:
108
+ return token_ids_0
109
+ return [self.bos_token_id] + token_ids_0
110
+
111
+ def get_special_tokens_mask(
112
+ self,
113
+ token_ids_0: List[int],
114
+ token_ids_1: Optional[List[int]] = None,
115
+ already_has_special_tokens: bool = False,
116
+ ) -> List[int]:
117
+ if already_has_special_tokens:
118
+ specials = {self.pad_token_id, self.bos_token_id, self.eos_token_id, self.unk_token_id}
119
+ return [1 if t in specials else 0 for t in token_ids_0]
120
+
121
+ if token_ids_1 is None:
122
+ return [1] + [0] * len(token_ids_0)
123
+ return [1] + [0] * (len(token_ids_0) + len(token_ids_1))
124
+
125
+ def create_token_type_ids_from_sequences(
126
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
127
+ ) -> List[int]:
128
+ if token_ids_1 is None:
129
+ return [0] * (len(token_ids_0) + 1)
130
+ return [0] * (len(token_ids_0) + len(token_ids_1) + 1)
131
+
132
+ # --------------------------
133
+ # Tokenization
134
+ # --------------------------
135
+ def _tokenize(self, text: str) -> List[str]:
136
+ if not text or not text.strip():
137
+ return []
138
+
139
+ parts = text.strip().split()
140
+ out: List[str] = []
141
+
142
+ for tok in parts:
143
+ # allow literal special tokens present in text
144
+ if tok in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}:
145
+ out.append(tok)
146
+ continue
147
+
148
+ # already decomposed tokens
149
+ if (len(tok) == 2 and tok[0] in "WB" and tok[1] in "PNBRQK") or tok.endswith("_f") or tok.endswith("_t") or tok in {"=q", "=r", "=b", "=n"}:
150
+ out.append(tok)
151
+ continue
152
+
153
+ # parse extended UCI (dataset): WPe2e4, BNg8f6(x), WPe7e8=Q(+), ...
154
+ if len(tok) < 6:
155
+ out.append(self.UNK_TOKEN)
156
+ continue
157
+
158
+ color = tok[0]
159
+ piece = tok[1]
160
+ from_sq = tok[2:4]
161
+ to_sq = tok[4:6]
162
+
163
+ out.append(f"{color}{piece}")
164
+ out.append(f"{from_sq}_f")
165
+ out.append(f"{to_sq}_t")
166
+
167
+ # promotion like "=Q"
168
+ if "=" in tok:
169
+ try:
170
+ promo_part = tok.split("=", 1)[1]
171
+ promo_letter = promo_part[0].lower()
172
+ promo_tok = f"={promo_letter}"
173
+ if promo_tok in self._vocab:
174
+ out.append(promo_tok)
175
+ except Exception:
176
+ pass
177
+
178
+ return out
179
+
180
+ def _convert_token_to_id(self, token: str) -> int:
181
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
182
+
183
+ def _convert_id_to_token(self, index: int) -> str:
184
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
185
+
186
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
187
+ return " ".join(tokens)
188
+
189
+ # --------------------------
190
+ # Vocab I/O
191
+ # --------------------------
192
+ @property
193
+ def vocab_size(self) -> int:
194
+ return len(self._vocab)
195
+
196
+ def get_vocab(self) -> Dict[str, int]:
197
+ return dict(self._vocab)
198
+
199
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
200
+ os.makedirs(save_directory, exist_ok=True)
201
+ vocab_file = os.path.join(
202
+ save_directory,
203
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
204
+ )
205
+ with open(vocab_file, "w", encoding="utf-8") as f:
206
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
207
+ return (vocab_file,)
src/train.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for the Chess Challenge.
3
+
4
+ GPU-optimized version (still compatible with older transformers/accelerate):
5
+ - Uses fp16/bf16 automatically on GPU
6
+ - Uses evaluation + saving per EPOCH by default (much faster than steps)
7
+ - Enables dataloader_num_workers + pin_memory on GPU
8
+ - Optional torch.compile for speed (safe-guarded)
9
+ - Keeps your robust TrainingArguments compatibility (evaluation_strategy vs eval_strategy)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import os
16
+ import warnings
17
+ from pathlib import Path
18
+
19
+ warnings.filterwarnings("ignore", message="'return' in a 'finally' block")
20
+
21
+ import torch
22
+ from transformers import Trainer, TrainingArguments, set_seed
23
+
24
+ from src.data import ChessDataCollator, create_train_val_datasets
25
+ from src.model import ChessConfig, ChessForCausalLM
26
+ from src.tokenizer import ChessTokenizer
27
+ from src.utils import count_parameters, print_parameter_budget
28
+
29
+
30
+ def parse_args():
31
+ p = argparse.ArgumentParser(description="Train a chess-playing language model")
32
+
33
+ # ---------------- Model ----------------
34
+ p.add_argument("--n_embd", type=int, default=128, help="Embedding dimension")
35
+ p.add_argument("--n_layer", type=int, default=6, help="Number of transformer layers")
36
+ p.add_argument("--n_head", type=int, default=8, help="Number of attention heads")
37
+ # For speed on GPU, 256 is often a great default; override via CLI if needed.
38
+ p.add_argument("--n_ctx", type=int, default=256, help="Maximum context length")
39
+
40
+ p.add_argument("--n_inner", type=int, default=248, help="MLP hidden size (SwiGLU: h)")
41
+ p.add_argument("--dropout", type=float, default=0.05, help="Dropout probability")
42
+ p.add_argument("--no_tie_weights", action="store_true", help="Disable weight tying")
43
+
44
+ # improved model.py flags
45
+ p.add_argument("--use_rope", action="store_true", help="Use RoPE (recommended)")
46
+ p.add_argument("--mlp_type", type=str, default="swiglu", choices=["swiglu", "gelu"], help="MLP type")
47
+ p.add_argument("--use_rmsnorm", action="store_true", help="Use RMSNorm (recommended)")
48
+
49
+ # ---------------- Data ----------------
50
+ p.add_argument("--dataset_name", type=str, default="dlouapre/lichess_2025-01_1M")
51
+ p.add_argument("--max_train_samples", type=int, default=None, help="Optional cap for train samples")
52
+ p.add_argument("--val_samples", type=int, default=5000)
53
+
54
+ p.add_argument(
55
+ "--tokenizer_dir",
56
+ type=str,
57
+ default="./tokenizer_cache",
58
+ help="Where to save/load the tokenizer (vocab.json)",
59
+ )
60
+
61
+ # ---------------- Training ----------------
62
+ p.add_argument("--output_dir", type=str, default="./output")
63
+ p.add_argument("--num_train_epochs", type=int, default=3)
64
+
65
+ # For speed: prefer larger batch and smaller accumulation.
66
+ p.add_argument("--per_device_train_batch_size", type=int, default=64)
67
+ p.add_argument("--per_device_eval_batch_size", type=int, default=128)
68
+ p.add_argument("--gradient_accumulation_steps", type=int, default=1)
69
+
70
+ p.add_argument("--learning_rate", type=float, default=3e-4)
71
+ p.add_argument("--weight_decay", type=float, default=0.1)
72
+ p.add_argument("--warmup_steps", type=int, default=300)
73
+
74
+ p.add_argument("--seed", type=int, default=42)
75
+
76
+ # ---------------- Logging / Save ----------------
77
+ p.add_argument("--logging_steps", type=int, default=50)
78
+
79
+ # Eval/save config: epoch by default (much faster). Still allow steps if user wants.
80
+ p.add_argument("--eval_strategy", type=str, default="epoch", choices=["epoch", "steps"], help="Evaluation strategy")
81
+ p.add_argument("--save_strategy", type=str, default="epoch", choices=["epoch", "steps"], help="Save strategy")
82
+ p.add_argument("--eval_steps", type=int, default=1000, help="Only used if eval_strategy=steps")
83
+ p.add_argument("--save_steps", type=int, default=1000, help="Only used if save_strategy=steps")
84
+
85
+ # ---------------- Speed knobs ----------------
86
+ p.add_argument("--dataloader_num_workers", type=int, default=2, help="CPU workers for dataloader")
87
+ p.add_argument("--torch_compile", action="store_true", help="Enable torch.compile on GPU (can speed up)")
88
+
89
+ return p.parse_args()
90
+
91
+
92
+ def load_or_create_tokenizer(args) -> ChessTokenizer:
93
+ tok_dir = Path(args.tokenizer_dir)
94
+ tok_dir.mkdir(parents=True, exist_ok=True)
95
+
96
+ vocab_path = tok_dir / "vocab.json"
97
+ if vocab_path.exists():
98
+ print(f"Loading tokenizer from {tok_dir} ...")
99
+ return ChessTokenizer(vocab_file=str(vocab_path))
100
+
101
+ print("Creating fixed-vocab tokenizer (decomposed) ...")
102
+ tok = ChessTokenizer()
103
+ tok.save_pretrained(str(tok_dir))
104
+ print(f"Tokenizer saved to {tok_dir} (vocab_size={tok.vocab_size})")
105
+ return tok
106
+
107
+
108
+ def _make_training_args(args) -> TrainingArguments:
109
+ """
110
+ Compatibility layer for transformers versions:
111
+ - some use evaluation_strategy, others use eval_strategy
112
+ - we keep it robust while using faster defaults (epoch eval/save).
113
+ """
114
+ use_gpu = torch.cuda.is_available()
115
+ use_bf16 = bool(use_gpu and torch.cuda.is_bf16_supported())
116
+ use_fp16 = bool(use_gpu and not use_bf16)
117
+
118
+ common = dict(
119
+ output_dir=args.output_dir,
120
+ num_train_epochs=args.num_train_epochs,
121
+
122
+ per_device_train_batch_size=args.per_device_train_batch_size,
123
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
124
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
125
+
126
+ learning_rate=args.learning_rate,
127
+ weight_decay=args.weight_decay,
128
+ warmup_steps=args.warmup_steps,
129
+ lr_scheduler_type="cosine",
130
+
131
+ max_grad_norm=1.0,
132
+
133
+ logging_dir=os.path.join(args.output_dir, "logs"),
134
+ logging_steps=args.logging_steps,
135
+
136
+ save_total_limit=2,
137
+ load_best_model_at_end=True,
138
+ metric_for_best_model="eval_loss",
139
+ greater_is_better=False,
140
+
141
+ seed=args.seed,
142
+ report_to=["none"],
143
+
144
+ # Mixed precision for GPU speed
145
+ fp16=use_fp16,
146
+ bf16=use_bf16,
147
+
148
+ # DataLoader perf
149
+ dataloader_num_workers=args.dataloader_num_workers,
150
+ dataloader_pin_memory=use_gpu,
151
+
152
+ # Important for custom batches
153
+ remove_unused_columns=False,
154
+ )
155
+
156
+ # Build kwargs depending on epoch vs steps
157
+ eval_kwargs = {}
158
+ if args.eval_strategy == "steps":
159
+ eval_kwargs["eval_steps"] = args.eval_steps
160
+ save_kwargs = {}
161
+ if args.save_strategy == "steps":
162
+ save_kwargs["save_steps"] = args.save_steps
163
+
164
+ # Try standard HF arg names first
165
+ try:
166
+ return TrainingArguments(
167
+ **common,
168
+ evaluation_strategy=args.eval_strategy,
169
+ save_strategy=args.save_strategy,
170
+ **eval_kwargs,
171
+ **save_kwargs,
172
+ )
173
+ except TypeError:
174
+ # Fallback for forks/older variants that renamed args
175
+ return TrainingArguments(
176
+ **common,
177
+ eval_strategy=args.eval_strategy,
178
+ save_strategy=args.save_strategy,
179
+ **eval_kwargs,
180
+ **save_kwargs,
181
+ )
182
+
183
+
184
+ def main():
185
+ args = parse_args()
186
+ set_seed(args.seed)
187
+
188
+ print("=" * 60)
189
+ print("CHESS CHALLENGE - TRAINING")
190
+ print("=" * 60)
191
+
192
+ tokenizer = load_or_create_tokenizer(args)
193
+ actual_vocab_size = tokenizer.vocab_size
194
+ print(f" Vocab size used: {actual_vocab_size}")
195
+
196
+ print("\nCreating model configuration...")
197
+ config = ChessConfig(
198
+ vocab_size=actual_vocab_size,
199
+ n_embd=args.n_embd,
200
+ n_layer=args.n_layer,
201
+ n_head=args.n_head,
202
+ n_ctx=args.n_ctx,
203
+ n_inner=args.n_inner,
204
+ dropout=args.dropout,
205
+ tie_weights=not args.no_tie_weights,
206
+ pad_token_id=tokenizer.pad_token_id,
207
+ bos_token_id=tokenizer.bos_token_id,
208
+ eos_token_id=tokenizer.eos_token_id,
209
+ use_rope=bool(args.use_rope),
210
+ mlp_type=args.mlp_type,
211
+ use_rmsnorm=bool(args.use_rmsnorm),
212
+ )
213
+
214
+ print_parameter_budget(config)
215
+
216
+ print("\nCreating model...")
217
+ model = ChessForCausalLM(config)
218
+
219
+ # Optional torch.compile (GPU only)
220
+ if args.torch_compile and torch.cuda.is_available():
221
+ try:
222
+ model = torch.compile(model)
223
+ print("✓ torch.compile enabled")
224
+ except Exception as e:
225
+ print(f"WARNING: torch.compile failed ({e}). Continuing without it.")
226
+
227
+ n_params = count_parameters(model)
228
+ print(f" Total parameters: {n_params:,}")
229
+ print("✓ Model is within 1M parameter limit" if n_params <= 1_000_000 else "WARNING: Model exceeds 1M!")
230
+
231
+ print("\nLoading datasets...")
232
+ train_dataset, val_dataset = create_train_val_datasets(
233
+ tokenizer=tokenizer,
234
+ dataset_name=args.dataset_name,
235
+ max_length=args.n_ctx,
236
+ train_samples=args.max_train_samples,
237
+ val_samples=args.val_samples,
238
+ )
239
+ print(f" Training samples: {len(train_dataset):,}")
240
+ print(f" Validation samples: {len(val_dataset):,}")
241
+
242
+ data_collator = ChessDataCollator(tokenizer, max_length=args.n_ctx)
243
+
244
+ training_args = _make_training_args(args)
245
+
246
+ trainer = Trainer(
247
+ model=model,
248
+ args=training_args,
249
+ train_dataset=train_dataset,
250
+ eval_dataset=val_dataset,
251
+ data_collator=data_collator,
252
+ tokenizer=tokenizer,
253
+ )
254
+
255
+ print("\nStarting training...")
256
+ trainer.train()
257
+
258
+ out_final = os.path.join(args.output_dir, "final_model")
259
+ print("\nSaving final model...")
260
+ trainer.save_model(out_final)
261
+ tokenizer.save_pretrained(out_final)
262
+
263
+ print("\nTraining complete!")
264
+ print(f" Model saved to: {out_final}")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()
src/utils.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Chess Challenge.
3
+
4
+ This module provides helper functions for:
5
+ - Parameter counting and budget analysis (including RoPE / SwiGLU / RMSNorm variants)
6
+ - Move validation and conversion with python-chess
7
+ - Optional: compute legal-move rate over a whole game string
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ from typing import Dict, Optional, TYPE_CHECKING
14
+
15
+ import torch.nn as nn
16
+
17
+ if TYPE_CHECKING:
18
+ from src.model import ChessConfig
19
+
20
+
21
+ # =========================
22
+ # Parameter counting
23
+ # =========================
24
+
25
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
26
+ """
27
+ Count the number of parameters in a model.
28
+
29
+ Args:
30
+ model: The PyTorch model.
31
+ trainable_only: If True, only count trainable parameters.
32
+
33
+ Returns:
34
+ Total number of parameters.
35
+ """
36
+ if trainable_only:
37
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ return sum(p.numel() for p in model.parameters())
39
+
40
+
41
+ def count_parameters_by_component(model: nn.Module) -> Dict[str, int]:
42
+ """
43
+ Count parameters broken down by leaf modules.
44
+
45
+ Args:
46
+ model: The PyTorch model.
47
+
48
+ Returns:
49
+ Dictionary mapping module names to parameter counts.
50
+ """
51
+ counts: Dict[str, int] = {}
52
+ for name, module in model.named_modules():
53
+ if len(list(module.children())) == 0: # leaf module
54
+ param_count = sum(p.numel() for p in module.parameters(recurse=False))
55
+ if param_count > 0:
56
+ counts[name] = param_count
57
+ return counts
58
+
59
+
60
+ def estimate_parameters(config: "ChessConfig") -> Dict[str, int]:
61
+ """
62
+ Estimate parameter count for a configuration.
63
+
64
+ Works for:
65
+ - learned position embeddings (wpe) or RoPE (no pos params)
66
+ - GELU FFN (d -> n_inner -> d)
67
+ - SwiGLU FFN (d -> 2h, h -> d) where h = n_inner
68
+ - LayerNorm (weight+bias) vs RMSNorm (weight only)
69
+ - tied or untied LM head
70
+
71
+ NOTE: This is an estimate of *weights + biases* for the common implementation
72
+ patterns used in this repo.
73
+ """
74
+ V = int(config.vocab_size)
75
+ d = int(config.n_embd)
76
+ L = int(config.n_layer)
77
+ n_ctx = int(config.n_ctx)
78
+ n_inner = int(config.n_inner)
79
+
80
+ use_rope = bool(getattr(config, "use_rope", False))
81
+ use_rmsnorm = bool(getattr(config, "use_rmsnorm", False))
82
+ mlp_type = str(getattr(config, "mlp_type", "gelu")).lower()
83
+ tie = bool(getattr(config, "tie_weights", True))
84
+
85
+ # Embeddings
86
+ token_embeddings = V * d
87
+ position_embeddings = 0 if use_rope else (n_ctx * d)
88
+
89
+ # Attention per layer:
90
+ # c_attn: d -> 3d : weight 3d*d, bias 3d
91
+ # c_proj: d -> d : weight d*d, bias d
92
+ attn_qkv_per_layer = 3 * d * d + 3 * d
93
+ attn_proj_per_layer = d * d + d
94
+
95
+ # FFN per layer
96
+ if mlp_type == "swiglu":
97
+ # w12: d -> 2h : weight 2h*d, bias 2h
98
+ # w3: h -> d : weight d*h, bias d
99
+ h = n_inner
100
+ ffn_per_layer = (2 * h * d + 2 * h) + (d * h + d) # 3*d*h + (2h + d)
101
+ else:
102
+ # GELU: d -> n_inner -> d
103
+ ffn_per_layer = (d * n_inner + n_inner) + (n_inner * d + d) # 2*d*n_inner + (n_inner + d)
104
+
105
+ # Norm params
106
+ # LayerNorm: weight+bias => 2d ; RMSNorm: weight => d
107
+ norm_params = d if use_rmsnorm else 2 * d
108
+ norms_per_layer = 2 * norm_params # ln_1 + ln_2
109
+ final_norm = norm_params
110
+
111
+ per_layer = attn_qkv_per_layer + attn_proj_per_layer + ffn_per_layer + norms_per_layer
112
+ total_transformer_layers = L * per_layer
113
+
114
+ # LM head
115
+ # In this repo, lm_head is typically Linear(d, V, bias=False).
116
+ # If untied, count V*d parameters.
117
+ lm_head = 0 if tie else (V * d)
118
+
119
+ total = token_embeddings + position_embeddings + total_transformer_layers + final_norm + lm_head
120
+
121
+ return {
122
+ "token_embeddings": token_embeddings,
123
+ "position_embeddings": position_embeddings,
124
+ "attention_qkv_per_layer": attn_qkv_per_layer,
125
+ "attention_proj_per_layer": attn_proj_per_layer,
126
+ "ffn_per_layer": ffn_per_layer,
127
+ "norms_per_layer": norms_per_layer,
128
+ "final_norm": final_norm,
129
+ "total_transformer_layers": total_transformer_layers,
130
+ "lm_head": lm_head,
131
+ "total": total,
132
+ "notes": {
133
+ "use_rope": use_rope,
134
+ "use_rmsnorm": use_rmsnorm,
135
+ "mlp_type": mlp_type,
136
+ "tie_weights": tie,
137
+ },
138
+ }
139
+
140
+
141
+ def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None:
142
+ """
143
+ Print a formatted parameter budget analysis.
144
+
145
+ Args:
146
+ config: Model configuration.
147
+ limit: Parameter limit.
148
+ """
149
+ est = estimate_parameters(config)
150
+
151
+ print("=" * 60)
152
+ print("PARAMETER BUDGET ANALYSIS")
153
+ print("=" * 60)
154
+ print("\nConfiguration:")
155
+ print(f" vocab_size (V) = {config.vocab_size}")
156
+ print(f" n_embd (d) = {config.n_embd}")
157
+ print(f" n_layer (L) = {config.n_layer}")
158
+ print(f" n_head = {config.n_head}")
159
+ print(f" n_ctx = {config.n_ctx}")
160
+ print(f" n_inner = {config.n_inner}")
161
+ print(f" tie_weights = {getattr(config, 'tie_weights', True)}")
162
+ if hasattr(config, "use_rope"):
163
+ print(f" use_rope = {getattr(config, 'use_rope', False)}")
164
+ if hasattr(config, "mlp_type"):
165
+ print(f" mlp_type = {getattr(config, 'mlp_type', 'gelu')}")
166
+ if hasattr(config, "use_rmsnorm"):
167
+ print(f" use_rmsnorm = {getattr(config, 'use_rmsnorm', False)}")
168
+
169
+ print("\nParameter Breakdown (estimate):")
170
+ print(f" Token Embeddings: {est['token_embeddings']:>10,}")
171
+ print(f" Position Embeddings: {est['position_embeddings']:>10,}")
172
+ print(f" Transformer Layers: {est['total_transformer_layers']:>10,}")
173
+ print(f" Final Norm: {est['final_norm']:>10,}")
174
+ if getattr(config, "tie_weights", True):
175
+ print(f" LM Head: {'(tied)':>10}")
176
+ else:
177
+ print(f" LM Head: {est['lm_head']:>10,}")
178
+
179
+ print(" " + "-" * 32)
180
+ print(f" TOTAL: {est['total']:>10,}")
181
+
182
+ remaining = limit - est["total"]
183
+ print("\nBudget Status:")
184
+ print(f" Limit: {limit:>10,}")
185
+ print(f" Used: {est['total']:>10,}")
186
+ print(f" Remaining: {remaining:>10,}")
187
+
188
+ if est["total"] <= limit:
189
+ print(f"\n✓ Within budget! ({est['total'] / limit * 100:.1f}% used)")
190
+ else:
191
+ print(f"\n✗ OVER BUDGET by {-remaining:,} parameters!")
192
+ print("=" * 60)
193
+
194
+
195
+ # =========================
196
+ # Move conversion / validation (python-chess)
197
+ # =========================
198
+
199
+ def convert_extended_uci_to_uci(move: str) -> str:
200
+ """
201
+ Convert extended UCI format to standard UCI format.
202
+
203
+ Extended UCI format (dataset):
204
+ [W|B][Piece][from_sq][to_sq][suffixes...]
205
+ e.g. "WPe2e4", "BNg8f6(x)", "WKe1g1(o)", "WPe7e8=Q(+)"
206
+ Standard UCI:
207
+ "e2e4", "g8f6", "e1g1", "e7e8q"
208
+ """
209
+ if len(move) < 6:
210
+ return move
211
+
212
+ from_sq = move[2:4]
213
+ to_sq = move[4:6]
214
+
215
+ promotion = ""
216
+ if "=" in move:
217
+ promo_idx = move.index("=")
218
+ if promo_idx + 1 < len(move):
219
+ promotion = move[promo_idx + 1].lower()
220
+
221
+ return from_sq + to_sq + promotion
222
+
223
+
224
+ def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool:
225
+ """
226
+ Validate a single move using python-chess against a given board state.
227
+
228
+ IMPORTANT:
229
+ - If board_fen is None, validation is against the initial position.
230
+ For validating a *game*, use `legal_rate_game_text` which advances the board.
231
+
232
+ Args:
233
+ move: Move in extended UCI format.
234
+ board_fen: FEN string of the current board (optional).
235
+
236
+ Returns:
237
+ True if move is legal on that board, else False.
238
+ """
239
+ try:
240
+ import chess
241
+ except ImportError:
242
+ raise ImportError(
243
+ "python-chess is required for move validation. Install it with: pip install python-chess"
244
+ )
245
+
246
+ if len(move) < 6:
247
+ return False
248
+
249
+ board = chess.Board(board_fen) if board_fen else chess.Board()
250
+ uci_move = convert_extended_uci_to_uci(move)
251
+
252
+ try:
253
+ move_obj = chess.Move.from_uci(uci_move)
254
+ return move_obj in board.legal_moves
255
+ except Exception:
256
+ return False
257
+
258
+
259
+ def legal_rate_game_text(game_text: str, stop_on_illegal: bool = True) -> float:
260
+ """
261
+ Compute the fraction of legal moves in a space-separated extended-UCI game string.
262
+
263
+ Args:
264
+ game_text: "WPe2e4 BPe7e5 ..." (space-separated moves)
265
+ stop_on_illegal: If True, stop at first illegal move.
266
+
267
+ Returns:
268
+ legal / total (total is moves processed, or total moves if stop_on_illegal=False)
269
+ """
270
+ try:
271
+ import chess
272
+ except ImportError:
273
+ raise ImportError("python-chess is required. Install it with: pip install python-chess")
274
+
275
+ moves = game_text.strip().split()
276
+ if not moves:
277
+ return 0.0
278
+
279
+ board = chess.Board()
280
+ legal = 0
281
+ total = 0
282
+
283
+ for mv in moves:
284
+ total += 1
285
+ uci = convert_extended_uci_to_uci(mv)
286
+ try:
287
+ m = chess.Move.from_uci(uci)
288
+ except Exception:
289
+ if stop_on_illegal:
290
+ break
291
+ continue
292
+
293
+ if m in board.legal_moves:
294
+ legal += 1
295
+ board.push(m)
296
+ else:
297
+ if stop_on_illegal:
298
+ break
299
+
300
+ return legal / max(total, 1)
301
+
302
+
303
+ def convert_uci_to_extended(uci_move: str, board_fen: str) -> str:
304
+ """
305
+ Convert standard UCI move to extended UCI format used by the dataset.
306
+
307
+ Args:
308
+ uci_move: e.g., "e2e4", "e7e8q", "e1g1"
309
+ board_fen: FEN of current board (must match move)
310
+
311
+ Returns:
312
+ Extended UCI like "WPe2e4", with suffixes:
313
+ - (x) capture
314
+ - (+) check
315
+ - (+*) checkmate
316
+ - (x+) capture+check
317
+ - (x+*) capture+checkmate
318
+ - (o) / (O) castling
319
+ - promotions as "=Q" etc
320
+ """
321
+ try:
322
+ import chess
323
+ except ImportError:
324
+ raise ImportError("python-chess is required for move conversion. Install it with: pip install python-chess")
325
+
326
+ board = chess.Board(board_fen)
327
+ move = chess.Move.from_uci(uci_move)
328
+
329
+ color = "W" if board.turn == chess.WHITE else "B"
330
+
331
+ piece = board.piece_at(move.from_square)
332
+ piece_letter = piece.symbol().upper() if piece else "P"
333
+
334
+ from_sq = chess.square_name(move.from_square)
335
+ to_sq = chess.square_name(move.to_square)
336
+
337
+ result = f"{color}{piece_letter}{from_sq}{to_sq}"
338
+
339
+ # Promotion
340
+ if move.promotion:
341
+ result += f"={chess.piece_symbol(move.promotion).upper()}"
342
+
343
+ # Capture suffix
344
+ if board.is_capture(move):
345
+ result += "(x)"
346
+
347
+ # Check / mate suffix (need to push)
348
+ board.push(move)
349
+ if board.is_checkmate():
350
+ if "(x)" in result:
351
+ result = result.replace("(x)", "(x+*)")
352
+ else:
353
+ result += "(+*)"
354
+ elif board.is_check():
355
+ if "(x)" in result:
356
+ result = result.replace("(x)", "(x+)")
357
+ else:
358
+ result += "(+)"
359
+ board.pop()
360
+
361
+ # Castling (dataset wants (o)/(O), usually no other suffix with it)
362
+ if board.is_castling(move):
363
+ result = re.sub(r"\([^)]*\)", "", result) # drop any (...) suffix
364
+ if move.to_square in [chess.G1, chess.G8]:
365
+ result += "(o)"
366
+ else:
367
+ result += "(O)"
368
+
369
+ return result
tokenizer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Decomposed Chess Tokenizer for the Chess Challenge.
3
+
4
+ Each move becomes 3 or 4 tokens:
5
+ WP e2_f e4_t
6
+ BN g8_f f6_t
7
+ Promotion adds an extra token:
8
+ WP e7_f e8_t =q
9
+
10
+ Why this helps:
11
+ - Fixed small vocab (~150 tokens)
12
+ - Near-zero OOV / UNK, so the evaluator can always parse squares
13
+ - Compatible with the provided evaluate.py (it auto-detects 'decomposed')
14
+
15
+ Special tokens behavior:
16
+ - Adds BOS only (NO EOS)
17
+ - If BOS already present, does not add it twice
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import os
24
+ from typing import Dict, List, Optional
25
+
26
+ from transformers import PreTrainedTokenizer
27
+
28
+
29
+ class ChessTokenizer(PreTrainedTokenizer):
30
+ model_input_names = ["input_ids", "attention_mask"]
31
+ vocab_files_names = {"vocab_file": "vocab.json"}
32
+
33
+ PAD_TOKEN = "[PAD]"
34
+ BOS_TOKEN = "[BOS]"
35
+ EOS_TOKEN = "[EOS]" # kept for compatibility, not auto-added
36
+ UNK_TOKEN = "[UNK]"
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_file: Optional[str] = None,
41
+ vocab: Optional[Dict[str, int]] = None,
42
+ **kwargs,
43
+ ):
44
+ self._pad_token = self.PAD_TOKEN
45
+ self._bos_token = self.BOS_TOKEN
46
+ self._eos_token = self.EOS_TOKEN
47
+ self._unk_token = self.UNK_TOKEN
48
+
49
+ # avoid duplicates from kwargs
50
+ kwargs.pop("pad_token", None)
51
+ kwargs.pop("bos_token", None)
52
+ kwargs.pop("eos_token", None)
53
+ kwargs.pop("unk_token", None)
54
+
55
+ if vocab is not None:
56
+ self._vocab = vocab
57
+ elif vocab_file is not None and os.path.exists(vocab_file):
58
+ with open(vocab_file, "r", encoding="utf-8") as f:
59
+ self._vocab = json.load(f)
60
+ else:
61
+ self._vocab = self._build_fixed_vocab()
62
+
63
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
64
+
65
+ super().__init__(
66
+ pad_token=self._pad_token,
67
+ bos_token=self._bos_token,
68
+ eos_token=self._eos_token,
69
+ unk_token=self._unk_token,
70
+ **kwargs,
71
+ )
72
+
73
+ # --------------------------
74
+ # Fixed vocab: pieces + squares + promos
75
+ # --------------------------
76
+ @staticmethod
77
+ def _all_squares() -> List[str]:
78
+ files = "abcdefgh"
79
+ ranks = "12345678"
80
+ return [f + r for r in ranks for f in files] # a1..h8
81
+
82
+ def _build_fixed_vocab(self) -> Dict[str, int]:
83
+ special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
84
+
85
+ # piece tokens: WP..WK, BP..BK
86
+ piece_tokens = [f"{c}{p}" for c in "WB" for p in "PNBRQK"]
87
+
88
+ squares = self._all_squares()
89
+ from_tokens = [f"{sq}_f" for sq in squares]
90
+ to_tokens = [f"{sq}_t" for sq in squares]
91
+
92
+ promo_tokens = ["=q", "=r", "=b", "=n"]
93
+
94
+ tokens = special + piece_tokens + from_tokens + to_tokens + promo_tokens
95
+ return {tok: i for i, tok in enumerate(tokens)}
96
+
97
+ # --------------------------
98
+ # Special tokens handling (robust with evaluate.py)
99
+ # --------------------------
100
+ def build_inputs_with_special_tokens(
101
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
102
+ ) -> List[int]:
103
+ # BOS only, NO EOS
104
+ if token_ids_1 is not None:
105
+ token_ids_0 = token_ids_0 + token_ids_1
106
+
107
+ if token_ids_0 and token_ids_0[0] == self.bos_token_id:
108
+ return token_ids_0
109
+ return [self.bos_token_id] + token_ids_0
110
+
111
+ def get_special_tokens_mask(
112
+ self,
113
+ token_ids_0: List[int],
114
+ token_ids_1: Optional[List[int]] = None,
115
+ already_has_special_tokens: bool = False,
116
+ ) -> List[int]:
117
+ if already_has_special_tokens:
118
+ specials = {self.pad_token_id, self.bos_token_id, self.eos_token_id, self.unk_token_id}
119
+ return [1 if t in specials else 0 for t in token_ids_0]
120
+
121
+ if token_ids_1 is None:
122
+ return [1] + [0] * len(token_ids_0)
123
+ return [1] + [0] * (len(token_ids_0) + len(token_ids_1))
124
+
125
+ def create_token_type_ids_from_sequences(
126
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
127
+ ) -> List[int]:
128
+ if token_ids_1 is None:
129
+ return [0] * (len(token_ids_0) + 1)
130
+ return [0] * (len(token_ids_0) + len(token_ids_1) + 1)
131
+
132
+ # --------------------------
133
+ # Tokenization
134
+ # --------------------------
135
+ def _tokenize(self, text: str) -> List[str]:
136
+ if not text or not text.strip():
137
+ return []
138
+
139
+ parts = text.strip().split()
140
+ out: List[str] = []
141
+
142
+ for tok in parts:
143
+ # allow literal special tokens present in text
144
+ if tok in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}:
145
+ out.append(tok)
146
+ continue
147
+
148
+ # already decomposed tokens
149
+ if (len(tok) == 2 and tok[0] in "WB" and tok[1] in "PNBRQK") or tok.endswith("_f") or tok.endswith("_t") or tok in {"=q", "=r", "=b", "=n"}:
150
+ out.append(tok)
151
+ continue
152
+
153
+ # parse extended UCI (dataset): WPe2e4, BNg8f6(x), WPe7e8=Q(+), ...
154
+ if len(tok) < 6:
155
+ out.append(self.UNK_TOKEN)
156
+ continue
157
+
158
+ color = tok[0]
159
+ piece = tok[1]
160
+ from_sq = tok[2:4]
161
+ to_sq = tok[4:6]
162
+
163
+ out.append(f"{color}{piece}")
164
+ out.append(f"{from_sq}_f")
165
+ out.append(f"{to_sq}_t")
166
+
167
+ # promotion like "=Q"
168
+ if "=" in tok:
169
+ try:
170
+ promo_part = tok.split("=", 1)[1]
171
+ promo_letter = promo_part[0].lower()
172
+ promo_tok = f"={promo_letter}"
173
+ if promo_tok in self._vocab:
174
+ out.append(promo_tok)
175
+ except Exception:
176
+ pass
177
+
178
+ return out
179
+
180
+ def _convert_token_to_id(self, token: str) -> int:
181
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
182
+
183
+ def _convert_id_to_token(self, index: int) -> str:
184
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
185
+
186
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
187
+ return " ".join(tokens)
188
+
189
+ # --------------------------
190
+ # Vocab I/O
191
+ # --------------------------
192
+ @property
193
+ def vocab_size(self) -> int:
194
+ return len(self._vocab)
195
+
196
+ def get_vocab(self) -> Dict[str, int]:
197
+ return dict(self._vocab)
198
+
199
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
200
+ os.makedirs(save_directory, exist_ok=True)
201
+ vocab_file = os.path.join(
202
+ save_directory,
203
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
204
+ )
205
+ with open(vocab_file, "w", encoding="utf-8") as f:
206
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
207
+ return (vocab_file,)
tokenizer_config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "auto_map": {
37
+ "AutoTokenizer": [
38
+ "tokenizer.ChessTokenizer",
39
+ null
40
+ ]
41
+ },
42
+ "bos_token": "[BOS]",
43
+ "clean_up_tokenization_spaces": true,
44
+ "eos_token": "[EOS]",
45
+ "model_max_length": 1000000000000000019884624838656,
46
+ "pad_token": "[PAD]",
47
+ "tokenizer_class": "ChessTokenizer",
48
+ "unk_token": "[UNK]"
49
+ }
vocab.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[BOS]": 1,
4
+ "[EOS]": 2,
5
+ "[UNK]": 3,
6
+ "WP": 4,
7
+ "WN": 5,
8
+ "WB": 6,
9
+ "WR": 7,
10
+ "WQ": 8,
11
+ "WK": 9,
12
+ "BP": 10,
13
+ "BN": 11,
14
+ "BB": 12,
15
+ "BR": 13,
16
+ "BQ": 14,
17
+ "BK": 15,
18
+ "a1_f": 16,
19
+ "b1_f": 17,
20
+ "c1_f": 18,
21
+ "d1_f": 19,
22
+ "e1_f": 20,
23
+ "f1_f": 21,
24
+ "g1_f": 22,
25
+ "h1_f": 23,
26
+ "a2_f": 24,
27
+ "b2_f": 25,
28
+ "c2_f": 26,
29
+ "d2_f": 27,
30
+ "e2_f": 28,
31
+ "f2_f": 29,
32
+ "g2_f": 30,
33
+ "h2_f": 31,
34
+ "a3_f": 32,
35
+ "b3_f": 33,
36
+ "c3_f": 34,
37
+ "d3_f": 35,
38
+ "e3_f": 36,
39
+ "f3_f": 37,
40
+ "g3_f": 38,
41
+ "h3_f": 39,
42
+ "a4_f": 40,
43
+ "b4_f": 41,
44
+ "c4_f": 42,
45
+ "d4_f": 43,
46
+ "e4_f": 44,
47
+ "f4_f": 45,
48
+ "g4_f": 46,
49
+ "h4_f": 47,
50
+ "a5_f": 48,
51
+ "b5_f": 49,
52
+ "c5_f": 50,
53
+ "d5_f": 51,
54
+ "e5_f": 52,
55
+ "f5_f": 53,
56
+ "g5_f": 54,
57
+ "h5_f": 55,
58
+ "a6_f": 56,
59
+ "b6_f": 57,
60
+ "c6_f": 58,
61
+ "d6_f": 59,
62
+ "e6_f": 60,
63
+ "f6_f": 61,
64
+ "g6_f": 62,
65
+ "h6_f": 63,
66
+ "a7_f": 64,
67
+ "b7_f": 65,
68
+ "c7_f": 66,
69
+ "d7_f": 67,
70
+ "e7_f": 68,
71
+ "f7_f": 69,
72
+ "g7_f": 70,
73
+ "h7_f": 71,
74
+ "a8_f": 72,
75
+ "b8_f": 73,
76
+ "c8_f": 74,
77
+ "d8_f": 75,
78
+ "e8_f": 76,
79
+ "f8_f": 77,
80
+ "g8_f": 78,
81
+ "h8_f": 79,
82
+ "a1_t": 80,
83
+ "b1_t": 81,
84
+ "c1_t": 82,
85
+ "d1_t": 83,
86
+ "e1_t": 84,
87
+ "f1_t": 85,
88
+ "g1_t": 86,
89
+ "h1_t": 87,
90
+ "a2_t": 88,
91
+ "b2_t": 89,
92
+ "c2_t": 90,
93
+ "d2_t": 91,
94
+ "e2_t": 92,
95
+ "f2_t": 93,
96
+ "g2_t": 94,
97
+ "h2_t": 95,
98
+ "a3_t": 96,
99
+ "b3_t": 97,
100
+ "c3_t": 98,
101
+ "d3_t": 99,
102
+ "e3_t": 100,
103
+ "f3_t": 101,
104
+ "g3_t": 102,
105
+ "h3_t": 103,
106
+ "a4_t": 104,
107
+ "b4_t": 105,
108
+ "c4_t": 106,
109
+ "d4_t": 107,
110
+ "e4_t": 108,
111
+ "f4_t": 109,
112
+ "g4_t": 110,
113
+ "h4_t": 111,
114
+ "a5_t": 112,
115
+ "b5_t": 113,
116
+ "c5_t": 114,
117
+ "d5_t": 115,
118
+ "e5_t": 116,
119
+ "f5_t": 117,
120
+ "g5_t": 118,
121
+ "h5_t": 119,
122
+ "a6_t": 120,
123
+ "b6_t": 121,
124
+ "c6_t": 122,
125
+ "d6_t": 123,
126
+ "e6_t": 124,
127
+ "f6_t": 125,
128
+ "g6_t": 126,
129
+ "h6_t": 127,
130
+ "a7_t": 128,
131
+ "b7_t": 129,
132
+ "c7_t": 130,
133
+ "d7_t": 131,
134
+ "e7_t": 132,
135
+ "f7_t": 133,
136
+ "g7_t": 134,
137
+ "h7_t": 135,
138
+ "a8_t": 136,
139
+ "b8_t": 137,
140
+ "c8_t": 138,
141
+ "d8_t": 139,
142
+ "e8_t": 140,
143
+ "f8_t": 141,
144
+ "g8_t": 142,
145
+ "h8_t": 143,
146
+ "=q": 144,
147
+ "=r": 145,
148
+ "=b": 146,
149
+ "=n": 147
150
+ }