Nestor02 commited on
Commit
db4f1c0
·
verified ·
1 Parent(s): cc82a34

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +817 -0
model.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+
4
+ This module provides a modular GPT-style transformer architecture
5
+ designed to fit within the 1M parameter constraint.
6
+
7
+ Key components:
8
+ - ChessConfig: Configuration class for model hyperparameters
9
+ - ChessForCausalLM: The main model class for next-move prediction
10
+
11
+ Modular options:
12
+ - Attention: MHA (standard), GQA (grouped query), MQA (multi-query)
13
+ - Position encoding: learned, rope (rotary), alibi
14
+ - FFN activation: gelu, swiglu
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union, Literal
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from transformers import PretrainedConfig, PreTrainedModel
27
+ try:
28
+ from transformers.generation.utils import GenerationMixin
29
+ except ImportError: # Fallback for older transformers
30
+ from transformers import GenerationMixin
31
+ from transformers.modeling_outputs import CausalLMOutputWithPast
32
+
33
+ # Type aliases for configuration options
34
+ AttentionType = Literal["mha", "gqa", "mqa"]
35
+ PositionEncoding = Literal["learned", "rope", "alibi"]
36
+ FFNType = Literal["gelu", "swiglu"]
37
+
38
+
39
+ class ChessConfig(PretrainedConfig):
40
+ """
41
+ Configuration class for the Chess Transformer model.
42
+
43
+ This configuration is designed for a ~1M parameter model.
44
+ Students can adjust these values to explore different architectures.
45
+
46
+ Parameter budget breakdown (with default values):
47
+ - Embeddings (vocab): 1200 x 128 = 153,600
48
+ - Position Embeddings: 256 x 128 = 32,768 (0 with rope/alibi)
49
+ - Transformer Layers: 6 x ~120,000 = ~720,000
50
+ - LM Head (with weight tying): 0 (shared with embeddings)
51
+ - Total: ~906,000 parameters
52
+
53
+ Attributes:
54
+ vocab_size: Size of the vocabulary (number of unique moves).
55
+ n_embd: Embedding dimension (d_model).
56
+ n_layer: Number of transformer layers.
57
+ n_head: Number of attention heads.
58
+ n_kv_heads: Number of key-value heads (for GQA/MQA). None = same as n_head.
59
+ n_ctx: Maximum sequence length (context window).
60
+ n_inner: Feed-forward inner dimension (default: 3 * n_embd).
61
+ dropout: Dropout probability.
62
+ layer_norm_epsilon: Epsilon for layer normalization.
63
+ tie_weights: Whether to tie embedding and output weights.
64
+ attention_type: Type of attention mechanism ("mha", "gqa", "mqa").
65
+ pos_encoding: Type of position encoding ("learned", "rope", "alibi").
66
+ ffn_type: Type of FFN activation ("gelu", "swiglu").
67
+ rope_theta: Base frequency for RoPE (default 10000.0).
68
+ legal_loss_weight: Auxiliary legal-move loss weight (default 0.0).
69
+ """
70
+
71
+ model_type = "chess_transformer"
72
+
73
+ def __init__(
74
+ self,
75
+ vocab_size: int = 1200,
76
+ n_embd: int = 128,
77
+ n_layer: int = 6,
78
+ n_head: int = 4,
79
+ n_kv_heads: Optional[int] = None,
80
+ n_ctx: int = 256,
81
+ n_inner: Optional[int] = None,
82
+ dropout: float = 0.1,
83
+ layer_norm_epsilon: float = 1e-5,
84
+ tie_weights: bool = True,
85
+ # New modular options
86
+ attention_type: AttentionType = "mha",
87
+ pos_encoding: PositionEncoding = "learned",
88
+ ffn_type: FFNType = "gelu",
89
+ rope_theta: float = 10000.0,
90
+ legal_loss_weight: float = 0.0,
91
+ # Token IDs
92
+ pad_token_id: int = 0,
93
+ bos_token_id: int = 1,
94
+ eos_token_id: int = 2,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(
98
+ pad_token_id=pad_token_id,
99
+ bos_token_id=bos_token_id,
100
+ eos_token_id=eos_token_id,
101
+ **kwargs,
102
+ )
103
+
104
+ self.vocab_size = vocab_size
105
+ self.n_embd = n_embd
106
+ self.n_layer = n_layer
107
+ self.n_head = n_head
108
+ self.n_ctx = n_ctx
109
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd
110
+ self.dropout = dropout
111
+ self.layer_norm_epsilon = layer_norm_epsilon
112
+ self.tie_weights = tie_weights
113
+ # Inform HF base class about tying behavior
114
+ self.tie_word_embeddings = bool(tie_weights)
115
+
116
+ # Modular architecture options
117
+ self.attention_type = attention_type
118
+ self.pos_encoding = pos_encoding
119
+ self.ffn_type = ffn_type
120
+ self.rope_theta = rope_theta
121
+ self.legal_loss_weight = legal_loss_weight
122
+
123
+ # Handle n_kv_heads based on attention type
124
+ if n_kv_heads is None:
125
+ if attention_type == "mqa":
126
+ self.n_kv_heads = 1
127
+ elif attention_type == "gqa":
128
+ # Default to n_head // 2 for GQA, but at least 1
129
+ self.n_kv_heads = max(1, n_head // 2)
130
+ else: # mha
131
+ self.n_kv_heads = n_head
132
+ else:
133
+ self.n_kv_heads = n_kv_heads
134
+
135
+ # Validation
136
+ assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by n_head ({n_head})"
137
+ assert n_head % self.n_kv_heads == 0, f"n_head ({n_head}) must be divisible by n_kv_heads ({self.n_kv_heads})"
138
+ assert attention_type in ("mha", "gqa", "mqa"), f"Invalid attention_type: {attention_type}"
139
+ assert pos_encoding in ("learned", "rope", "alibi"), f"Invalid pos_encoding: {pos_encoding}"
140
+ assert ffn_type in ("gelu", "swiglu"), f"Invalid ffn_type: {ffn_type}"
141
+
142
+
143
+ # ==============================================================================
144
+ # Position Encoding Modules
145
+ # ==============================================================================
146
+
147
+
148
+ class RotaryEmbedding(nn.Module):
149
+ """
150
+ Rotary Position Embedding (RoPE).
151
+
152
+ Applies rotary embeddings to queries and keys, encoding position
153
+ information through rotation in the complex plane. This allows
154
+ relative position information without explicit position embeddings.
155
+
156
+ Reference: https://arxiv.org/abs/2104.09864
157
+ """
158
+
159
+ def __init__(self, dim: int, max_seq_len: int = 256, theta: float = 10000.0):
160
+ super().__init__()
161
+ self.dim = dim
162
+ self.max_seq_len = max_seq_len
163
+ self.theta = theta
164
+
165
+ # Precompute frequency bands
166
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
167
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
168
+
169
+ # Precompute sin/cos for all positions
170
+ self._build_cache(max_seq_len)
171
+
172
+ def _build_cache(self, seq_len: int):
173
+ """Build sin/cos cache for given sequence length."""
174
+ positions = torch.arange(seq_len, dtype=torch.float32)
175
+ freqs = torch.outer(positions, self.inv_freq)
176
+ # Create [cos, sin] interleaved for rotation
177
+ emb = torch.cat([freqs, freqs], dim=-1)
178
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
179
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
180
+
181
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
182
+ """Return cos and sin for the given sequence length."""
183
+ if seq_len > self.max_seq_len:
184
+ self._build_cache(seq_len)
185
+ self.max_seq_len = seq_len
186
+ return (
187
+ self.cos_cached[:seq_len].to(x.dtype),
188
+ self.sin_cached[:seq_len].to(x.dtype),
189
+ )
190
+
191
+
192
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
193
+ """Rotate half the hidden dims of the input."""
194
+ x1 = x[..., : x.shape[-1] // 2]
195
+ x2 = x[..., x.shape[-1] // 2 :]
196
+ return torch.cat([-x2, x1], dim=-1)
197
+
198
+
199
+ def apply_rotary_pos_emb(
200
+ q: torch.Tensor,
201
+ k: torch.Tensor,
202
+ cos: torch.Tensor,
203
+ sin: torch.Tensor,
204
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ """
206
+ Apply rotary position embedding to queries and keys.
207
+
208
+ Args:
209
+ q: Query tensor of shape (batch, n_heads, seq_len, head_dim)
210
+ k: Key tensor of shape (batch, n_kv_heads, seq_len, head_dim)
211
+ cos: Cosine of rotation angles
212
+ sin: Sine of rotation angles
213
+
214
+ Returns:
215
+ Rotated q and k tensors
216
+ """
217
+ # cos/sin shape: (seq_len, head_dim) -> (1, 1, seq_len, head_dim)
218
+ cos = cos.unsqueeze(0).unsqueeze(0)
219
+ sin = sin.unsqueeze(0).unsqueeze(0)
220
+
221
+ q_embed = (q * cos) + (rotate_half(q) * sin)
222
+ k_embed = (k * cos) + (rotate_half(k) * sin)
223
+ return q_embed, k_embed
224
+
225
+
226
+ def build_alibi_slopes(n_heads: int) -> torch.Tensor:
227
+ """
228
+ Build ALiBi slopes for attention bias.
229
+
230
+ ALiBi adds a linear bias to attention scores based on position distance.
231
+ The slope decreases geometrically for each head.
232
+
233
+ Reference: https://arxiv.org/abs/2108.12409
234
+ """
235
+ def get_slopes_power_of_2(n: int) -> list:
236
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
237
+ ratio = start
238
+ return [start * (ratio ** i) for i in range(n)]
239
+
240
+ if math.log2(n_heads).is_integer():
241
+ slopes = get_slopes_power_of_2(n_heads)
242
+ else:
243
+ # For non-power-of-2, use closest power of 2 and interpolate
244
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
245
+ slopes = get_slopes_power_of_2(closest_power_of_2)
246
+ extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)
247
+ slopes = slopes + extra_slopes[0::2][: n_heads - closest_power_of_2]
248
+
249
+ return torch.tensor(slopes, dtype=torch.float32)
250
+
251
+
252
+ def build_alibi_bias(seq_len: int, slopes: torch.Tensor) -> torch.Tensor:
253
+ """
254
+ Build the ALiBi attention bias matrix.
255
+
256
+ Args:
257
+ seq_len: Sequence length
258
+ slopes: ALiBi slopes tensor of shape (n_heads,)
259
+
260
+ Returns:
261
+ Bias tensor of shape (1, n_heads, seq_len, seq_len)
262
+ """
263
+ # Create distance matrix: distance[i, j] = j - i (negative for causal)
264
+ positions = torch.arange(seq_len)
265
+ distance = positions.unsqueeze(0) - positions.unsqueeze(1) # (seq_len, seq_len)
266
+
267
+ # Apply slopes: (n_heads, 1, 1) * (seq_len, seq_len) -> (n_heads, seq_len, seq_len)
268
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * distance.unsqueeze(0)
269
+
270
+ return alibi.unsqueeze(0) # (1, n_heads, seq_len, seq_len)
271
+
272
+
273
+ # ==============================================================================
274
+ # Attention Modules
275
+ # ==============================================================================
276
+
277
+
278
+ class Attention(nn.Module):
279
+ """
280
+ Unified attention module supporting MHA, GQA, and MQA.
281
+
282
+ Supports multiple position encoding methods:
283
+ - learned: Standard learned position embeddings (handled externally)
284
+ - rope: Rotary Position Embeddings (applied to Q and K)
285
+ - alibi: Attention with Linear Biases (added to attention scores)
286
+
287
+ Architecture variants:
288
+ - MHA (Multi-Head Attention): n_kv_heads == n_head
289
+ - GQA (Grouped Query Attention): n_kv_heads < n_head, n_head % n_kv_heads == 0
290
+ - MQA (Multi-Query Attention): n_kv_heads == 1
291
+ """
292
+
293
+ def __init__(self, config: ChessConfig):
294
+ super().__init__()
295
+
296
+ self.n_head = config.n_head
297
+ self.n_kv_heads = config.n_kv_heads
298
+ self.n_embd = config.n_embd
299
+ self.head_dim = config.n_embd // config.n_head
300
+ self.n_rep = config.n_head // config.n_kv_heads # Repetition factor for GQA/MQA
301
+ self.pos_encoding = config.pos_encoding
302
+
303
+ # Compute projection sizes
304
+ # Q: n_head * head_dim = n_embd
305
+ # K, V: n_kv_heads * head_dim (smaller for GQA/MQA)
306
+ self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
307
+ self.k_proj = nn.Linear(config.n_embd, config.n_kv_heads * self.head_dim, bias=False)
308
+ self.v_proj = nn.Linear(config.n_embd, config.n_kv_heads * self.head_dim, bias=False)
309
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
310
+
311
+ self.dropout = nn.Dropout(config.dropout)
312
+
313
+ # Position encoding components
314
+ if config.pos_encoding == "rope":
315
+ self.rotary_emb = RotaryEmbedding(
316
+ dim=self.head_dim,
317
+ max_seq_len=config.n_ctx,
318
+ theta=config.rope_theta,
319
+ )
320
+ elif config.pos_encoding == "alibi":
321
+ # Precompute ALiBi slopes
322
+ slopes = build_alibi_slopes(config.n_head)
323
+ self.register_buffer("alibi_slopes", slopes, persistent=False)
324
+
325
+ # Causal mask
326
+ self.register_buffer(
327
+ "causal_mask",
328
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
329
+ 1, 1, config.n_ctx, config.n_ctx
330
+ ),
331
+ persistent=False,
332
+ )
333
+
334
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
335
+ """
336
+ Repeat KV heads to match the number of query heads.
337
+
338
+ For GQA/MQA, we need to expand K and V to match Q's head count.
339
+ Input shape: (batch, n_kv_heads, seq_len, head_dim)
340
+ Output shape: (batch, n_head, seq_len, head_dim)
341
+ """
342
+ if self.n_rep == 1:
343
+ return x
344
+ batch, n_kv_heads, seq_len, head_dim = x.shape
345
+ x = x.unsqueeze(2).expand(batch, n_kv_heads, self.n_rep, seq_len, head_dim)
346
+ return x.reshape(batch, n_kv_heads * self.n_rep, seq_len, head_dim)
347
+
348
+ def forward(
349
+ self,
350
+ x: torch.Tensor,
351
+ attention_mask: Optional[torch.Tensor] = None,
352
+ ) -> torch.Tensor:
353
+ batch_size, seq_len, _ = x.size()
354
+
355
+ # Compute Q, K, V projections
356
+ q = self.q_proj(x)
357
+ k = self.k_proj(x)
358
+ v = self.v_proj(x)
359
+
360
+ # Reshape for multi-head attention
361
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
362
+ k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
363
+ v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
364
+
365
+ # Apply rotary embeddings if using RoPE
366
+ if self.pos_encoding == "rope":
367
+ cos, sin = self.rotary_emb(q, seq_len)
368
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
369
+
370
+ # Repeat K and V for GQA/MQA
371
+ k = self._repeat_kv(k)
372
+ v = self._repeat_kv(v)
373
+
374
+ # Scaled dot-product attention
375
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
376
+
377
+ # Apply ALiBi bias if using ALiBi
378
+ if self.pos_encoding == "alibi":
379
+ alibi_bias = build_alibi_bias(seq_len, self.alibi_slopes.to(x.device))
380
+ attn_weights = attn_weights + alibi_bias.to(attn_weights.dtype)
381
+
382
+ # Apply causal mask
383
+ causal_mask = self.causal_mask[:, :, :seq_len, :seq_len]
384
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
385
+
386
+ # Apply attention mask (for padding)
387
+ if attention_mask is not None:
388
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
389
+ attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
390
+
391
+ attn_weights = F.softmax(attn_weights, dim=-1)
392
+ attn_weights = self.dropout(attn_weights)
393
+
394
+ # Apply attention to values
395
+ attn_output = torch.matmul(attn_weights, v)
396
+
397
+ # Reshape back
398
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
399
+ batch_size, seq_len, self.n_embd
400
+ )
401
+
402
+ # Output projection
403
+ attn_output = self.c_proj(attn_output)
404
+
405
+ return attn_output
406
+
407
+
408
+ # Alias for backward compatibility
409
+ MultiHeadAttention = Attention
410
+
411
+
412
+ # ==============================================================================
413
+ # Feed-Forward Modules
414
+ # ==============================================================================
415
+
416
+
417
+ class FeedForward(nn.Module):
418
+ """
419
+ Feed-forward network (MLP) module with configurable activation.
420
+
421
+ Supports:
422
+ - gelu: Standard GELU activation (2 weight matrices)
423
+ - swiglu: SwiGLU activation (3 weight matrices, better performance)
424
+
425
+ For SwiGLU, the hidden dimension is adjusted to keep parameter count similar:
426
+ - GELU: 2 * n_embd * n_inner parameters
427
+ - SwiGLU: 3 * n_embd * n_inner_swiglu parameters
428
+ To match, n_inner_swiglu = 2/3 * n_inner
429
+ """
430
+
431
+ def __init__(self, config: ChessConfig):
432
+ super().__init__()
433
+
434
+ self.ffn_type = config.ffn_type
435
+
436
+ if config.ffn_type == "swiglu":
437
+ # SwiGLU uses 3 projections, so reduce hidden dim to compensate
438
+ # Adjust n_inner for SwiGLU to maintain similar parameter count
439
+ hidden_dim = int(2 * config.n_inner / 3)
440
+ # Round to nearest multiple of 8 for efficiency
441
+ hidden_dim = ((hidden_dim + 7) // 8) * 8
442
+
443
+ self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Gate
444
+ self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Up
445
+ self.w3 = nn.Linear(hidden_dim, config.n_embd, bias=False) # Down
446
+ else: # gelu
447
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
448
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
449
+
450
+ self.dropout = nn.Dropout(config.dropout)
451
+
452
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
453
+ if self.ffn_type == "swiglu":
454
+ # SwiGLU: Swish(W1*x) * W2*x, then W3
455
+ gate = F.silu(self.w1(x)) # Swish activation
456
+ up = self.w2(x)
457
+ x = gate * up
458
+ x = self.w3(x)
459
+ x = self.dropout(x)
460
+ else: # gelu
461
+ x = self.c_fc(x)
462
+ x = F.gelu(x)
463
+ x = self.c_proj(x)
464
+ x = self.dropout(x)
465
+ return x
466
+
467
+
468
+ # ==============================================================================
469
+ # Transformer Block
470
+ # ==============================================================================
471
+
472
+
473
+ class TransformerBlock(nn.Module):
474
+ """
475
+ A single transformer block with attention and feed-forward layers.
476
+
477
+ Uses pre-normalization (LayerNorm before attention/FFN) for better
478
+ training stability.
479
+ """
480
+
481
+ def __init__(self, config: ChessConfig):
482
+ super().__init__()
483
+
484
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
485
+ self.attn = Attention(config)
486
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
487
+ self.mlp = FeedForward(config)
488
+
489
+ def forward(
490
+ self,
491
+ x: torch.Tensor,
492
+ attention_mask: Optional[torch.Tensor] = None,
493
+ ) -> torch.Tensor:
494
+ # Pre-norm attention
495
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
496
+ # Pre-norm FFN
497
+ x = x + self.mlp(self.ln_2(x))
498
+ return x
499
+
500
+
501
+ # ==============================================================================
502
+ # Main Model
503
+ # ==============================================================================
504
+
505
+
506
+ class ChessForCausalLM(PreTrainedModel, GenerationMixin):
507
+ """
508
+ Chess Transformer for Causal Language Modeling (next-move prediction).
509
+
510
+ This model is designed to predict the next chess move given a sequence
511
+ of previous moves. It uses a modular GPT-style architecture with:
512
+ - Token embeddings for chess moves
513
+ - Configurable positional embeddings (learned/RoPE/ALiBi)
514
+ - Stacked transformer blocks with configurable attention (MHA/GQA/MQA)
515
+ - Configurable FFN activation (GELU/SwiGLU)
516
+ - Linear head for next-token prediction
517
+
518
+ The model supports weight tying between the embedding layer and the
519
+ output projection to save parameters.
520
+
521
+ Example:
522
+ >>> # Baseline configuration
523
+ >>> config = ChessConfig(vocab_size=1200, n_embd=128, n_layer=6)
524
+ >>> model = ChessForCausalLM(config)
525
+
526
+ >>> # GQA with RoPE (saves parameters, allows more layers)
527
+ >>> config = ChessConfig(
528
+ ... vocab_size=1200, n_embd=128, n_layer=8,
529
+ ... attention_type="gqa", n_kv_heads=2,
530
+ ... pos_encoding="rope"
531
+ ... )
532
+ >>> model = ChessForCausalLM(config)
533
+ """
534
+
535
+ config_class = ChessConfig
536
+ base_model_prefix = "transformer"
537
+ supports_gradient_checkpointing = True
538
+ # Suppress missing-key warning for tied lm_head when loading
539
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
540
+
541
+ def __init__(self, config: ChessConfig):
542
+ super().__init__(config)
543
+
544
+ self.pos_encoding = config.pos_encoding
545
+
546
+ # Token embeddings (always needed)
547
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
548
+
549
+ # Position embeddings (only for learned position encoding)
550
+ if config.pos_encoding == "learned":
551
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
552
+ else:
553
+ # RoPE and ALiBi don't need position embeddings
554
+ self.wpe = None
555
+
556
+ self.drop = nn.Dropout(config.dropout)
557
+
558
+ # Transformer blocks
559
+ self.h = nn.ModuleList([
560
+ TransformerBlock(config) for _ in range(config.n_layer)
561
+ ])
562
+
563
+ # Final layer norm
564
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
565
+
566
+ # Output head
567
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
568
+
569
+ # Declare tied weights for proper serialization
570
+ if config.tie_weights:
571
+ self._tied_weights_keys = ["lm_head.weight"]
572
+
573
+ # Initialize weights
574
+ self.post_init()
575
+
576
+ # Tie weights if configured
577
+ if config.tie_weights:
578
+ self.tie_weights()
579
+
580
+ def get_input_embeddings(self) -> nn.Module:
581
+ return self.wte
582
+
583
+ def set_input_embeddings(self, new_embeddings: nn.Module):
584
+ self.wte = new_embeddings
585
+ if getattr(self.config, "tie_weights", False):
586
+ self.tie_weights()
587
+
588
+ def get_output_embeddings(self) -> nn.Module:
589
+ return self.lm_head
590
+
591
+ def set_output_embeddings(self, new_embeddings: nn.Module):
592
+ self.lm_head = new_embeddings
593
+
594
+ def tie_weights(self):
595
+ # Use HF helper to tie or clone depending on config
596
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
597
+ self._tie_or_clone_weights(self.lm_head, self.wte)
598
+
599
+ def prepare_inputs_for_generation(
600
+ self,
601
+ input_ids: torch.LongTensor,
602
+ past_key_values: Optional[Tuple] = None,
603
+ attention_mask: Optional[torch.Tensor] = None,
604
+ **kwargs,
605
+ ) -> dict:
606
+ # No KV-cache support; fall back to full forward each step.
607
+ if past_key_values is not None:
608
+ input_ids = input_ids[:, -1:]
609
+ return {
610
+ "input_ids": input_ids,
611
+ "attention_mask": attention_mask,
612
+ "past_key_values": past_key_values,
613
+ }
614
+
615
+ def _init_weights(self, module: nn.Module):
616
+ """Initialize weights following GPT-2 style."""
617
+ if isinstance(module, nn.Linear):
618
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
619
+ if module.bias is not None:
620
+ torch.nn.init.zeros_(module.bias)
621
+ elif isinstance(module, nn.Embedding):
622
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
623
+ elif isinstance(module, nn.LayerNorm):
624
+ torch.nn.init.ones_(module.weight)
625
+ torch.nn.init.zeros_(module.bias)
626
+
627
+ def forward(
628
+ self,
629
+ input_ids: torch.LongTensor,
630
+ attention_mask: Optional[torch.Tensor] = None,
631
+ position_ids: Optional[torch.LongTensor] = None,
632
+ labels: Optional[torch.LongTensor] = None,
633
+ return_dict: Optional[bool] = None,
634
+ legal_token_ids: Optional[List[List[int]]] = None,
635
+ **kwargs,
636
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
637
+ """
638
+ Forward pass of the model.
639
+
640
+ Args:
641
+ input_ids: Token IDs of shape (batch_size, seq_len).
642
+ attention_mask: Attention mask of shape (batch_size, seq_len).
643
+ position_ids: Position IDs of shape (batch_size, seq_len).
644
+ labels: Labels for language modeling loss.
645
+ return_dict: Whether to return a ModelOutput object.
646
+
647
+ Returns:
648
+ CausalLMOutputWithPast containing loss (if labels provided) and logits.
649
+ """
650
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
651
+
652
+ batch_size, seq_len = input_ids.size()
653
+ device = input_ids.device
654
+
655
+ # Get token embeddings
656
+ hidden_states = self.wte(input_ids)
657
+
658
+ # Add position embeddings only for learned encoding
659
+ if self.pos_encoding == "learned":
660
+ if position_ids is None:
661
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
662
+ position_embeds = self.wpe(position_ids)
663
+ hidden_states = hidden_states + position_embeds
664
+
665
+ # Apply dropout
666
+ hidden_states = self.drop(hidden_states)
667
+
668
+ # Pass through transformer blocks
669
+ for block in self.h:
670
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
671
+
672
+ # Final layer norm
673
+ hidden_states = self.ln_f(hidden_states)
674
+
675
+ # Get logits
676
+ logits = self.lm_head(hidden_states)
677
+
678
+ # Compute loss if labels are provided
679
+ loss = None
680
+ if labels is not None:
681
+ # Shift logits and labels for next-token prediction
682
+ shift_logits = logits[..., :-1, :].contiguous()
683
+ shift_labels = labels[..., 1:].contiguous()
684
+
685
+ # Flatten for cross-entropy
686
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
687
+ # loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
688
+ loss = loss_fct(
689
+ shift_logits.view(-1, shift_logits.size(-1)),
690
+ shift_labels.view(-1),
691
+ )
692
+
693
+ if self.config.legal_loss_weight > 0 and legal_token_ids:
694
+ aux_loss = self._legal_move_loss(logits, labels, legal_token_ids)
695
+ if aux_loss is not None:
696
+ loss = loss + self.config.legal_loss_weight * aux_loss
697
+
698
+ if not return_dict:
699
+ output = (logits,)
700
+ return ((loss,) + output) if loss is not None else output
701
+
702
+ return CausalLMOutputWithPast(
703
+ loss=loss,
704
+ logits=logits,
705
+ past_key_values=None,
706
+ hidden_states=None,
707
+ attentions=None,
708
+ )
709
+
710
+ def _legal_move_loss(
711
+ self,
712
+ logits: torch.Tensor,
713
+ labels: torch.Tensor,
714
+ legal_token_ids: List[List[int]],
715
+ ) -> Optional[torch.Tensor]:
716
+ batch_size = logits.size(0)
717
+ total_loss = logits.new_tensor(0.0)
718
+ count = 0
719
+
720
+ for batch_idx in range(batch_size):
721
+ if batch_idx >= len(legal_token_ids):
722
+ continue
723
+ legal_ids = legal_token_ids[batch_idx]
724
+ if not legal_ids:
725
+ continue
726
+
727
+ label_row = labels[batch_idx]
728
+ valid_mask = label_row != -100
729
+ for special_id in (
730
+ getattr(self.config, "pad_token_id", None),
731
+ getattr(self.config, "bos_token_id", None),
732
+ getattr(self.config, "eos_token_id", None),
733
+ ):
734
+ if special_id is not None:
735
+ valid_mask = valid_mask & (label_row != int(special_id))
736
+
737
+ valid_positions = valid_mask.nonzero(as_tuple=False)
738
+ if valid_positions.numel() == 0:
739
+ continue
740
+
741
+ last_pos = int(valid_positions[-1].item())
742
+ pred_pos = last_pos - 1
743
+ if pred_pos < 0:
744
+ continue
745
+
746
+ logits_slice = logits[batch_idx, pred_pos]
747
+ legal_logits = logits_slice.index_select(
748
+ 0,
749
+ torch.tensor(legal_ids, device=logits_slice.device, dtype=torch.long),
750
+ )
751
+
752
+ loss = torch.logsumexp(logits_slice, dim=-1) - torch.logsumexp(legal_logits, dim=-1)
753
+ total_loss = total_loss + loss
754
+ count += 1
755
+
756
+ if count == 0:
757
+ return None
758
+ return total_loss / count
759
+
760
+ @torch.no_grad()
761
+ def generate_move(
762
+ self,
763
+ input_ids: torch.LongTensor,
764
+ temperature: float = 1.0,
765
+ top_k: Optional[int] = None,
766
+ top_p: Optional[float] = None,
767
+ ) -> int:
768
+ """
769
+ Generate the next move given a sequence of moves.
770
+
771
+ Args:
772
+ input_ids: Token IDs of shape (1, seq_len).
773
+ temperature: Sampling temperature (1.0 = no change).
774
+ top_k: If set, only sample from top k tokens.
775
+ top_p: If set, use nucleus sampling with this threshold.
776
+
777
+ Returns:
778
+ The token ID of the predicted next move.
779
+ """
780
+ self.eval()
781
+
782
+ # Get logits for the last position
783
+ outputs = self(input_ids)
784
+ logits = outputs.logits[:, -1, :] / temperature
785
+
786
+ # Apply top-k filtering
787
+ if top_k is not None:
788
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
789
+ logits[indices_to_remove] = float("-inf")
790
+
791
+ # Apply top-p (nucleus) filtering
792
+ if top_p is not None:
793
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
794
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
795
+
796
+ # Remove tokens with cumulative probability above the threshold
797
+ sorted_indices_to_remove = cumulative_probs > top_p
798
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
799
+ sorted_indices_to_remove[..., 0] = 0
800
+
801
+ indices_to_remove = sorted_indices_to_remove.scatter(
802
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
803
+ )
804
+ logits[indices_to_remove] = float("-inf")
805
+
806
+ # Sample from the distribution
807
+ probs = F.softmax(logits, dim=-1)
808
+ next_token = torch.multinomial(probs, num_samples=1)
809
+
810
+ return next_token.item()
811
+
812
+
813
+ # Register the model with Auto classes for easy loading
814
+ from transformers import AutoConfig, AutoModelForCausalLM
815
+
816
+ AutoConfig.register("chess_transformer", ChessConfig)
817
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)