haiphamcse commited on
Commit
41e1785
·
verified ·
1 Parent(s): 89c36b8

Chess Challenge submission by haiphamcse

Browse files
Files changed (8) hide show
  1. README.md +26 -0
  2. config.json +25 -0
  3. model.py +566 -0
  4. model.safetensors +3 -0
  5. special_tokens_map.json +6 -0
  6. tokenizer.py +309 -0
  7. tokenizer_config.json +50 -0
  8. vocab.json +88 -0
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chess
5
+ - llm-course
6
+ - chess-challenge
7
+ license: mit
8
+ ---
9
+
10
+ # chess-haiphamcse_olmo_12k_mod_token
11
+
12
+ Chess model submitted to the LLM Course Chess Challenge.
13
+
14
+ ## Submission Info
15
+
16
+ - **Submitted by**: [haiphamcse](https://huggingface.co/haiphamcse)
17
+ - **Parameters**: 838,144
18
+ - **Organization**: LLM-course
19
+
20
+ ## Model Details
21
+
22
+ - **Architecture**: Chess Transformer (GPT-style)
23
+ - **Vocab size**: 86
24
+ - **Embedding dim**: 128
25
+ - **Layers**: 5
26
+ - **Heads**: 4
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ChessForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model.ChessConfig",
7
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
8
+ },
9
+ "bos_token_id": 1,
10
+ "dropout": 0.1,
11
+ "dtype": "float32",
12
+ "eos_token_id": 2,
13
+ "layer_norm_epsilon": 1e-05,
14
+ "model_type": "chess_transformer",
15
+ "n_ctx": 256,
16
+ "n_embd": 128,
17
+ "n_head": 4,
18
+ "n_inner": 384,
19
+ "n_layer": 5,
20
+ "pad_token_id": 0,
21
+ "tie_weights": true,
22
+ "tie_word_embeddings": true,
23
+ "transformers_version": "4.57.6",
24
+ "vocab_size": 86
25
+ }
model.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+
13
+
14
+ def precompute_rope_(head_dim: int, max_seq_length: int, base=10000):
15
+ # Calculate theta_i
16
+ inv_freq = 1/(base ** (torch.arange(0, head_dim, 2).float() / head_dim)) # [max_len/2]
17
+ # create m
18
+ t = torch.arange(max_seq_length, dtype=torch.float32) # [max_len]
19
+ # m * theta_i
20
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
21
+ # repeat
22
+ emb = freqs.repeat_interleave(2, dim=-1)
23
+
24
+ return emb
25
+
26
+ def apply_rope_(x, rope_emb):
27
+ # Align shapes
28
+ seq_len = x.shape[1]
29
+ rope_emb_sliced = rope_emb[:seq_len, :]
30
+ # Reshape for broadcasting
31
+ # Shape (1, seq_len, 1, head_dim)
32
+ emb = rope_emb_sliced.unsqueeze(0).unsqueeze(2)
33
+
34
+ cos = emb.cos()
35
+ sin = emb.sin()
36
+
37
+ # Create partnet
38
+ x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
39
+ x_partner = torch.stack([-x_reshaped[..., 1], x_reshaped[...,0]], dim=-1)
40
+ x_partnet = x_partner.flatten(-2)
41
+
42
+ return (x*cos + x_partner*sin).type_as(x)
43
+
44
+
45
+
46
+ class ChessConfig(PretrainedConfig):
47
+ """
48
+ Configuration class for the Chess Transformer model.
49
+
50
+ This configuration is designed for a ~1M parameter model.
51
+ Students can adjust these values to explore different architectures.
52
+
53
+ Parameter budget breakdown (with default values):
54
+ - Embeddings (vocab): 1200 x 128 = 153,600
55
+ - Position Embeddings: 256 x 128 = 32,768
56
+ - Transformer Layers: 6 x ~120,000 = ~720,000
57
+ - LM Head (with weight tying): 0 (shared with embeddings)
58
+ - Total: ~906,000 parameters
59
+
60
+ Attributes:
61
+ vocab_size: Size of the vocabulary (number of unique moves).
62
+ n_embd: Embedding dimension (d_model).
63
+ n_layer: Number of transformer layers.
64
+ n_head: Number of attention heads.
65
+ n_ctx: Maximum sequence length (context window).
66
+ n_inner: Feed-forward inner dimension (default: 3 * n_embd).
67
+ dropout: Dropout probability.
68
+ layer_norm_epsilon: Epsilon for layer normalization.
69
+ tie_weights: Whether to tie embedding and output weights.
70
+ """
71
+
72
+ model_type = "chess_transformer"
73
+
74
+ def __init__(
75
+ self,
76
+ vocab_size: int = 1200,
77
+ n_embd: int = 128,
78
+ n_layer: int = 6,
79
+ n_head: int = 4,
80
+ n_ctx: int = 256,
81
+ n_inner: Optional[int] = None,
82
+ dropout: float = 0.1,
83
+ layer_norm_epsilon: float = 1e-5,
84
+ tie_weights: bool = False,
85
+ pad_token_id: int = 0,
86
+ bos_token_id: int = 1,
87
+ eos_token_id: int = 2,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(
91
+ pad_token_id=pad_token_id,
92
+ bos_token_id=bos_token_id,
93
+ eos_token_id=eos_token_id,
94
+ **kwargs,
95
+ )
96
+
97
+ self.vocab_size = vocab_size
98
+ self.n_embd = n_embd
99
+ self.n_layer = n_layer
100
+ self.n_head = n_head
101
+ self.n_ctx = n_ctx
102
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd # Reduced from 4x to 3x
103
+ self.dropout = dropout
104
+ self.layer_norm_epsilon = layer_norm_epsilon
105
+ self.tie_weights = tie_weights
106
+ # Inform HF base class about tying behavior
107
+ self.tie_word_embeddings = bool(tie_weights)
108
+
109
+
110
+
111
+ class LlamaRotaryEmbedding(nn.Module):
112
+ # inv_freq: torch.Tensor # fix linting for `register_buffer`
113
+
114
+ def __init__(self, config: ChessConfig, device=None):
115
+ super().__init__()
116
+ self.max_seq_len_cached = config.n_ctx
117
+ self.original_max_seq_len = config.n_ctx
118
+
119
+ self.config = config
120
+
121
+ rope_init_fn = self.compute_default_rope_parameters
122
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
123
+
124
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
125
+
126
+ @staticmethod
127
+ def compute_default_rope_parameters(
128
+ config: ChessConfig | None = None,
129
+ device: Optional["torch.device"] = None,
130
+ seq_len: int | None = None,
131
+ ) -> tuple["torch.Tensor", float]:
132
+ """
133
+ Computes the inverse frequencies according to the original RoPE implementation
134
+ Args:
135
+ config ([`~transformers.PreTrainedConfig`]):
136
+ The model configuration.
137
+ device (`torch.device`):
138
+ The device to use for initialization of the inverse frequencies.
139
+ seq_len (`int`, *optional*):
140
+ The current sequence length. Unused for this type of RoPE.
141
+ Returns:
142
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
143
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
144
+ """
145
+ base = 10_000.0
146
+ dim = config.n_embd // config.n_head
147
+
148
+ attention_factor = 1.0 # Unused in this type of RoPE
149
+
150
+ # Compute the inverse frequencies
151
+ inv_freq = 1.0 / (
152
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device).float() / dim)
153
+ )
154
+ return inv_freq, attention_factor
155
+
156
+ @torch.no_grad()
157
+ def forward(self, x, position_ids):
158
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
159
+ position_ids_expanded = position_ids[:, None, :].float()
160
+
161
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
162
+ emb = torch.cat((freqs, freqs), dim=-1)
163
+ cos = emb.cos() * self.attention_scaling
164
+ sin = emb.sin() * self.attention_scaling
165
+
166
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
167
+
168
+
169
+ def rotate_half(x):
170
+ """Rotates half the hidden dims of the input."""
171
+ x1 = x[..., : x.shape[-1] // 2]
172
+ x2 = x[..., x.shape[-1] // 2 :]
173
+ return torch.cat((-x2, x1), dim=-1)
174
+
175
+
176
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
177
+ """Applies Rotary Position Embedding to the query and key tensors.
178
+
179
+ Args:
180
+ q (`torch.Tensor`): The query tensor.
181
+ k (`torch.Tensor`): The key tensor.
182
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
183
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
184
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
185
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
186
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
187
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
188
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
189
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
190
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
191
+ Returns:
192
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
193
+ """
194
+ cos = cos.unsqueeze(unsqueeze_dim)
195
+ sin = sin.unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ class MultiHeadAttention(nn.Module):
202
+ """
203
+ Multi-head self-attention module.
204
+
205
+ This is a standard scaled dot-product attention implementation
206
+ with causal masking for autoregressive generation.
207
+ """
208
+
209
+ def __init__(self, config: ChessConfig):
210
+ super().__init__()
211
+
212
+ assert config.n_embd % config.n_head == 0, \
213
+ f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
214
+
215
+ self.n_head = config.n_head
216
+ self.n_embd = config.n_embd
217
+ self.head_dim = config.n_embd // config.n_head
218
+
219
+
220
+ # Using QK-norm
221
+ self.q_norm = nn.RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
222
+ self.k_norm = nn.RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
223
+ # Combined QKV projection for efficiency
224
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
225
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
226
+
227
+ self.dropout = nn.Dropout(config.dropout)
228
+
229
+ # Causal mask (will be created on first forward pass)
230
+ self.register_buffer(
231
+ "bias",
232
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
233
+ 1, 1, config.n_ctx, config.n_ctx
234
+ ),
235
+ persistent=False,
236
+ )
237
+
238
+ def forward(
239
+ self,
240
+ x: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_embeds = None,
243
+ ) -> torch.Tensor:
244
+ batch_size, seq_len, _ = x.size()
245
+
246
+ # Compute Q, K, V
247
+ qkv = self.c_attn(x)
248
+ q, k, v = qkv.split(self.n_embd, dim=2)
249
+
250
+ q = self.q_norm(q)
251
+ k = self.k_norm(k)
252
+
253
+
254
+ # Reshape for multi-head attention
255
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
256
+ k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
257
+ v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
258
+
259
+ cos, sin = position_embeds
260
+ # Positional Embedding
261
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
262
+
263
+ # Scaled dot-product attention
264
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
265
+
266
+ # Apply causal mask
267
+ causal_mask = self.bias[:, :, :seq_len, :seq_len]
268
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
269
+
270
+ # Apply attention mask (for padding)
271
+ if attention_mask is not None:
272
+ # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
273
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
274
+ attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
275
+
276
+ attn_weights = F.softmax(attn_weights, dim=-1)
277
+ attn_weights = self.dropout(attn_weights)
278
+
279
+ # Apply attention to values
280
+ attn_output = torch.matmul(attn_weights, v)
281
+
282
+ # Reshape back
283
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
284
+ batch_size, seq_len, self.n_embd
285
+ )
286
+
287
+ # Output projection
288
+ attn_output = self.c_proj(attn_output)
289
+
290
+ return attn_output
291
+
292
+
293
+ class FeedForward(nn.Module):
294
+ """
295
+ Feed-forward network (MLP) module.
296
+
297
+ Standard two-layer MLP with GELU activation.
298
+ """
299
+
300
+ def __init__(self, config: ChessConfig):
301
+ super().__init__()
302
+
303
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
304
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
305
+ self.dropout = nn.Dropout(config.dropout)
306
+
307
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
308
+ x = self.c_fc(x)
309
+ x = F.gelu(x)
310
+ x = self.c_proj(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class TransformerBlock(nn.Module):
316
+ """
317
+ A single transformer block with attention and feed-forward layers.
318
+
319
+ Uses pre-normalization (LayerNorm before attention/FFN) for better
320
+ training stability.
321
+ """
322
+
323
+ def __init__(self, config: ChessConfig):
324
+ super().__init__()
325
+
326
+ self.ln_1 = nn.RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
327
+ self.attn = MultiHeadAttention(config)
328
+ self.ln_2 = nn.RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
329
+ self.mlp = FeedForward(config)
330
+
331
+ def forward(
332
+ self,
333
+ x: torch.Tensor,
334
+ attention_mask: Optional[torch.Tensor] = None,
335
+ position_embeds = None,
336
+ ) -> torch.Tensor:
337
+ pass
338
+ # Attention -> norm -> residual
339
+ x = self.ln_1(self.attn(x, attention_mask=attention_mask, position_embeds=position_embeds)) + x
340
+ # Feed-forward -> norm -> residual
341
+ x = self.ln_2(self.mlp(x)) + x
342
+ return x
343
+
344
+
345
+
346
+ class ChessForCausalLM(PreTrainedModel):
347
+ """
348
+ Chess Transformer for Causal Language Modeling (next-move prediction).
349
+
350
+ This model is designed to predict the next chess move given a sequence
351
+ of previous moves. It uses a GPT-style architecture with:
352
+ - Token embeddings for chess moves
353
+ - Learned positional embeddings
354
+ - Stacked transformer blocks
355
+ - Linear head for next-token prediction
356
+
357
+ The model supports weight tying between the embedding layer and the
358
+ output projection to save parameters.
359
+
360
+ Example:
361
+ >>> config = ChessConfig(vocab_size=1200, n_embd=128, n_layer=6)
362
+ >>> model = ChessForCausalLM(config)
363
+ >>> inputs = {"input_ids": torch.tensor([[1, 42, 87]])}
364
+ >>> outputs = model(**inputs)
365
+ >>> next_move_logits = outputs.logits[:, -1, :]
366
+ """
367
+
368
+ config_class = ChessConfig
369
+ base_model_prefix = "transformer"
370
+ supports_gradient_checkpointing = True
371
+ # Suppress missing-key warning for tied lm_head when loading
372
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
373
+
374
+ def __init__(self, config: ChessConfig):
375
+ super().__init__(config)
376
+
377
+ # Token and position embeddings
378
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
379
+ # self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
380
+ self.wpe = LlamaRotaryEmbedding(config)
381
+ self.drop = nn.Dropout(config.dropout)
382
+
383
+ # Transformer blocks
384
+ self.h = nn.ModuleList([
385
+ TransformerBlock(config) for _ in range(config.n_layer)
386
+ ])
387
+
388
+ # Final layer norm
389
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
390
+
391
+ # Output head
392
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
393
+
394
+ # Declare tied weights for proper serialization
395
+ if config.tie_weights:
396
+ self._tied_weights_keys = ["lm_head.weight"]
397
+
398
+ # Initialize weights
399
+ self.post_init()
400
+
401
+ # Tie weights if configured
402
+ if config.tie_weights:
403
+ self.tie_weights()
404
+
405
+ def get_input_embeddings(self) -> nn.Module:
406
+ return self.wte
407
+
408
+ def set_input_embeddings(self, new_embeddings: nn.Module):
409
+ self.wte = new_embeddings
410
+ if getattr(self.config, "tie_weights", False):
411
+ self.tie_weights()
412
+
413
+ def get_output_embeddings(self) -> nn.Module:
414
+ return self.lm_head
415
+
416
+ def set_output_embeddings(self, new_embeddings: nn.Module):
417
+ self.lm_head = new_embeddings
418
+
419
+ def tie_weights(self):
420
+ # Use HF helper to tie or clone depending on config
421
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
422
+ self._tie_or_clone_weights(self.lm_head, self.wte)
423
+
424
+ def _init_weights(self, module: nn.Module):
425
+ """Initialize weights following GPT-2 style."""
426
+ if isinstance(module, nn.Linear):
427
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
428
+ if module.bias is not None:
429
+ torch.nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.Embedding):
431
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
432
+ elif isinstance(module, nn.LayerNorm):
433
+ torch.nn.init.ones_(module.weight)
434
+ torch.nn.init.zeros_(module.bias)
435
+
436
+ def forward(
437
+ self,
438
+ input_ids: torch.LongTensor,
439
+ attention_mask: Optional[torch.Tensor] = None,
440
+ position_ids: Optional[torch.LongTensor] = None,
441
+ labels: Optional[torch.LongTensor] = None,
442
+ return_dict: Optional[bool] = None,
443
+ **kwargs,
444
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
445
+ """
446
+ Forward pass of the model.
447
+
448
+ Args:
449
+ input_ids: Token IDs of shape (batch_size, seq_len).
450
+ attention_mask: Attention mask of shape (batch_size, seq_len).
451
+ position_ids: Position IDs of shape (batch_size, seq_len).
452
+ labels: Labels for language modeling loss.
453
+ return_dict: Whether to return a ModelOutput object.
454
+
455
+ Returns:
456
+ CausalLMOutputWithPast containing loss (if labels provided) and logits.
457
+ """
458
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
459
+
460
+ batch_size, seq_len = input_ids.size()
461
+ device = input_ids.device
462
+
463
+ # Create position IDs if not provided
464
+ if position_ids is None:
465
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
466
+
467
+ # Get embeddings
468
+ token_embeds = self.wte(input_ids)
469
+ position_embeds = self.wpe(token_embeds, position_ids)
470
+ hidden_states = self.drop(token_embeds)
471
+
472
+ # Pass through transformer blocks
473
+ for block in self.h:
474
+ hidden_states = block(hidden_states, attention_mask=attention_mask, position_embeds=position_embeds)
475
+
476
+ # Final layer norm
477
+ hidden_states = self.ln_f(hidden_states)
478
+
479
+ # Get logits
480
+ logits = self.lm_head(hidden_states)
481
+
482
+ # Compute loss if labels are provided
483
+ loss = None
484
+ if labels is not None:
485
+ # Shift logits and labels for next-token prediction
486
+ shift_logits = logits[..., :-1, :].contiguous()
487
+ shift_labels = labels[..., 1:].contiguous()
488
+
489
+ # Flatten for cross-entropy
490
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
491
+ # loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
492
+ loss = loss_fct(
493
+ shift_logits.view(-1, shift_logits.size(-1)),
494
+ shift_labels.view(-1),
495
+ )
496
+
497
+ if not return_dict:
498
+ output = (logits,)
499
+ return ((loss,) + output) if loss is not None else output
500
+
501
+ return CausalLMOutputWithPast(
502
+ loss=loss,
503
+ logits=logits,
504
+ past_key_values=None,
505
+ hidden_states=None,
506
+ attentions=None,
507
+ )
508
+
509
+ @torch.no_grad()
510
+ def generate_move(
511
+ self,
512
+ input_ids: torch.LongTensor,
513
+ temperature: float = 1.0,
514
+ top_k: Optional[int] = None,
515
+ top_p: Optional[float] = None,
516
+ ) -> int:
517
+ """
518
+ Generate the next move given a sequence of moves.
519
+
520
+ Args:
521
+ input_ids: Token IDs of shape (1, seq_len).
522
+ temperature: Sampling temperature (1.0 = no change).
523
+ top_k: If set, only sample from top k tokens.
524
+ top_p: If set, use nucleus sampling with this threshold.
525
+
526
+ Returns:
527
+ The token ID of the predicted next move.
528
+ """
529
+ self.eval()
530
+
531
+ # Get logits for the last position
532
+ outputs = self(input_ids)
533
+ logits = outputs.logits[:, -1, :] / temperature
534
+
535
+ # Apply top-k filtering
536
+ if top_k is not None:
537
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
538
+ logits[indices_to_remove] = float("-inf")
539
+
540
+ # Apply top-p (nucleus) filtering
541
+ if top_p is not None:
542
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
543
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
544
+
545
+ # Remove tokens with cumulative probability above the threshold
546
+ sorted_indices_to_remove = cumulative_probs > top_p
547
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
548
+ sorted_indices_to_remove[..., 0] = 0
549
+
550
+ indices_to_remove = sorted_indices_to_remove.scatter(
551
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
552
+ )
553
+ logits[indices_to_remove] = float("-inf")
554
+
555
+ # Sample from the distribution
556
+ probs = F.softmax(logits, dim=-1)
557
+ next_token = torch.multinomial(probs, num_samples=1)
558
+
559
+ return next_token.item()
560
+
561
+
562
+ # Register the model with Auto classes for easy loading
563
+ from transformers import AutoConfig, AutoModelForCausalLM
564
+
565
+ AutoConfig.register("chess_transformer", ChessConfig)
566
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32a745c333ad9a014b7cd6d9a15ee45e84658c20e6c7d278451b775998eedad8
3
+ size 3358008
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+
23
+
24
+ class ChessTokenizer(PreTrainedTokenizer):
25
+ """
26
+ A custom tokenizer for chess moves using extended UCI notation.
27
+
28
+ This tokenizer maps each possible chess move to a unique token ID.
29
+ The vocabulary is built from the training dataset to ensure all moves
30
+ encountered during training have a corresponding token.
31
+
32
+ Example:
33
+ >>> tokenizer = ChessTokenizer()
34
+ >>> tokenizer.encode("WPe2e4 BPe7e5")
35
+ [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
+ """
37
+
38
+ model_input_names = ["input_ids", "attention_mask"]
39
+ vocab_files_names = {"vocab_file": "vocab.json"}
40
+
41
+ # Special tokens
42
+ PAD_TOKEN = "[PAD]"
43
+ BOS_TOKEN = "[BOS]"
44
+ EOS_TOKEN = "[EOS]"
45
+ UNK_TOKEN = "[UNK]"
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_file: Optional[str] = None,
50
+ vocab: Optional[Dict[str, int]] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Initialize the chess tokenizer.
55
+
56
+ Args:
57
+ vocab_file: Path to a JSON file containing the vocabulary mapping.
58
+ vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
59
+ **kwargs: Additional arguments passed to PreTrainedTokenizer.
60
+ """
61
+ # Initialize special tokens
62
+ self._pad_token = self.PAD_TOKEN
63
+ self._bos_token = self.BOS_TOKEN
64
+ self._eos_token = self.EOS_TOKEN
65
+ self._unk_token = self.UNK_TOKEN
66
+
67
+ # Remove any duplicate special-token entries passed through kwargs
68
+ # to avoid "multiple values for keyword" errors when loading from disk.
69
+ kwargs.pop("pad_token", None)
70
+ kwargs.pop("bos_token", None)
71
+ kwargs.pop("eos_token", None)
72
+ kwargs.pop("unk_token", None)
73
+
74
+ # Load or create vocabulary
75
+ if vocab is not None:
76
+ self._vocab = vocab
77
+ elif vocab_file is not None and os.path.exists(vocab_file):
78
+ with open(vocab_file, "r", encoding="utf-8") as f:
79
+ self._vocab = json.load(f)
80
+ else:
81
+ # Create a minimal vocabulary with just special tokens
82
+ # The full vocabulary should be built from the dataset
83
+ self._vocab = self._create_default_vocab()
84
+
85
+ # Create reverse mapping
86
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
87
+
88
+ # Call parent init AFTER setting up vocab
89
+ super().__init__(
90
+ pad_token=self._pad_token,
91
+ bos_token=self._bos_token,
92
+ eos_token=self._eos_token,
93
+ unk_token=self._unk_token,
94
+ **kwargs,
95
+ )
96
+
97
+ def _create_default_vocab(self) -> Dict[str, int]:
98
+ """
99
+ Create a minimal default vocabulary with just special tokens.
100
+
101
+ For the full vocabulary, use `build_vocab_from_dataset()`.
102
+ This minimal vocab is just a placeholder - you should build from data.
103
+ """
104
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
105
+ vocab = {token: idx for idx, token in enumerate(special_tokens)}
106
+ return vocab
107
+
108
+ @classmethod
109
+ def build_vocab_from_iterator(
110
+ cls,
111
+ iterator,
112
+ min_frequency: int = 1,
113
+ ) -> "ChessTokenizer":
114
+ """
115
+ Build a tokenizer vocabulary from an iterator of game strings.
116
+
117
+ Args:
118
+ iterator: An iterator yielding game strings (space-separated moves).
119
+ min_frequency: Minimum frequency for a token to be included.
120
+
121
+ Returns:
122
+ A ChessTokenizer with the built vocabulary.
123
+ """
124
+ from collections import Counter
125
+
126
+ token_counts = Counter()
127
+
128
+ for game in iterator:
129
+ moves = game.strip().split()
130
+ # Break the words down into different smaller chunks of moves
131
+ for move in moves:
132
+ if '(' in move:
133
+ indx = move.find('(')
134
+ command = move[:indx]
135
+ special = move[indx:]
136
+ token_counts.update([command[:1], command[1:2], command[2:4], command[4:], special])
137
+ else:
138
+ token_counts.update([move[:1], move[1:2], move[2:4], move[4:]])
139
+
140
+ # Filter by frequency
141
+ tokens = [
142
+ token for token, count in token_counts.items()
143
+ if count >= min_frequency
144
+ ]
145
+
146
+ # Sort for reproducibility
147
+ tokens = sorted(tokens)
148
+
149
+ # Build vocabulary
150
+ special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
151
+ vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
152
+
153
+ return cls(vocab=vocab)
154
+
155
+ @classmethod
156
+ def build_vocab_from_dataset(
157
+ cls,
158
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
159
+ split: str = "train",
160
+ column: str = "text",
161
+ min_frequency: int = 500,
162
+ max_samples: Optional[int] = 100000,
163
+ ) -> "ChessTokenizer":
164
+ """
165
+ Build a tokenizer vocabulary from a Hugging Face dataset.
166
+
167
+ Args:
168
+ dataset_name: Name of the dataset on Hugging Face Hub.
169
+ split: Dataset split to use.
170
+ column: Column containing the game strings.
171
+ min_frequency: Minimum frequency for a token to be included (default: 500).
172
+ max_samples: Maximum number of samples to process (default: 100k).
173
+
174
+ Returns:
175
+ A ChessTokenizer with the built vocabulary.
176
+ """
177
+ from datasets import load_dataset
178
+
179
+ dataset = load_dataset(dataset_name, split=split)
180
+
181
+ if max_samples is not None:
182
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
183
+
184
+ def game_iterator():
185
+ for example in dataset:
186
+ yield example[column]
187
+
188
+ return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
189
+
190
+ @property
191
+ def vocab_size(self) -> int:
192
+ """Return the size of the vocabulary."""
193
+ return len(self._vocab)
194
+
195
+ def get_vocab(self) -> Dict[str, int]:
196
+ """Return the vocabulary as a dictionary."""
197
+ return dict(self._vocab)
198
+
199
+ def _tokenize(self, text: str) -> List[str]:
200
+ """
201
+ Tokenize a string of moves into a list of tokens.
202
+
203
+ Args:
204
+ text: A string of space-separated moves.
205
+
206
+ Returns:
207
+ List of move tokens.
208
+ """
209
+ moves = text.strip().split()
210
+ tokens = []
211
+ for move in moves:
212
+ if '(' in move:
213
+ indx = move.find('(')
214
+ command = move[:indx]
215
+ special = move[indx:]
216
+ tokens.extend([command[:1], command[1:2], command[2:4], command[4:], special])
217
+ else:
218
+ tokens.extend([move[:1], move[1:2], move[2:4], move[4:]])
219
+
220
+
221
+ return tokens
222
+
223
+
224
+ def _convert_token_to_id(self, token: str) -> int:
225
+ """Convert a token to its ID."""
226
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
227
+
228
+ def _convert_id_to_token(self, index: int) -> str:
229
+ """Convert an ID to its token."""
230
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
231
+
232
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
233
+ """Convert a list of tokens back to a string."""
234
+ # Filter out special tokens for cleaner output
235
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
236
+ return " ".join(t for t in tokens if t not in special)
237
+
238
+ def save_vocabulary(
239
+ self,
240
+ save_directory: str,
241
+ filename_prefix: Optional[str] = None,
242
+ ) -> tuple:
243
+ """
244
+ Save the vocabulary to a JSON file.
245
+
246
+ Args:
247
+ save_directory: Directory to save the vocabulary.
248
+ filename_prefix: Optional prefix for the filename.
249
+
250
+ Returns:
251
+ Tuple containing the path to the saved vocabulary file.
252
+ """
253
+ if not os.path.isdir(save_directory):
254
+ os.makedirs(save_directory, exist_ok=True)
255
+
256
+ vocab_file = os.path.join(
257
+ save_directory,
258
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
259
+ )
260
+
261
+ with open(vocab_file, "w", encoding="utf-8") as f:
262
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
263
+
264
+ return (vocab_file,)
265
+
266
+
267
+ def count_vocab_from_dataset(
268
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
269
+ split: str = "train",
270
+ column: str = "text",
271
+ max_samples: Optional[int] = 10000,
272
+ ) -> Dict[str, int]:
273
+ """
274
+ Count token frequencies in a dataset (useful for vocabulary analysis).
275
+
276
+ Args:
277
+ dataset_name: Name of the dataset on Hugging Face Hub.
278
+ split: Dataset split to use.
279
+ column: Column containing the game strings.
280
+ max_samples: Maximum number of samples to process.
281
+
282
+ Returns:
283
+ Dictionary mapping tokens to their frequencies.
284
+ """
285
+ from collections import Counter
286
+ from datasets import load_dataset
287
+
288
+ dataset = load_dataset(dataset_name, split=split)
289
+
290
+ if max_samples is not None:
291
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
292
+
293
+ token_counts = Counter()
294
+
295
+ for example in dataset:
296
+ moves = example[column].strip().split()
297
+ token_counts.update(moves)
298
+
299
+ return dict(token_counts)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ tokenizer = ChessTokenizer.build_vocab_from_dataset(
304
+ dataset_name="dlouapre/lichess_2025-01_1M",
305
+ min_frequency=500, # Only keep moves that appear at least 500 times
306
+ max_samples=100000, # Use 100k games to build vocabulary
307
+ )
308
+ out = tokenizer.encode("WPe2e4 BPe7e5")
309
+ breakpoint()
tokenizer_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "auto_map": {
37
+ "AutoTokenizer": [
38
+ "tokenizer.ChessTokenizer",
39
+ null
40
+ ]
41
+ },
42
+ "bos_token": "[BOS]",
43
+ "clean_up_tokenization_spaces": false,
44
+ "eos_token": "[EOS]",
45
+ "extra_special_tokens": {},
46
+ "model_max_length": 1000000000000000019884624838656,
47
+ "pad_token": "[PAD]",
48
+ "tokenizer_class": "ChessTokenizer",
49
+ "unk_token": "[UNK]"
50
+ }
vocab.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[BOS]": 1,
4
+ "[EOS]": 2,
5
+ "[UNK]": 3,
6
+ "(+)": 4,
7
+ "(+*)": 5,
8
+ "(+Q)": 6,
9
+ "(O)": 7,
10
+ "(Q)": 8,
11
+ "(o)": 9,
12
+ "(x)": 10,
13
+ "(x+)": 11,
14
+ "(x+*)": 12,
15
+ "(x+Q)": 13,
16
+ "(xE)": 14,
17
+ "B": 15,
18
+ "K": 16,
19
+ "N": 17,
20
+ "P": 18,
21
+ "Q": 19,
22
+ "R": 20,
23
+ "W": 21,
24
+ "a1": 22,
25
+ "a2": 23,
26
+ "a3": 24,
27
+ "a4": 25,
28
+ "a5": 26,
29
+ "a6": 27,
30
+ "a7": 28,
31
+ "a8": 29,
32
+ "b1": 30,
33
+ "b2": 31,
34
+ "b3": 32,
35
+ "b4": 33,
36
+ "b5": 34,
37
+ "b6": 35,
38
+ "b7": 36,
39
+ "b8": 37,
40
+ "c1": 38,
41
+ "c2": 39,
42
+ "c3": 40,
43
+ "c4": 41,
44
+ "c5": 42,
45
+ "c6": 43,
46
+ "c7": 44,
47
+ "c8": 45,
48
+ "d1": 46,
49
+ "d2": 47,
50
+ "d3": 48,
51
+ "d4": 49,
52
+ "d5": 50,
53
+ "d6": 51,
54
+ "d7": 52,
55
+ "d8": 53,
56
+ "e1": 54,
57
+ "e2": 55,
58
+ "e3": 56,
59
+ "e4": 57,
60
+ "e5": 58,
61
+ "e6": 59,
62
+ "e7": 60,
63
+ "e8": 61,
64
+ "f1": 62,
65
+ "f2": 63,
66
+ "f3": 64,
67
+ "f4": 65,
68
+ "f5": 66,
69
+ "f6": 67,
70
+ "f7": 68,
71
+ "f8": 69,
72
+ "g1": 70,
73
+ "g2": 71,
74
+ "g3": 72,
75
+ "g4": 73,
76
+ "g5": 74,
77
+ "g6": 75,
78
+ "g7": 76,
79
+ "g8": 77,
80
+ "h1": 78,
81
+ "h2": 79,
82
+ "h3": 80,
83
+ "h4": 81,
84
+ "h5": 82,
85
+ "h6": 83,
86
+ "h7": 84,
87
+ "h8": 85
88
+ }