Bichrai commited on
Commit
2c2b862
·
verified ·
1 Parent(s): 0ff8eb7

Fix model loading with auto_map

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.safetensors.backup filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chess Challenge source module."""
2
+
3
+ from .model import ChessConfig, ChessForCausalLM
4
+ from .tokenizer import ChessTokenizer
5
+
6
+ # Lazy import for evaluate to avoid RuntimeWarning when running as module
7
+ def __getattr__(name):
8
+ if name == "ChessEvaluator":
9
+ from .evaluate import ChessEvaluator
10
+ return ChessEvaluator
11
+ if name == "load_model_from_hub":
12
+ from .evaluate import load_model_from_hub
13
+ return load_model_from_hub
14
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
15
+
16
+ __all__ = [
17
+ "ChessConfig",
18
+ "ChessForCausalLM",
19
+ "ChessTokenizer",
20
+ "ChessEvaluator",
21
+ "load_model_from_hub",
22
+ ]
config.json CHANGED
@@ -17,5 +17,9 @@
17
  "rope_base": 10000,
18
  "tie_weights": true,
19
  "transformers_version": "4.57.0",
20
- "vocab_size": 1682
21
- }
 
 
 
 
 
17
  "rope_base": 10000,
18
  "tie_weights": true,
19
  "transformers_version": "4.57.0",
20
+ "vocab_size": 1682,
21
+ "auto_map": {
22
+ "AutoConfig": "model.ChessConfig",
23
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
24
+ }
25
+ }
model.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+
4
+ This module provides a simple GPT-style transformer architecture
5
+ designed to fit within the 1M parameter constraint.
6
+
7
+ Key components:
8
+ - ChessConfig: Configuration class for model hyperparameters
9
+ - ChessForCausalLM: The main model class for next-move prediction
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+
24
+
25
+ class RMSNorm(nn.Module):
26
+ """
27
+ Root Mean Square Layer Normalization.
28
+
29
+ RMSNorm is more efficient than LayerNorm as it:
30
+ - Does not subtract mean (re-centering)
31
+ - Does not have a bias parameter
32
+ - Only has scale parameter (weight)
33
+
34
+ This saves computation and parameters while maintaining performance.
35
+ Paper: https://arxiv.org/abs/1910.07467
36
+ """
37
+
38
+ def __init__(self, dim: int, eps: float = 1e-6):
39
+ super().__init__()
40
+ self.eps = eps
41
+ self.weight = nn.Parameter(torch.ones(dim))
42
+
43
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
44
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ output = self._norm(x.float()).type_as(x)
48
+ return output * self.weight
49
+
50
+
51
+ class RotaryPositionalEmbedding(nn.Module):
52
+ """
53
+ Rotary Position Embedding (RoPE).
54
+
55
+ RoPE encodes absolute positional information with rotation matrices
56
+ and incorporates relative positional information naturally.
57
+
58
+ Advantages over learned embeddings:
59
+ - No additional parameters (saves n_ctx * n_embd parameters)
60
+ - Better extrapolation to longer sequences
61
+ - Encodes relative positions naturally
62
+
63
+ Paper: https://arxiv.org/abs/2104.09864
64
+ """
65
+
66
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
67
+ super().__init__()
68
+ self.dim = dim
69
+ self.max_seq_len = max_seq_len
70
+ self.base = base
71
+
72
+ # Compute inverse frequencies
73
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
74
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
75
+
76
+ # Pre-compute frequencies for max_seq_len
77
+ self._set_cos_sin_cache(max_seq_len)
78
+
79
+ def _set_cos_sin_cache(self, seq_len: int):
80
+ self.max_cached_len = seq_len
81
+ t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
82
+ freqs = torch.outer(t, self.inv_freq)
83
+ # Different from paper, but uses a different permutation to get same result
84
+ emb = torch.cat((freqs, freqs), dim=-1)
85
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
86
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
87
+
88
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ # x: [batch_size, seq_len, n_head, head_dim]
90
+ if seq_len > self.max_cached_len:
91
+ self._set_cos_sin_cache(seq_len)
92
+
93
+ return (
94
+ self.cos_cached[:seq_len, ...].to(x.device),
95
+ self.sin_cached[:seq_len, ...].to(x.device),
96
+ )
97
+
98
+
99
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Apply rotary position embeddings to query and key tensors.
102
+
103
+ Args:
104
+ q: Query tensor of shape [batch_size, seq_len, n_head, head_dim]
105
+ k: Key tensor of shape [batch_size, seq_len, n_head, head_dim]
106
+ cos: Cosine values of shape [seq_len, head_dim]
107
+ sin: Sine values of shape [seq_len, head_dim]
108
+
109
+ Returns:
110
+ Tuple of rotated (q, k) tensors
111
+ """
112
+ # Reshape cos and sin for broadcasting
113
+ cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim]
114
+ sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim]
115
+
116
+ # Rotate half
117
+ def rotate_half(x):
118
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
119
+ return torch.cat((-x2, x1), dim=-1)
120
+
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+
124
+ return q_embed, k_embed
125
+
126
+
127
+ class ChessConfig(PretrainedConfig):
128
+ """
129
+ Configuration class for the Chess Transformer model.
130
+
131
+ This configuration is designed for a ~1M parameter model.
132
+ Optimizations applied:
133
+ - RMSNorm instead of LayerNorm (fewer params, same performance)
134
+ - RoPE instead of learned positional embeddings (saves n_ctx * n_embd params)
135
+ - Weight tying between embeddings and output (saves vocab_size * n_embd params)
136
+
137
+ Parameter budget breakdown (with default values):
138
+ - Token Embeddings: 1200 x 128 = 153,600
139
+ - Position Embeddings: 0 (using RoPE - no parameters)
140
+ - Transformer Layers: 6 x ~115,000 = ~690,000 (RMSNorm saves params)
141
+ - LM Head (with weight tying): 0 (shared with embeddings)
142
+ - Total: ~843,000 parameters (157k saved vs LayerNorm + learned pos embeddings)
143
+
144
+ Attributes:
145
+ vocab_size: Size of the vocabulary (number of unique moves).
146
+ n_embd: Embedding dimension (d_model).
147
+ n_layer: Number of transformer layers.
148
+ n_head: Number of attention heads.
149
+ n_ctx: Maximum sequence length (context window).
150
+ n_inner: Feed-forward inner dimension (default: 3 * n_embd).
151
+ dropout: Dropout probability.
152
+ rms_norm_eps: Epsilon for RMS normalization.
153
+ tie_weights: Whether to tie embedding and output weights.
154
+ rope_base: Base for RoPE frequency calculation.
155
+ """
156
+
157
+ model_type = "chess_transformer"
158
+
159
+ def __init__(
160
+ self,
161
+ vocab_size: int = 1200,
162
+ n_embd: int = 128,
163
+ n_layer: int = 6,
164
+ n_head: int = 4,
165
+ n_ctx: int = 256,
166
+ n_inner: Optional[int] = None,
167
+ dropout: float = 0.1,
168
+ rms_norm_eps: float = 1e-6,
169
+ tie_weights: bool = True,
170
+ rope_base: int = 10000,
171
+ pad_token_id: int = 0,
172
+ bos_token_id: int = 1,
173
+ eos_token_id: int = 2,
174
+ **kwargs,
175
+ ):
176
+ super().__init__(
177
+ pad_token_id=pad_token_id,
178
+ bos_token_id=bos_token_id,
179
+ eos_token_id=eos_token_id,
180
+ **kwargs,
181
+ )
182
+
183
+ self.vocab_size = vocab_size
184
+ self.n_embd = n_embd
185
+ self.n_layer = n_layer
186
+ self.n_head = n_head
187
+ self.n_ctx = n_ctx
188
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd # Reduced from 4x to 3x
189
+ self.dropout = dropout
190
+ self.rms_norm_eps = rms_norm_eps
191
+ self.tie_weights = tie_weights
192
+ self.rope_base = rope_base
193
+ # Inform HF base class about tying behavior
194
+ self.tie_word_embeddings = bool(tie_weights)
195
+
196
+
197
+ class MultiHeadAttention(nn.Module):
198
+ """
199
+ Multi-head self-attention module with RoPE.
200
+
201
+ This implementation uses:
202
+ - Rotary Position Embeddings (RoPE) for position encoding
203
+ - Causal masking for autoregressive generation
204
+ - Scaled dot-product attention
205
+ """
206
+
207
+ def __init__(self, config: ChessConfig):
208
+ super().__init__()
209
+
210
+ assert config.n_embd % config.n_head == 0, \
211
+ f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
212
+
213
+ self.n_head = config.n_head
214
+ self.n_embd = config.n_embd
215
+ self.head_dim = config.n_embd // config.n_head
216
+
217
+ # Combined QKV projection for efficiency
218
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
219
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
220
+
221
+ self.dropout = nn.Dropout(config.dropout)
222
+
223
+ # RoPE for positional encoding
224
+ self.rotary_emb = RotaryPositionalEmbedding(
225
+ dim=self.head_dim,
226
+ max_seq_len=config.n_ctx,
227
+ base=config.rope_base,
228
+ )
229
+
230
+ # Causal mask
231
+ self.register_buffer(
232
+ "bias",
233
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
234
+ 1, 1, config.n_ctx, config.n_ctx
235
+ ),
236
+ persistent=False,
237
+ )
238
+
239
+ def forward(
240
+ self,
241
+ x: torch.Tensor,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ ) -> torch.Tensor:
244
+ batch_size, seq_len, _ = x.size()
245
+
246
+ # Compute Q, K, V
247
+ qkv = self.c_attn(x)
248
+ q, k, v = qkv.split(self.n_embd, dim=2)
249
+
250
+ # Reshape for multi-head attention: [batch, seq_len, n_head, head_dim]
251
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim)
252
+ k = k.view(batch_size, seq_len, self.n_head, self.head_dim)
253
+ v = v.view(batch_size, seq_len, self.n_head, self.head_dim)
254
+
255
+ # Apply RoPE to Q and K
256
+ cos, sin = self.rotary_emb(q, seq_len)
257
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
258
+
259
+ # Transpose for attention: [batch, n_head, seq_len, head_dim]
260
+ q = q.transpose(1, 2)
261
+ k = k.transpose(1, 2)
262
+ v = v.transpose(1, 2)
263
+
264
+ # Scaled dot-product attention
265
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
266
+
267
+ # Apply causal mask
268
+ causal_mask = self.bias[:, :, :seq_len, :seq_len]
269
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
270
+
271
+ # Apply attention mask (for padding)
272
+ if attention_mask is not None:
273
+ # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
274
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
275
+ attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
276
+
277
+ attn_weights = F.softmax(attn_weights, dim=-1)
278
+ attn_weights = self.dropout(attn_weights)
279
+
280
+ # Apply attention to values
281
+ attn_output = torch.matmul(attn_weights, v)
282
+
283
+ # Reshape back
284
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
285
+ batch_size, seq_len, self.n_embd
286
+ )
287
+
288
+ # Output projection
289
+ attn_output = self.c_proj(attn_output)
290
+
291
+ return attn_output
292
+
293
+
294
+ class FeedForward(nn.Module):
295
+ """
296
+ Feed-forward network (MLP) module.
297
+
298
+ Standard two-layer MLP with GELU activation.
299
+ """
300
+
301
+ def __init__(self, config: ChessConfig):
302
+ super().__init__()
303
+
304
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
305
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
306
+ self.dropout = nn.Dropout(config.dropout)
307
+
308
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
309
+ x = self.c_fc(x)
310
+ x = F.gelu(x)
311
+ x = self.c_proj(x)
312
+ x = self.dropout(x)
313
+ return x
314
+
315
+
316
+ class TransformerBlock(nn.Module):
317
+ """
318
+ A single transformer block with attention and feed-forward layers.
319
+
320
+ Uses pre-normalization (RMSNorm before attention/FFN) for better
321
+ training stability and parameter efficiency.
322
+ """
323
+
324
+ def __init__(self, config: ChessConfig):
325
+ super().__init__()
326
+
327
+ self.rms_1 = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
328
+ self.attn = MultiHeadAttention(config)
329
+ self.rms_2 = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
330
+ self.mlp = FeedForward(config)
331
+
332
+ def forward(
333
+ self,
334
+ x: torch.Tensor,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ ) -> torch.Tensor:
337
+ # Pre-norm attention
338
+ x = x + self.attn(self.rms_1(x), attention_mask=attention_mask)
339
+ # Pre-norm FFN
340
+ x = x + self.mlp(self.rms_2(x))
341
+ return x
342
+
343
+
344
+ class ChessForCausalLM(PreTrainedModel):
345
+ """
346
+ Chess Transformer for Causal Language Modeling (next-move prediction).
347
+
348
+ This model is designed to predict the next chess move given a sequence
349
+ of previous moves. It uses an optimized GPT-style architecture with:
350
+ - Token embeddings for chess moves
351
+ - RoPE (Rotary Position Embeddings) - no learned positional embeddings
352
+ - RMSNorm for efficient normalization
353
+ - Stacked transformer blocks
354
+ - Linear head for next-token prediction
355
+
356
+ The model supports weight tying between the embedding layer and the
357
+ output projection to save parameters.
358
+
359
+ Optimizations:
360
+ - RoPE saves n_ctx * n_embd parameters vs learned embeddings
361
+ - RMSNorm saves parameters vs LayerNorm (no bias, no mean centering)
362
+ - Weight tying saves vocab_size * n_embd parameters
363
+
364
+ Example:
365
+ >>> config = ChessConfig(vocab_size=1200, n_embd=128, n_layer=6)
366
+ >>> model = ChessForCausalLM(config)
367
+ >>> inputs = {"input_ids": torch.tensor([[1, 42, 87]])}
368
+ >>> outputs = model(**inputs)
369
+ >>> next_move_logits = outputs.logits[:, -1, :]
370
+ """
371
+
372
+ config_class = ChessConfig
373
+ base_model_prefix = "transformer"
374
+ supports_gradient_checkpointing = True
375
+ # Suppress missing-key warning for tied lm_head when loading
376
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
377
+
378
+ def __init__(self, config: ChessConfig):
379
+ super().__init__(config)
380
+
381
+ # Token embeddings (no positional embeddings - using RoPE instead)
382
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
383
+
384
+ self.drop = nn.Dropout(config.dropout)
385
+
386
+ # Transformer blocks
387
+ self.h = nn.ModuleList([
388
+ TransformerBlock(config) for _ in range(config.n_layer)
389
+ ])
390
+
391
+ # Final RMS norm
392
+ self.rms_f = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
393
+
394
+ # Output head
395
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
396
+
397
+ # Declare tied weights for proper serialization
398
+ if config.tie_weights:
399
+ self._tied_weights_keys = ["lm_head.weight"]
400
+
401
+ # Initialize weights
402
+ self.post_init()
403
+
404
+ # Tie weights if configured
405
+ if config.tie_weights:
406
+ self.tie_weights()
407
+
408
+ def get_input_embeddings(self) -> nn.Module:
409
+ return self.wte
410
+
411
+ def set_input_embeddings(self, new_embeddings: nn.Module):
412
+ self.wte = new_embeddings
413
+ if getattr(self.config, "tie_weights", False):
414
+ self.tie_weights()
415
+
416
+ def get_output_embeddings(self) -> nn.Module:
417
+ return self.lm_head
418
+
419
+ def set_output_embeddings(self, new_embeddings: nn.Module):
420
+ self.lm_head = new_embeddings
421
+
422
+ def tie_weights(self):
423
+ # Use HF helper to tie or clone depending on config
424
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
425
+ self._tie_or_clone_weights(self.lm_head, self.wte)
426
+
427
+ def _init_weights(self, module: nn.Module):
428
+ """Initialize weights following GPT-2 style with adaptations for RMSNorm."""
429
+ if isinstance(module, nn.Linear):
430
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
431
+ if module.bias is not None:
432
+ torch.nn.init.zeros_(module.bias)
433
+ elif isinstance(module, nn.Embedding):
434
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
435
+ elif isinstance(module, RMSNorm):
436
+ torch.nn.init.ones_(module.weight)
437
+
438
+ def forward(
439
+ self,
440
+ input_ids: torch.LongTensor,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.LongTensor] = None,
443
+ labels: Optional[torch.LongTensor] = None,
444
+ return_dict: Optional[bool] = None,
445
+ **kwargs,
446
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
447
+ """
448
+ Forward pass of the model.
449
+
450
+ Args:
451
+ input_ids: Token IDs of shape (batch_size, seq_len).
452
+ attention_mask: Attention mask of shape (batch_size, seq_len).
453
+ position_ids: Not used (kept for compatibility). Position encoding via RoPE.
454
+ labels: Labels for language modeling loss.
455
+ return_dict: Whether to return a ModelOutput object.
456
+
457
+ Returns:
458
+ CausalLMOutputWithPast containing loss (if labels provided) and logits.
459
+ """
460
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
461
+
462
+ # Get token embeddings (no positional embeddings - using RoPE in attention)
463
+ hidden_states = self.wte(input_ids)
464
+ hidden_states = self.drop(hidden_states)
465
+
466
+ # Pass through transformer blocks
467
+ for block in self.h:
468
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
469
+
470
+ # Final RMS norm
471
+ hidden_states = self.rms_f(hidden_states)
472
+
473
+ # Get logits
474
+ logits = self.lm_head(hidden_states)
475
+
476
+ # Compute loss if labels are provided
477
+ loss = None
478
+ if labels is not None:
479
+ # Shift logits and labels for next-token prediction
480
+ shift_logits = logits[..., :-1, :].contiguous()
481
+ shift_labels = labels[..., 1:].contiguous()
482
+
483
+ # Flatten for cross-entropy
484
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
485
+ loss = loss_fct(
486
+ shift_logits.view(-1, shift_logits.size(-1)),
487
+ shift_labels.view(-1),
488
+ )
489
+
490
+ if not return_dict:
491
+ output = (logits,)
492
+ return ((loss,) + output) if loss is not None else output
493
+
494
+ return CausalLMOutputWithPast(
495
+ loss=loss,
496
+ logits=logits,
497
+ past_key_values=None,
498
+ hidden_states=None,
499
+ attentions=None,
500
+ )
501
+
502
+ @torch.no_grad()
503
+ def generate_move(
504
+ self,
505
+ input_ids: torch.LongTensor,
506
+ temperature: float = 1.0,
507
+ top_k: Optional[int] = None,
508
+ top_p: Optional[float] = None,
509
+ ) -> int:
510
+ """
511
+ Generate the next move given a sequence of moves.
512
+
513
+ Args:
514
+ input_ids: Token IDs of shape (1, seq_len).
515
+ temperature: Sampling temperature (1.0 = no change).
516
+ top_k: If set, only sample from top k tokens.
517
+ top_p: If set, use nucleus sampling with this threshold.
518
+
519
+ Returns:
520
+ The token ID of the predicted next move.
521
+ """
522
+ self.eval()
523
+
524
+ # Get logits for the last position
525
+ outputs = self(input_ids)
526
+ logits = outputs.logits[:, -1, :] / temperature
527
+
528
+ # Apply top-k filtering
529
+ if top_k is not None:
530
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
531
+ logits[indices_to_remove] = float("-inf")
532
+
533
+ # Apply top-p (nucleus) filtering
534
+ if top_p is not None:
535
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
536
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
537
+
538
+ # Remove tokens with cumulative probability above the threshold
539
+ sorted_indices_to_remove = cumulative_probs > top_p
540
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
541
+ sorted_indices_to_remove[..., 0] = 0
542
+
543
+ indices_to_remove = sorted_indices_to_remove.scatter(
544
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
545
+ )
546
+ logits[indices_to_remove] = float("-inf")
547
+
548
+ # Sample from the distribution
549
+ probs = F.softmax(logits, dim=-1)
550
+ next_token = torch.multinomial(probs, num_samples=1)
551
+
552
+ return next_token.item()
553
+
554
+
555
+ # Register the model with Auto classes for easy loading
556
+ from transformers import AutoConfig, AutoModelForCausalLM
557
+
558
+ AutoConfig.register("chess_transformer", ChessConfig)
559
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5f2ecf2379e1f892827eb5d85dc3f5422614d4e0b1d66ce350e6928b00bce750
3
- size 3289680
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd0e9f845e5a66d913b529f674adbde1d65b6d3e032db545c2418f5d64859f0c
3
+ size 3289648
model.safetensors.backup ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa6a0947d4bdf00843602a929307237639ee2d504e9ea9ba30ccb24f3506d3b6
3
+ size 3290200
tokenizer_config.json CHANGED
@@ -33,12 +33,6 @@
33
  "special": true
34
  }
35
  },
36
- "auto_map": {
37
- "AutoTokenizer": [
38
- "tokenizer.ChessTokenizer",
39
- null
40
- ]
41
- },
42
  "bos_token": "[BOS]",
43
  "clean_up_tokenization_spaces": false,
44
  "eos_token": "[EOS]",
 
33
  "special": true
34
  }
35
  },
 
 
 
 
 
 
36
  "bos_token": "[BOS]",
37
  "clean_up_tokenization_spaces": false,
38
  "eos_token": "[EOS]",
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cff970f63c9bb8338a25bc0711d8b45428ad0af0ea1bd1296c70cafac37a322
3
+ size 5777