MaximeMuhlethaler commited on
Commit
481eec0
·
verified ·
1 Parent(s): c1a4876

Chess Challenge submission by MaximeMuhlethaler

Browse files
Files changed (5) hide show
  1. README.md +1 -4
  2. config.json +6 -5
  3. model.py +210 -0
  4. pytorch_model.bin +3 -0
  5. tokenizer.py +75 -60
README.md CHANGED
@@ -12,15 +12,12 @@ license: mit
12
  Chess model submitted to the LLM Course Chess Challenge.
13
 
14
  ## Submission Info
15
-
16
  - **Submitted by**: [MaximeMuhlethaler](https://huggingface.co/MaximeMuhlethaler)
17
  - **Parameters**: 924,000
18
  - **Organization**: LLM-course
 
19
 
20
  ## Model Details
21
-
22
- - **Architecture**: Chess Transformer (GPT-style)
23
  - **Vocab size**: 1200
24
- - **Embedding dim**: 112
25
  - **Layers**: 6
26
  - **Heads**: 8
 
12
  Chess model submitted to the LLM Course Chess Challenge.
13
 
14
  ## Submission Info
 
15
  - **Submitted by**: [MaximeMuhlethaler](https://huggingface.co/MaximeMuhlethaler)
16
  - **Parameters**: 924,000
17
  - **Organization**: LLM-course
18
+ - **Architecture**: Custom Chess Transformer (Regex Tokenizer + EOS Protection)
19
 
20
  ## Model Details
 
 
21
  - **Vocab size**: 1200
 
22
  - **Layers**: 6
23
  - **Heads**: 8
config.json CHANGED
@@ -1,7 +1,4 @@
1
  {
2
- "architectures": [
3
- "ChessForCausalLM"
4
- ],
5
  "bos_token_id": 1,
6
  "dropout": 0.1,
7
  "dtype": "float32",
@@ -16,5 +13,9 @@
16
  "pad_token_id": 0,
17
  "tie_weights": true,
18
  "transformers_version": "4.57.3",
19
- "vocab_size": 1200
20
- }
 
 
 
 
 
1
  {
 
 
 
2
  "bos_token_id": 1,
3
  "dropout": 0.1,
4
  "dtype": "float32",
 
13
  "pad_token_id": 0,
14
  "tie_weights": true,
15
  "transformers_version": "4.57.3",
16
+ "vocab_size": 1200,
17
+ "auto_map": {
18
+ "AutoModelForCausalLM": "model.ChessForCausalLM",
19
+ "AutoConfig": "model.ChessConfig"
20
+ }
21
+ }
model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model - Final Stable Version with Inference Patch
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import PretrainedConfig, PreTrainedModel
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+
16
+
17
+ class ChessConfig(PretrainedConfig):
18
+ model_type = "chess_transformer"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size: int = 1200,
23
+ n_embd: int = 128,
24
+ n_layer: int = 6,
25
+ n_head: int = 4,
26
+ n_ctx: int = 256,
27
+ n_inner: Optional[int] = None,
28
+ dropout: float = 0.1,
29
+ layer_norm_epsilon: float = 1e-5,
30
+ tie_weights: bool = True,
31
+ pad_token_id: int = 0,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ **kwargs,
35
+ ):
36
+
37
+ super().__init__(
38
+ pad_token_id=pad_token_id,
39
+ bos_token_id=bos_token_id,
40
+ eos_token_id=eos_token_id,
41
+ **kwargs,
42
+ )
43
+
44
+ self.vocab_size = vocab_size
45
+ self.n_embd = n_embd
46
+ self.n_layer = n_layer
47
+ self.n_head = n_head
48
+ self.n_ctx = n_ctx
49
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd
50
+ self.dropout = dropout
51
+ self.layer_norm_epsilon = layer_norm_epsilon
52
+ self.tie_weights = tie_weights
53
+ self.tie_word_embeddings = bool(tie_weights)
54
+
55
+
56
+ class MultiHeadAttention(nn.Module):
57
+ def __init__(self, config: ChessConfig):
58
+ super().__init__()
59
+ assert config.n_embd % config.n_head == 0
60
+ self.n_head = config.n_head
61
+ self.n_embd = config.n_embd
62
+ self.head_dim = config.n_embd // config.n_head
63
+
64
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
65
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
66
+ self.dropout = nn.Dropout(config.dropout)
67
+
68
+ self.register_buffer(
69
+ "bias",
70
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx),
71
+ persistent=False,
72
+ )
73
+
74
+ def forward(self, x, attention_mask=None):
75
+ B, T, C = x.size()
76
+ qkv = self.c_attn(x)
77
+ q, k, v = qkv.split(self.n_embd, dim=2)
78
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
79
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
80
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
81
+
82
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
83
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf"))
84
+
85
+ if attention_mask is not None:
86
+ att = att.masked_fill(attention_mask.view(B, 1, 1, T) == 0, float("-inf"))
87
+
88
+ att = F.softmax(att, dim=-1)
89
+ att = self.dropout(att)
90
+ y = att @ v
91
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
92
+ return self.c_proj(y)
93
+
94
+
95
+ class FeedForward(nn.Module):
96
+ def __init__(self, config: ChessConfig):
97
+ super().__init__()
98
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
99
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
100
+ self.dropout = nn.Dropout(config.dropout)
101
+
102
+ def forward(self, x):
103
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
104
+
105
+
106
+ class TransformerBlock(nn.Module):
107
+ def __init__(self, config: ChessConfig):
108
+ super().__init__()
109
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
110
+ self.attn = MultiHeadAttention(config)
111
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
112
+ self.mlp = FeedForward(config)
113
+
114
+ def forward(self, x, attention_mask=None):
115
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
116
+ x = x + self.mlp(self.ln_2(x))
117
+ return x
118
+
119
+
120
+ class ChessForCausalLM(PreTrainedModel):
121
+ config_class = ChessConfig
122
+ base_model_prefix = "transformer"
123
+ supports_gradient_checkpointing = True
124
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
125
+
126
+ def __init__(self, config: ChessConfig):
127
+ super().__init__(config)
128
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
129
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
130
+ self.drop = nn.Dropout(config.dropout)
131
+ self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
132
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+
135
+ if config.tie_weights:
136
+ self.post_init()
137
+ self.tie_weights()
138
+
139
+ def get_input_embeddings(self): return self.wte
140
+ def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings
141
+ def get_output_embeddings(self): return self.lm_head
142
+ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
143
+
144
+ def tie_weights(self):
145
+ if getattr(self.config, "tie_weights", False):
146
+ self._tie_or_clone_weights(self.lm_head, self.wte)
147
+
148
+ def forward(
149
+ self,
150
+ input_ids: torch.LongTensor,
151
+ attention_mask: Optional[torch.Tensor] = None,
152
+ position_ids: Optional[torch.LongTensor] = None,
153
+ labels: Optional[torch.LongTensor] = None,
154
+ return_dict: Optional[bool] = None,
155
+ **kwargs,
156
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
157
+
158
+
159
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
+
161
+ device = input_ids.device
162
+ b, t = input_ids.size()
163
+ if position_ids is None:
164
+ position_ids = torch.arange(t, device=device).unsqueeze(0).expand(b, -1)
165
+
166
+ x = self.wte(input_ids) + self.wpe(position_ids)
167
+ x = self.drop(x)
168
+ for block in self.h:
169
+ x = block(x, attention_mask)
170
+ x = self.ln_f(x)
171
+ logits = self.lm_head(x)
172
+
173
+ if labels is None:
174
+ bad_tokens = [
175
+ self.config.eos_token_id,
176
+ self.config.pad_token_id,
177
+ self.config.bos_token_id
178
+ ]
179
+ if hasattr(self.config, "unk_token_id") and self.config.unk_token_id is not None:
180
+ bad_tokens.append(self.config.unk_token_id)
181
+
182
+ bad_tokens = [t for t in bad_tokens if t is not None]
183
+
184
+ if len(bad_tokens) > 0:
185
+ logits[:, :, bad_tokens] = float("-inf")
186
+
187
+
188
+ loss = None
189
+ if labels is not None:
190
+ shift_logits = logits[..., :-1, :].contiguous()
191
+ shift_labels = labels[..., 1:].contiguous()
192
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
193
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
194
+
195
+ if not return_dict:
196
+ output = (logits,)
197
+ return ((loss,) + output) if loss is not None else output
198
+
199
+ return CausalLMOutputWithPast(
200
+ loss=loss,
201
+ logits=logits,
202
+ past_key_values=None,
203
+ hidden_states=None,
204
+ attentions=None,
205
+ )
206
+
207
+
208
+ from transformers import AutoConfig, AutoModelForCausalLM
209
+ AutoConfig.register("chess_transformer", ChessConfig)
210
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d638adc2fb527d4c06e3a92895c9d10523d568048b92991aa434b6bbbe3ef338
3
+ size 3719211
tokenizer.py CHANGED
@@ -1,32 +1,52 @@
1
  """
2
- Custom Chess Tokenizer - Final Fix
 
 
 
 
 
 
 
 
 
3
  """
 
4
  from __future__ import annotations
 
5
  import json
6
  import os
7
- import shutil
8
- import re
9
  from typing import Dict, List, Optional
 
10
  from transformers import PreTrainedTokenizer
 
 
 
 
11
 
12
- # --- REGEX (Pour nettoyer les coups) ---
13
  MOVE_RE = re.compile(r"([a-h][1-8])([a-h][1-8])")
14
  PROMO_RE = re.compile(r"=([NBRQ])")
15
 
16
  def normalize_move(tok: str) -> str:
17
- if tok.startswith("["): return tok
18
  m = MOVE_RE.search(tok)
19
- if not m: return tok
 
 
20
  fr, to = m.group(1), m.group(2)
 
 
21
  promo = ""
22
  pm = PROMO_RE.search(tok)
23
- if pm: promo = "=" + pm.group(1)
 
 
 
24
  prefix = tok[:2] if len(tok) >= 2 else "WP"
25
  return f"{prefix}{fr}{to}{promo}"
26
 
27
  class ChessTokenizer(PreTrainedTokenizer):
28
  model_input_names = ["input_ids", "attention_mask"]
29
- vocab_files_names = {"vocab_file": "vocab.json"}
30
 
31
  PAD_TOKEN = "[PAD]"
32
  BOS_TOKEN = "[BOS]"
@@ -38,74 +58,69 @@ class ChessTokenizer(PreTrainedTokenizer):
38
  self._bos_token = self.BOS_TOKEN
39
  self._eos_token = self.EOS_TOKEN
40
  self._unk_token = self.UNK_TOKEN
41
-
42
- for t in ["pad_token", "bos_token", "eos_token", "unk_token"]: kwargs.pop(t, None)
 
 
43
 
44
- # FIX CHEMIN
45
- if vocab is None:
46
- if vocab_file is None:
47
- vocab_file = os.path.join(os.path.dirname(__file__), "vocab.json")
48
- self.vocab_file = vocab_file
49
- if os.path.exists(vocab_file):
50
- with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f)
51
- else: self._vocab = self._create_default_vocab()
52
- else:
53
  self._vocab = vocab
54
- self.vocab_file = vocab_file
55
-
 
 
 
 
56
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
57
  super().__init__(pad_token=self.PAD_TOKEN, bos_token=self.BOS_TOKEN, eos_token=self.EOS_TOKEN, unk_token=self.UNK_TOKEN, **kwargs)
58
 
59
- # AUTO-COPIE (Vital pour le submit)
60
- def save_pretrained(self, save_directory: str, **kwargs):
61
- super().save_pretrained(save_directory, **kwargs)
62
- src_path = os.path.abspath(__file__)
63
- dst_path = os.path.join(save_directory, "tokenizer.py")
64
- if src_path != dst_path: shutil.copy(src_path, dst_path)
65
-
66
- config_path = os.path.join(save_directory, "tokenizer_config.json")
67
- if os.path.exists(config_path):
68
- with open(config_path, "r") as f: cfg = json.load(f)
69
- cfg["auto_map"] = {"AutoTokenizer": "tokenizer.ChessTokenizer"}
70
- with open(config_path, "w") as f: json.dump(cfg, f, indent=2)
71
-
72
- def _create_default_vocab(self):
73
- return {t: i for i, t in enumerate([self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])}
 
 
 
 
74
 
75
- # LA FONCTION QUI GERE LA TAILLE FIXE
 
 
 
 
 
 
 
76
  @classmethod
77
- def build_vocab_from_dataset(cls, dataset_name, split="train", column="text", min_frequency=2, max_vocab_size=1700, max_samples=100000):
 
78
  from datasets import load_dataset
79
  from collections import Counter
80
 
81
- ds = load_dataset(dataset_name, split=split, streaming=True)
82
- ds = ds.take(max_samples)
 
83
 
84
  counter = Counter()
85
  for ex in ds:
86
- # On normalise
87
- moves = [normalize_move(t) for t in ex[column].split()]
88
  counter.update(moves)
89
 
90
- # ON FORCE LA TAILLE MAXIMALE ICI
91
  special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
92
- # On prend les N plus fréquents pour remplir jusqu'à max_vocab_size
93
  most_common = counter.most_common(max_vocab_size - len(special))
94
 
95
  vocab = {t: i for i, t in enumerate(special + [t for t, c in most_common])}
96
- return cls(vocab=vocab)
97
-
98
- @property
99
- def vocab_size(self): return len(self._vocab)
100
- def get_vocab(self): return dict(self._vocab)
101
-
102
- def _tokenize(self, text): return [normalize_move(t) for t in text.strip().split()]
103
- def _convert_token_to_id(self, token): return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
104
- def _convert_id_to_token(self, index): return self._ids_to_tokens.get(index, self.UNK_TOKEN)
105
- def convert_tokens_to_string(self, tokens): return " ".join(t for t in tokens if t not in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])
106
-
107
- def save_vocabulary(self, save_directory, filename_prefix=None):
108
- if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True)
109
- path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
110
- with open(path, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2)
111
- return (path,)
 
1
  """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
  """
13
+
14
  from __future__ import annotations
15
+
16
  import json
17
  import os
18
+ from pathlib import Path
 
19
  from typing import Dict, List, Optional
20
+
21
  from transformers import PreTrainedTokenizer
22
+ """
23
+ Custom Chess Tokenizer - Normalized Version
24
+ """
25
+ import re
26
 
 
27
  MOVE_RE = re.compile(r"([a-h][1-8])([a-h][1-8])")
28
  PROMO_RE = re.compile(r"=([NBRQ])")
29
 
30
  def normalize_move(tok: str) -> str:
31
+ """Transforme 'WPe2e4(x)' en 'WPe2e4' pour réduire le vocabulaire."""
32
  m = MOVE_RE.search(tok)
33
+ if not m:
34
+ return tok
35
+
36
  fr, to = m.group(1), m.group(2)
37
+
38
+
39
  promo = ""
40
  pm = PROMO_RE.search(tok)
41
+ if pm:
42
+ promo = "=" + pm.group(1)
43
+
44
+
45
  prefix = tok[:2] if len(tok) >= 2 else "WP"
46
  return f"{prefix}{fr}{to}{promo}"
47
 
48
  class ChessTokenizer(PreTrainedTokenizer):
49
  model_input_names = ["input_ids", "attention_mask"]
 
50
 
51
  PAD_TOKEN = "[PAD]"
52
  BOS_TOKEN = "[BOS]"
 
58
  self._bos_token = self.BOS_TOKEN
59
  self._eos_token = self.EOS_TOKEN
60
  self._unk_token = self.UNK_TOKEN
61
+
62
+ # Nettoyage kwargs
63
+ for t in ["pad_token", "bos_token", "eos_token", "unk_token"]:
64
+ kwargs.pop(t, None)
65
 
66
+ if vocab:
 
 
 
 
 
 
 
 
67
  self._vocab = vocab
68
+ elif vocab_file:
69
+ with open(vocab_file, "r", encoding="utf-8") as f:
70
+ self._vocab = json.load(f)
71
+ else:
72
+ self._vocab = {t: i for i, t in enumerate([self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])}
73
+
74
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
75
  super().__init__(pad_token=self.PAD_TOKEN, bos_token=self.BOS_TOKEN, eos_token=self.EOS_TOKEN, unk_token=self.UNK_TOKEN, **kwargs)
76
 
77
+ @property
78
+ def vocab_size(self):
79
+ return len(self._vocab)
80
+
81
+ def get_vocab(self):
82
+ return dict(self._vocab)
83
+
84
+ def _tokenize(self, text):
85
+
86
+ return [normalize_move(t) for t in text.strip().split()]
87
+
88
+ def _convert_token_to_id(self, token):
89
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
90
+
91
+ def _convert_id_to_token(self, index):
92
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
93
+
94
+ def convert_tokens_to_string(self, tokens):
95
+ return " ".join(t for t in tokens if t not in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])
96
 
97
+ def save_vocabulary(self, save_directory, filename_prefix=None):
98
+ if not os.path.exists(save_directory):
99
+ os.makedirs(save_directory)
100
+ path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
101
+ with open(path, "w") as f:
102
+ json.dump(self._vocab, f, indent=2)
103
+ return (path,)
104
+
105
  @classmethod
106
+ def build_vocab_from_dataset(cls, dataset_name, min_frequency=2, max_vocab_size=1200, **kwargs):
107
+ """Construit un vocabulaire compact et dense."""
108
  from datasets import load_dataset
109
  from collections import Counter
110
 
111
+ # On charge en streaming pour aller vite
112
+ ds = load_dataset(dataset_name, split="train", streaming=True)
113
+ ds = ds.take(50000) # 50k parties suffisent pour voir tous les coups possibles
114
 
115
  counter = Counter()
116
  for ex in ds:
117
+ # On normalise avant de compter !
118
+ moves = [normalize_move(t) for t in ex["text"].split()]
119
  counter.update(moves)
120
 
121
+ # On garde les tokens spéciaux + les N plus fréquents
122
  special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
 
123
  most_common = counter.most_common(max_vocab_size - len(special))
124
 
125
  vocab = {t: i for i, t in enumerate(special + [t for t, c in most_common])}
126
+ return cls(vocab=vocab)