swdo commited on
Commit
cb8e145
·
verified ·
1 Parent(s): 125f38d

Chess Challenge submission by swdo

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. model.py +83 -277
config.json CHANGED
@@ -3,8 +3,8 @@
3
  "ChessTRMForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "model.ChessConfig",
7
- "AutoModelForCausalLM": "model.ChessForCausalLM"
8
  },
9
  "bos_token_id": 1,
10
  "dropout": 0.1,
 
3
  "ChessTRMForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "model.ChessTRMConfig",
7
+ "AutoModelForCausalLM": "model.ChessTRMForCausalLM"
8
  },
9
  "bos_token_id": 1,
10
  "dropout": 0.1,
model.py CHANGED
@@ -1,63 +1,34 @@
1
  """
2
- Chess Transformer Model for the Chess Challenge.
3
 
4
- This module provides a simple GPT-style transformer architecture
5
- designed to fit within the 1M parameter constraint.
6
-
7
- Key components:
8
- - ChessConfig: Configuration class for model hyperparameters
9
- - ChessForCausalLM: The main model class for next-move prediction
10
  """
11
 
12
  from __future__ import annotations
13
 
14
  import math
15
- from dataclasses import dataclass
16
  from typing import Optional, Tuple, Union
17
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
- from transformers import PretrainedConfig, PreTrainedModel
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
 
24
 
25
- class ChessConfig(PretrainedConfig):
26
- """
27
- Configuration class for the Chess Transformer model.
28
-
29
- This configuration is designed for a ~1M parameter model.
30
- Students can adjust these values to explore different architectures.
31
-
32
- Parameter budget breakdown (with default values):
33
- - Embeddings (vocab): 1200 x 128 = 153,600
34
- - Position Embeddings: 256 x 128 = 32,768
35
- - Transformer Layers: 6 x ~120,000 = ~720,000
36
- - LM Head (with weight tying): 0 (shared with embeddings)
37
- - Total: ~906,000 parameters
38
-
39
- Attributes:
40
- vocab_size: Size of the vocabulary (number of unique moves).
41
- n_embd: Embedding dimension (d_model).
42
- n_layer: Number of transformer layers.
43
- n_head: Number of attention heads.
44
- n_ctx: Maximum sequence length (context window).
45
- n_inner: Feed-forward inner dimension (default: 3 * n_embd).
46
- dropout: Dropout probability.
47
- layer_norm_epsilon: Epsilon for layer normalization.
48
- tie_weights: Whether to tie embedding and output weights.
49
- """
50
-
51
- model_type = "chess_transformer"
52
-
53
  def __init__(
54
  self,
55
  vocab_size: int = 1200,
56
  n_embd: int = 128,
57
- n_layer: int = 6,
58
  n_head: int = 4,
59
  n_ctx: int = 256,
60
  n_inner: Optional[int] = None,
 
61
  dropout: float = 0.1,
62
  layer_norm_epsilon: float = 1e-5,
63
  tie_weights: bool = True,
@@ -72,113 +43,73 @@ class ChessConfig(PretrainedConfig):
72
  eos_token_id=eos_token_id,
73
  **kwargs,
74
  )
75
-
76
- self.vocab_size = vocab_size
77
- self.n_embd = n_embd
78
- self.n_layer = n_layer
79
- self.n_head = n_head
80
- self.n_ctx = n_ctx
81
- self.n_inner = n_inner if n_inner is not None else 3 * n_embd # Reduced from 4x to 3x
82
- self.dropout = dropout
83
- self.layer_norm_epsilon = layer_norm_epsilon
84
- self.tie_weights = tie_weights
85
- # Inform HF base class about tying behavior
86
  self.tie_word_embeddings = bool(tie_weights)
87
 
88
 
89
- class MultiHeadAttention(nn.Module):
90
- """
91
- Multi-head self-attention module.
92
-
93
- This is a standard scaled dot-product attention implementation
94
- with causal masking for autoregressive generation.
95
- """
96
-
97
- def __init__(self, config: ChessConfig):
98
  super().__init__()
99
-
100
- assert config.n_embd % config.n_head == 0, \
101
- f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
102
-
103
  self.n_head = config.n_head
104
  self.n_embd = config.n_embd
105
  self.head_dim = config.n_embd // config.n_head
106
-
107
- # Combined QKV projection for efficiency
108
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
109
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
110
-
111
  self.dropout = nn.Dropout(config.dropout)
112
-
113
- # Causal mask (will be created on first forward pass)
114
  self.register_buffer(
115
  "bias",
116
- torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
117
- 1, 1, config.n_ctx, config.n_ctx
118
- ),
119
  persistent=False,
120
  )
121
-
122
- def forward(
123
- self,
124
- x: torch.Tensor,
125
- attention_mask: Optional[torch.Tensor] = None,
126
- ) -> torch.Tensor:
127
  batch_size, seq_len, _ = x.size()
128
-
129
- # Compute Q, K, V
130
  qkv = self.c_attn(x)
131
  q, k, v = qkv.split(self.n_embd, dim=2)
132
-
133
- # Reshape for multi-head attention
134
  q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
135
  k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
136
  v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
137
-
138
- # Scaled dot-product attention
139
  attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
140
-
141
- # Apply causal mask
142
  causal_mask = self.bias[:, :, :seq_len, :seq_len]
143
  attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
144
-
145
- # Apply attention mask (for padding)
146
  if attention_mask is not None:
147
- # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
148
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
149
- attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
150
-
151
  attn_weights = F.softmax(attn_weights, dim=-1)
152
  attn_weights = self.dropout(attn_weights)
153
-
154
- # Apply attention to values
155
  attn_output = torch.matmul(attn_weights, v)
156
-
157
- # Reshape back
158
- attn_output = attn_output.transpose(1, 2).contiguous().view(
159
- batch_size, seq_len, self.n_embd
160
- )
161
-
162
- # Output projection
163
  attn_output = self.c_proj(attn_output)
164
-
165
  return attn_output
166
 
167
 
168
- class FeedForward(nn.Module):
169
- """
170
- Feed-forward network (MLP) module.
171
-
172
- Standard two-layer MLP with GELU activation.
173
- """
174
-
175
- def __init__(self, config: ChessConfig):
176
  super().__init__()
177
-
178
  self.c_fc = nn.Linear(config.n_embd, config.n_inner)
179
  self.c_proj = nn.Linear(config.n_inner, config.n_embd)
180
  self.dropout = nn.Dropout(config.dropout)
181
-
182
  def forward(self, x: torch.Tensor) -> torch.Tensor:
183
  x = self.c_fc(x)
184
  x = F.gelu(x)
@@ -187,90 +118,40 @@ class FeedForward(nn.Module):
187
  return x
188
 
189
 
190
- class TransformerBlock(nn.Module):
191
- """
192
- A single transformer block with attention and feed-forward layers.
193
-
194
- Uses pre-normalization (LayerNorm before attention/FFN) for better
195
- training stability.
196
- """
197
-
198
- def __init__(self, config: ChessConfig):
199
  super().__init__()
200
-
201
  self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
202
- self.attn = MultiHeadAttention(config)
203
  self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
204
- self.mlp = FeedForward(config)
205
-
206
- def forward(
207
- self,
208
- x: torch.Tensor,
209
- attention_mask: Optional[torch.Tensor] = None,
210
- ) -> torch.Tensor:
211
- # Pre-norm attention
212
  x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
213
- # Pre-norm FFN
214
  x = x + self.mlp(self.ln_2(x))
215
  return x
216
 
217
 
218
- class ChessForCausalLM(PreTrainedModel):
219
- """
220
- Chess Transformer for Causal Language Modeling (next-move prediction).
221
-
222
- This model is designed to predict the next chess move given a sequence
223
- of previous moves. It uses a GPT-style architecture with:
224
- - Token embeddings for chess moves
225
- - Learned positional embeddings
226
- - Stacked transformer blocks
227
- - Linear head for next-token prediction
228
-
229
- The model supports weight tying between the embedding layer and the
230
- output projection to save parameters.
231
-
232
- Example:
233
- >>> config = ChessConfig(vocab_size=1200, n_embd=128, n_layer=6)
234
- >>> model = ChessForCausalLM(config)
235
- >>> inputs = {"input_ids": torch.tensor([[1, 42, 87]])}
236
- >>> outputs = model(**inputs)
237
- >>> next_move_logits = outputs.logits[:, -1, :]
238
- """
239
-
240
- config_class = ChessConfig
241
- base_model_prefix = "transformer"
242
  supports_gradient_checkpointing = True
243
- # Suppress missing-key warning for tied lm_head when loading
244
  keys_to_ignore_on_load_missing = ["lm_head.weight"]
245
-
246
- def __init__(self, config: ChessConfig):
247
  super().__init__(config)
248
-
249
- # Token and position embeddings
250
  self.wte = nn.Embedding(config.vocab_size, config.n_embd)
251
  self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
252
-
253
  self.drop = nn.Dropout(config.dropout)
254
-
255
- # Transformer blocks
256
- self.h = nn.ModuleList([
257
- TransformerBlock(config) for _ in range(config.n_layer)
258
- ])
259
-
260
- # Final layer norm
261
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
262
-
263
- # Output head
264
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
265
-
266
- # Declare tied weights for proper serialization
267
  if config.tie_weights:
268
  self._tied_weights_keys = ["lm_head.weight"]
269
-
270
- # Initialize weights
271
  self.post_init()
272
-
273
- # Tie weights if configured
274
  if config.tie_weights:
275
  self.tie_weights()
276
 
@@ -289,12 +170,10 @@ class ChessForCausalLM(PreTrainedModel):
289
  self.lm_head = new_embeddings
290
 
291
  def tie_weights(self):
292
- # Use HF helper to tie or clone depending on config
293
  if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
294
  self._tie_or_clone_weights(self.lm_head, self.wte)
295
-
296
  def _init_weights(self, module: nn.Module):
297
- """Initialize weights following GPT-2 style."""
298
  if isinstance(module, nn.Linear):
299
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
300
  if module.bias is not None:
@@ -304,7 +183,7 @@ class ChessForCausalLM(PreTrainedModel):
304
  elif isinstance(module, nn.LayerNorm):
305
  torch.nn.init.ones_(module.weight)
306
  torch.nn.init.zeros_(module.bias)
307
-
308
  def forward(
309
  self,
310
  input_ids: torch.LongTensor,
@@ -314,61 +193,42 @@ class ChessForCausalLM(PreTrainedModel):
314
  return_dict: Optional[bool] = None,
315
  **kwargs,
316
  ) -> Union[Tuple, CausalLMOutputWithPast]:
317
- """
318
- Forward pass of the model.
319
-
320
- Args:
321
- input_ids: Token IDs of shape (batch_size, seq_len).
322
- attention_mask: Attention mask of shape (batch_size, seq_len).
323
- position_ids: Position IDs of shape (batch_size, seq_len).
324
- labels: Labels for language modeling loss.
325
- return_dict: Whether to return a ModelOutput object.
326
-
327
- Returns:
328
- CausalLMOutputWithPast containing loss (if labels provided) and logits.
329
- """
330
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
331
-
332
  batch_size, seq_len = input_ids.size()
333
  device = input_ids.device
334
-
335
- # Create position IDs if not provided
 
 
336
  if position_ids is None:
337
  position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
338
-
339
- # Get embeddings
340
  token_embeds = self.wte(input_ids)
341
- position_embeds = self.wpe(position_ids)
342
- hidden_states = self.drop(token_embeds + position_embeds)
343
-
344
- # Pass through transformer blocks
345
- for block in self.h:
346
- hidden_states = block(hidden_states, attention_mask=attention_mask)
347
-
348
- # Final layer norm
 
 
349
  hidden_states = self.ln_f(hidden_states)
350
-
351
- # Get logits
352
  logits = self.lm_head(hidden_states)
353
-
354
- # Compute loss if labels are provided
355
  loss = None
356
  if labels is not None:
357
- # Shift logits and labels for next-token prediction
358
  shift_logits = logits[..., :-1, :].contiguous()
359
  shift_labels = labels[..., 1:].contiguous()
360
-
361
- # Flatten for cross-entropy
362
  loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
363
- loss = loss_fct(
364
- shift_logits.view(-1, shift_logits.size(-1)),
365
- shift_labels.view(-1),
366
- )
367
-
368
  if not return_dict:
369
  output = (logits,)
370
  return ((loss,) + output) if loss is not None else output
371
-
372
  return CausalLMOutputWithPast(
373
  loss=loss,
374
  logits=logits,
@@ -376,62 +236,8 @@ class ChessForCausalLM(PreTrainedModel):
376
  hidden_states=None,
377
  attentions=None,
378
  )
379
-
380
- @torch.no_grad()
381
- def generate_move(
382
- self,
383
- input_ids: torch.LongTensor,
384
- temperature: float = 1.0,
385
- top_k: Optional[int] = None,
386
- top_p: Optional[float] = None,
387
- ) -> int:
388
- """
389
- Generate the next move given a sequence of moves.
390
-
391
- Args:
392
- input_ids: Token IDs of shape (1, seq_len).
393
- temperature: Sampling temperature (1.0 = no change).
394
- top_k: If set, only sample from top k tokens.
395
- top_p: If set, use nucleus sampling with this threshold.
396
-
397
- Returns:
398
- The token ID of the predicted next move.
399
- """
400
- self.eval()
401
-
402
- # Get logits for the last position
403
- outputs = self(input_ids)
404
- logits = outputs.logits[:, -1, :] / temperature
405
-
406
- # Apply top-k filtering
407
- if top_k is not None:
408
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
409
- logits[indices_to_remove] = float("-inf")
410
-
411
- # Apply top-p (nucleus) filtering
412
- if top_p is not None:
413
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
414
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
415
-
416
- # Remove tokens with cumulative probability above the threshold
417
- sorted_indices_to_remove = cumulative_probs > top_p
418
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
419
- sorted_indices_to_remove[..., 0] = 0
420
-
421
- indices_to_remove = sorted_indices_to_remove.scatter(
422
- dim=-1, index=sorted_indices, src=sorted_indices_to_remove
423
- )
424
- logits[indices_to_remove] = float("-inf")
425
-
426
- # Sample from the distribution
427
- probs = F.softmax(logits, dim=-1)
428
- next_token = torch.multinomial(probs, num_samples=1)
429
-
430
- return next_token.item()
431
-
432
-
433
- # Register the model with Auto classes for easy loading
434
- from transformers import AutoConfig, AutoModelForCausalLM
435
-
436
- AutoConfig.register("chess_transformer", ChessConfig)
437
- AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
 
1
  """
2
+ TRM-style model for the Chess Challenge.
3
 
4
+ This implements a weight-shared recurrent transformer (Tiny Recursive Model style)
5
+ for causal language modeling under the 1M parameter constraint.
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
10
  import math
 
11
  from typing import Optional, Tuple, Union
12
 
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
18
 
19
 
20
+ class ChessTRMConfig(PretrainedConfig):
21
+ model_type = "chess_trm"
22
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def __init__(
24
  self,
25
  vocab_size: int = 1200,
26
  n_embd: int = 128,
27
+ n_layer: int = 2,
28
  n_head: int = 4,
29
  n_ctx: int = 256,
30
  n_inner: Optional[int] = None,
31
+ n_cycles: int = 8,
32
  dropout: float = 0.1,
33
  layer_norm_epsilon: float = 1e-5,
34
  tie_weights: bool = True,
 
43
  eos_token_id=eos_token_id,
44
  **kwargs,
45
  )
46
+ self.vocab_size = int(vocab_size)
47
+ self.n_embd = int(n_embd)
48
+ self.n_layer = int(n_layer)
49
+ self.n_head = int(n_head)
50
+ self.n_ctx = int(n_ctx)
51
+ self.n_inner = int(n_inner) if n_inner is not None else int(3 * n_embd)
52
+ self.n_cycles = int(n_cycles)
53
+ self.dropout = float(dropout)
54
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
55
+ self.tie_weights = bool(tie_weights)
 
56
  self.tie_word_embeddings = bool(tie_weights)
57
 
58
 
59
+ class _TRMMultiHeadAttention(nn.Module):
60
+ def __init__(self, config: ChessTRMConfig):
 
 
 
 
 
 
 
61
  super().__init__()
62
+ if config.n_embd % config.n_head != 0:
63
+ raise ValueError(f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})")
 
 
64
  self.n_head = config.n_head
65
  self.n_embd = config.n_embd
66
  self.head_dim = config.n_embd // config.n_head
67
+
 
68
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
69
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
 
70
  self.dropout = nn.Dropout(config.dropout)
71
+
 
72
  self.register_buffer(
73
  "bias",
74
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx),
 
 
75
  persistent=False,
76
  )
77
+
78
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
 
 
 
79
  batch_size, seq_len, _ = x.size()
80
+
 
81
  qkv = self.c_attn(x)
82
  q, k, v = qkv.split(self.n_embd, dim=2)
83
+
 
84
  q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
85
  k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
86
  v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
87
+
 
88
  attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
89
+
 
90
  causal_mask = self.bias[:, :, :seq_len, :seq_len]
91
  attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
92
+
 
93
  if attention_mask is not None:
94
+ expanded = attention_mask.unsqueeze(1).unsqueeze(2)
95
+ attn_weights = attn_weights.masked_fill(expanded == 0, float("-inf"))
96
+
 
97
  attn_weights = F.softmax(attn_weights, dim=-1)
98
  attn_weights = self.dropout(attn_weights)
99
+
 
100
  attn_output = torch.matmul(attn_weights, v)
101
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd)
 
 
 
 
 
 
102
  attn_output = self.c_proj(attn_output)
 
103
  return attn_output
104
 
105
 
106
+ class _TRMFeedForward(nn.Module):
107
+ def __init__(self, config: ChessTRMConfig):
 
 
 
 
 
 
108
  super().__init__()
 
109
  self.c_fc = nn.Linear(config.n_embd, config.n_inner)
110
  self.c_proj = nn.Linear(config.n_inner, config.n_embd)
111
  self.dropout = nn.Dropout(config.dropout)
112
+
113
  def forward(self, x: torch.Tensor) -> torch.Tensor:
114
  x = self.c_fc(x)
115
  x = F.gelu(x)
 
118
  return x
119
 
120
 
121
+ class _TRMBlock(nn.Module):
122
+ def __init__(self, config: ChessTRMConfig):
 
 
 
 
 
 
 
123
  super().__init__()
 
124
  self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
125
+ self.attn = _TRMMultiHeadAttention(config)
126
  self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
127
+ self.mlp = _TRMFeedForward(config)
128
+
129
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
 
 
 
 
130
  x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
 
131
  x = x + self.mlp(self.ln_2(x))
132
  return x
133
 
134
 
135
+ class ChessTRMForCausalLM(PreTrainedModel):
136
+ config_class = ChessTRMConfig
137
+ base_model_prefix = "trm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  supports_gradient_checkpointing = True
 
139
  keys_to_ignore_on_load_missing = ["lm_head.weight"]
140
+
141
+ def __init__(self, config: ChessTRMConfig):
142
  super().__init__(config)
 
 
143
  self.wte = nn.Embedding(config.vocab_size, config.n_embd)
144
  self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
 
145
  self.drop = nn.Dropout(config.dropout)
146
+
147
+ self.blocks = nn.ModuleList([_TRMBlock(config) for _ in range(config.n_layer)])
 
 
 
 
 
148
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
149
+
 
150
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
 
 
151
  if config.tie_weights:
152
  self._tied_weights_keys = ["lm_head.weight"]
153
+
 
154
  self.post_init()
 
 
155
  if config.tie_weights:
156
  self.tie_weights()
157
 
 
170
  self.lm_head = new_embeddings
171
 
172
  def tie_weights(self):
 
173
  if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
174
  self._tie_or_clone_weights(self.lm_head, self.wte)
175
+
176
  def _init_weights(self, module: nn.Module):
 
177
  if isinstance(module, nn.Linear):
178
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
179
  if module.bias is not None:
 
183
  elif isinstance(module, nn.LayerNorm):
184
  torch.nn.init.ones_(module.weight)
185
  torch.nn.init.zeros_(module.bias)
186
+
187
  def forward(
188
  self,
189
  input_ids: torch.LongTensor,
 
193
  return_dict: Optional[bool] = None,
194
  **kwargs,
195
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197
+
198
  batch_size, seq_len = input_ids.size()
199
  device = input_ids.device
200
+
201
+ if seq_len > self.config.n_ctx:
202
+ raise ValueError(f"seq_len ({seq_len}) exceeds n_ctx ({self.config.n_ctx})")
203
+
204
  if position_ids is None:
205
  position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
206
+
 
207
  token_embeds = self.wte(input_ids)
208
+ pos_embeds = self.wpe(position_ids)
209
+ input_injection = token_embeds + pos_embeds
210
+
211
+ hidden_states = self.drop(input_injection)
212
+
213
+ for _ in range(self.config.n_cycles):
214
+ hidden_states = hidden_states + input_injection
215
+ for block in self.blocks:
216
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
217
+
218
  hidden_states = self.ln_f(hidden_states)
 
 
219
  logits = self.lm_head(hidden_states)
220
+
 
221
  loss = None
222
  if labels is not None:
 
223
  shift_logits = logits[..., :-1, :].contiguous()
224
  shift_labels = labels[..., 1:].contiguous()
 
 
225
  loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
226
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
227
+
 
 
 
228
  if not return_dict:
229
  output = (logits,)
230
  return ((loss,) + output) if loss is not None else output
231
+
232
  return CausalLMOutputWithPast(
233
  loss=loss,
234
  logits=logits,
 
236
  hidden_states=None,
237
  attentions=None,
238
  )
239
+
240
+
241
+ AutoConfig.register("chess_trm", ChessTRMConfig)
242
+ AutoModelForCausalLM.register(ChessTRMConfig, ChessTRMForCausalLM)
243
+