MaximeMuhlethaler commited on
Commit
b1f8d8d
·
verified ·
1 Parent(s): 481eec0

Chess Challenge submission by MaximeMuhlethaler

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. model.py +47 -97
config.json CHANGED
@@ -13,6 +13,7 @@
13
  "pad_token_id": 0,
14
  "tie_weights": true,
15
  "transformers_version": "4.57.3",
 
16
  "vocab_size": 1200,
17
  "auto_map": {
18
  "AutoModelForCausalLM": "model.ChessForCausalLM",
 
13
  "pad_token_id": 0,
14
  "tie_weights": true,
15
  "transformers_version": "4.57.3",
16
+ "unk_token_id": 3,
17
  "vocab_size": 1200,
18
  "auto_map": {
19
  "AutoModelForCausalLM": "model.ChessForCausalLM",
model.py CHANGED
@@ -1,46 +1,36 @@
1
  """
2
- Chess Transformer Model - Final Stable Version with Inference Patch
3
  """
4
  from __future__ import annotations
5
-
6
  import math
7
- from dataclasses import dataclass
8
  from typing import Optional, Tuple, Union
9
-
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from transformers import PretrainedConfig, PreTrainedModel
14
  from transformers.modeling_outputs import CausalLMOutputWithPast
15
 
16
-
17
  class ChessConfig(PretrainedConfig):
18
  model_type = "chess_transformer"
19
 
20
  def __init__(
21
  self,
22
- vocab_size: int = 1200,
23
- n_embd: int = 128,
24
- n_layer: int = 6,
25
- n_head: int = 4,
26
- n_ctx: int = 256,
27
- n_inner: Optional[int] = None,
28
- dropout: float = 0.1,
29
- layer_norm_epsilon: float = 1e-5,
30
- tie_weights: bool = True,
31
- pad_token_id: int = 0,
32
- bos_token_id: int = 1,
33
- eos_token_id: int = 2,
 
 
34
  **kwargs,
35
  ):
36
-
37
- super().__init__(
38
- pad_token_id=pad_token_id,
39
- bos_token_id=bos_token_id,
40
- eos_token_id=eos_token_id,
41
- **kwargs,
42
- )
43
-
44
  self.vocab_size = vocab_size
45
  self.n_embd = n_embd
46
  self.n_layer = n_layer
@@ -50,26 +40,25 @@ class ChessConfig(PretrainedConfig):
50
  self.dropout = dropout
51
  self.layer_norm_epsilon = layer_norm_epsilon
52
  self.tie_weights = tie_weights
53
- self.tie_word_embeddings = bool(tie_weights)
54
-
 
 
 
 
 
 
55
 
56
  class MultiHeadAttention(nn.Module):
57
  def __init__(self, config: ChessConfig):
58
  super().__init__()
59
- assert config.n_embd % config.n_head == 0
60
  self.n_head = config.n_head
61
  self.n_embd = config.n_embd
62
  self.head_dim = config.n_embd // config.n_head
63
-
64
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
65
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
66
  self.dropout = nn.Dropout(config.dropout)
67
-
68
- self.register_buffer(
69
- "bias",
70
- torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx),
71
- persistent=False,
72
- )
73
 
74
  def forward(self, x, attention_mask=None):
75
  B, T, C = x.size()
@@ -78,31 +67,25 @@ class MultiHeadAttention(nn.Module):
78
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
79
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
80
  v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
81
-
82
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
83
- att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf"))
84
-
85
  if attention_mask is not None:
86
- att = att.masked_fill(attention_mask.view(B, 1, 1, T) == 0, float("-inf"))
87
-
88
  att = F.softmax(att, dim=-1)
89
  att = self.dropout(att)
90
  y = att @ v
91
  y = y.transpose(1, 2).contiguous().view(B, T, C)
92
  return self.c_proj(y)
93
 
94
-
95
  class FeedForward(nn.Module):
96
  def __init__(self, config: ChessConfig):
97
  super().__init__()
98
  self.c_fc = nn.Linear(config.n_embd, config.n_inner)
99
  self.c_proj = nn.Linear(config.n_inner, config.n_embd)
100
  self.dropout = nn.Dropout(config.dropout)
101
-
102
  def forward(self, x):
103
  return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
104
 
105
-
106
  class TransformerBlock(nn.Module):
107
  def __init__(self, config: ChessConfig):
108
  super().__init__()
@@ -110,18 +93,14 @@ class TransformerBlock(nn.Module):
110
  self.attn = MultiHeadAttention(config)
111
  self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
112
  self.mlp = FeedForward(config)
113
-
114
  def forward(self, x, attention_mask=None):
115
- x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
116
  x = x + self.mlp(self.ln_2(x))
117
  return x
118
 
119
-
120
  class ChessForCausalLM(PreTrainedModel):
121
  config_class = ChessConfig
122
  base_model_prefix = "transformer"
123
- supports_gradient_checkpointing = True
124
- keys_to_ignore_on_load_missing = ["lm_head.weight"]
125
 
126
  def __init__(self, config: ChessConfig):
127
  super().__init__(config)
@@ -131,79 +110,50 @@ class ChessForCausalLM(PreTrainedModel):
131
  self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
132
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
133
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
-
135
- if config.tie_weights:
136
- self.post_init()
137
- self.tie_weights()
138
 
139
  def get_input_embeddings(self): return self.wte
140
  def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings
141
  def get_output_embeddings(self): return self.lm_head
142
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
143
 
144
- def tie_weights(self):
145
- if getattr(self.config, "tie_weights", False):
146
- self._tie_or_clone_weights(self.lm_head, self.wte)
147
-
148
- def forward(
149
- self,
150
- input_ids: torch.LongTensor,
151
- attention_mask: Optional[torch.Tensor] = None,
152
- position_ids: Optional[torch.LongTensor] = None,
153
- labels: Optional[torch.LongTensor] = None,
154
- return_dict: Optional[bool] = None,
155
- **kwargs,
156
- ) -> Union[Tuple, CausalLMOutputWithPast]:
157
-
158
-
159
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
-
 
161
  device = input_ids.device
162
  b, t = input_ids.size()
163
- if position_ids is None:
164
- position_ids = torch.arange(t, device=device).unsqueeze(0).expand(b, -1)
165
 
166
  x = self.wte(input_ids) + self.wpe(position_ids)
167
  x = self.drop(x)
168
- for block in self.h:
169
- x = block(x, attention_mask)
170
  x = self.ln_f(x)
171
  logits = self.lm_head(x)
172
 
 
 
 
173
  if labels is None:
174
- bad_tokens = [
175
- self.config.eos_token_id,
176
- self.config.pad_token_id,
177
- self.config.bos_token_id
178
- ]
179
- if hasattr(self.config, "unk_token_id") and self.config.unk_token_id is not None:
180
- bad_tokens.append(self.config.unk_token_id)
181
 
182
- bad_tokens = [t for t in bad_tokens if t is not None]
183
-
184
- if len(bad_tokens) > 0:
185
- logits[:, :, bad_tokens] = float("-inf")
186
-
187
 
188
  loss = None
189
  if labels is not None:
190
  shift_logits = logits[..., :-1, :].contiguous()
191
  shift_labels = labels[..., 1:].contiguous()
192
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
193
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
194
-
195
  if not return_dict:
196
- output = (logits,)
197
- return ((loss,) + output) if loss is not None else output
198
 
199
- return CausalLMOutputWithPast(
200
- loss=loss,
201
- logits=logits,
202
- past_key_values=None,
203
- hidden_states=None,
204
- attentions=None,
205
- )
206
-
207
 
208
  from transformers import AutoConfig, AutoModelForCausalLM
209
  AutoConfig.register("chess_transformer", ChessConfig)
 
1
  """
2
+ Chess Transformer Model - The "Nuclear Patch" Edition
3
  """
4
  from __future__ import annotations
 
5
  import math
 
6
  from typing import Optional, Tuple, Union
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  from transformers import PretrainedConfig, PreTrainedModel
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
 
 
13
  class ChessConfig(PretrainedConfig):
14
  model_type = "chess_transformer"
15
 
16
  def __init__(
17
  self,
18
+ vocab_size=1200,
19
+ n_embd=128,
20
+ n_layer=6,
21
+ n_head=4,
22
+ n_ctx=256,
23
+ n_inner=None,
24
+ dropout=0.1,
25
+ layer_norm_epsilon=1e-5,
26
+ tie_weights=True,
27
+ # Valeurs par défaut strictes
28
+ pad_token_id=0,
29
+ bos_token_id=1,
30
+ eos_token_id=2,
31
+ unk_token_id=3,
32
  **kwargs,
33
  ):
 
 
 
 
 
 
 
 
34
  self.vocab_size = vocab_size
35
  self.n_embd = n_embd
36
  self.n_layer = n_layer
 
40
  self.dropout = dropout
41
  self.layer_norm_epsilon = layer_norm_epsilon
42
  self.tie_weights = tie_weights
43
+
44
+ # On passe les IDs vitaux à kwargs pour le parent
45
+ kwargs["pad_token_id"] = pad_token_id
46
+ kwargs["bos_token_id"] = bos_token_id
47
+ kwargs["eos_token_id"] = eos_token_id
48
+ kwargs["unk_token_id"] = unk_token_id
49
+
50
+ super().__init__(**kwargs)
51
 
52
  class MultiHeadAttention(nn.Module):
53
  def __init__(self, config: ChessConfig):
54
  super().__init__()
 
55
  self.n_head = config.n_head
56
  self.n_embd = config.n_embd
57
  self.head_dim = config.n_embd // config.n_head
 
58
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
59
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
60
  self.dropout = nn.Dropout(config.dropout)
61
+ self.register_buffer("bias", torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx), persistent=False)
 
 
 
 
 
62
 
63
  def forward(self, x, attention_mask=None):
64
  B, T, C = x.size()
 
67
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
68
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
69
  v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
 
70
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
71
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
 
72
  if attention_mask is not None:
73
+ att = att.masked_fill(attention_mask.view(B, 1, 1, T) == 0, float('-inf'))
 
74
  att = F.softmax(att, dim=-1)
75
  att = self.dropout(att)
76
  y = att @ v
77
  y = y.transpose(1, 2).contiguous().view(B, T, C)
78
  return self.c_proj(y)
79
 
 
80
  class FeedForward(nn.Module):
81
  def __init__(self, config: ChessConfig):
82
  super().__init__()
83
  self.c_fc = nn.Linear(config.n_embd, config.n_inner)
84
  self.c_proj = nn.Linear(config.n_inner, config.n_embd)
85
  self.dropout = nn.Dropout(config.dropout)
 
86
  def forward(self, x):
87
  return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
88
 
 
89
  class TransformerBlock(nn.Module):
90
  def __init__(self, config: ChessConfig):
91
  super().__init__()
 
93
  self.attn = MultiHeadAttention(config)
94
  self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
95
  self.mlp = FeedForward(config)
 
96
  def forward(self, x, attention_mask=None):
97
+ x = x + self.attn(self.ln_1(x), attention_mask)
98
  x = x + self.mlp(self.ln_2(x))
99
  return x
100
 
 
101
  class ChessForCausalLM(PreTrainedModel):
102
  config_class = ChessConfig
103
  base_model_prefix = "transformer"
 
 
104
 
105
  def __init__(self, config: ChessConfig):
106
  super().__init__(config)
 
110
  self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
111
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
112
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
113
+ if config.tie_weights: self.post_init()
 
 
 
114
 
115
  def get_input_embeddings(self): return self.wte
116
  def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings
117
  def get_output_embeddings(self): return self.lm_head
118
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
119
 
120
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, return_dict=None, **kwargs):
121
+ # 1. FIX TYPE RETOUR
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
123
+ if return_dict is None: return_dict = True
124
+
125
  device = input_ids.device
126
  b, t = input_ids.size()
127
+ if position_ids is None: position_ids = torch.arange(t, device=device).unsqueeze(0)
 
128
 
129
  x = self.wte(input_ids) + self.wpe(position_ids)
130
  x = self.drop(x)
131
+ for block in self.h: x = block(x, attention_mask)
 
132
  x = self.ln_f(x)
133
  logits = self.lm_head(x)
134
 
135
+ # ---------------------------------------------------------
136
+ # 2. PATCH NUCLÉAIRE : On bannit 0, 1, 2, 3 en dur
137
+ # ---------------------------------------------------------
138
  if labels is None:
139
+ # PAD=0, BOS=1, EOS=2, UNK=3 (Les standards de ton tokenizer)
140
+ nuclear_bad_ids = [0, 1, 2, 3]
 
 
 
 
 
141
 
142
+ # On met -infini (impossible à choisir)
143
+ # Le slicing [:, :, ids] couvre tout le batch et toute la séquence
144
+ logits[:, :, nuclear_bad_ids] = float("-inf")
145
+ # ---------------------------------------------------------
 
146
 
147
  loss = None
148
  if labels is not None:
149
  shift_logits = logits[..., :-1, :].contiguous()
150
  shift_labels = labels[..., 1:].contiguous()
151
+ loss = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
152
+
 
153
  if not return_dict:
154
+ return ((loss,) + (logits,)) if loss is not None else (logits,)
 
155
 
156
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
 
 
 
 
 
 
 
157
 
158
  from transformers import AutoConfig, AutoModelForCausalLM
159
  AutoConfig.register("chess_transformer", ChessConfig)