clmrie commited on
Commit
8deeee2
·
verified ·
1 Parent(s): ba6ede2

Chess Challenge submission by clmrie

Browse files
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chess
5
+ - llm-course
6
+ - chess-challenge
7
+ license: mit
8
+ ---
9
+
10
+ # chess-clmrie
11
+
12
+ Chess model submitted to the LLM Course Chess Challenge.
13
+
14
+ ## Submission Info
15
+
16
+ - Submitted by: clmrie
17
+ - Parameters: 991,168
18
+ - Organization: LLM-course
19
+
20
+ ## Model Details
21
+
22
+ - Architecture: chess_transformer (custom)
23
+ - Vocab size: 85
24
+ - Embedding dim: 136
25
+ - Layers: 5
26
+ - Heads: 8
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ChessForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "dropout": 0.1,
7
+ "dtype": "float32",
8
+ "eos_token_id": 2,
9
+ "layer_norm_epsilon": 1e-05,
10
+ "model_type": "chess_transformer",
11
+ "n_ctx": 256,
12
+ "n_embd": 136,
13
+ "n_head": 8,
14
+ "n_inner": 408,
15
+ "n_layer": 5,
16
+ "pad_token_id": 0,
17
+ "tie_weights": false,
18
+ "tie_word_embeddings": false,
19
+ "transformers_version": "4.57.6",
20
+ "vocab_size": 85,
21
+ "auto_map": {
22
+ "AutoConfig": "model.ChessConfig",
23
+ "AutoModelForCausalLM": "model.ChessForCausalLM"
24
+ },
25
+ "unk_token_id": 3
26
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.57.6"
7
+ }
model.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers import PretrainedConfig, PreTrainedModel
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ from transformers.utils.hub import cached_file
16
+
17
+
18
+ def _is_square(tok: str) -> bool:
19
+ return len(tok) == 2 and tok[0] in "abcdefgh" and tok[1] in "12345678"
20
+
21
+
22
+ def _resolve_file(name_or_path: str, filename: str) -> str:
23
+ if isinstance(name_or_path, str) and os.path.isdir(name_or_path):
24
+ p = os.path.join(name_or_path, filename)
25
+ if os.path.exists(p):
26
+ return p
27
+ return cached_file(name_or_path, filename)
28
+
29
+
30
+ def _load_vocab(name_or_path: str) -> Tuple[Dict[str, int], Dict[int, str]]:
31
+ vocab_path = _resolve_file(name_or_path, "vocab.json")
32
+ with open(vocab_path, "r", encoding="utf-8") as f:
33
+ tok2id = json.load(f)
34
+ id2tok = {int(i): t for t, i in tok2id.items()}
35
+ return tok2id, id2tok
36
+
37
+
38
+ @dataclass
39
+ class TokenScheme:
40
+ W: str
41
+ B: str
42
+ pieces: Dict[str, str]
43
+ sep: Optional[str]
44
+ suffix: Dict[str, str]
45
+ prom: Dict[str, str]
46
+ pad_id: int
47
+ bos_id: int
48
+ eos_id: int
49
+ unk_id: int
50
+
51
+
52
+ def _detect_scheme(tok2id: Dict[str, int], config) -> TokenScheme:
53
+ W = "W" if "W" in tok2id else None
54
+ B = "B" if "B" in tok2id else None
55
+ if W is None or B is None:
56
+ raise ValueError("Cannot find W/B tokens in vocab")
57
+
58
+ pieces = {}
59
+ for p in ["P", "N", "B", "R", "Q", "K"]:
60
+ if p in tok2id:
61
+ pieces[p] = p
62
+ else:
63
+ raise ValueError(f"Cannot find piece token {p} in vocab")
64
+
65
+ sep = " " if " " in tok2id else None
66
+
67
+ suffix = {}
68
+ for k, v in [
69
+ ("cap", "(x)"),
70
+ ("cap_check", "(x*)"),
71
+ ("cap_mate", "(x+*)"),
72
+ ("check", "(+)"),
73
+ ("mate", "(+*)"),
74
+ ("o", "(o)"),
75
+ ("O", "(O)"),
76
+ ]:
77
+ if v in tok2id:
78
+ suffix[k] = v
79
+
80
+ prom = {}
81
+ for p, v in [("Q", "(Q)"), ("R", "(R)"), ("B", "(B)"), ("N", "(N)")]:
82
+ if v in tok2id:
83
+ prom[p] = v
84
+
85
+ pad_id = int(getattr(config, "pad_token_id", 0))
86
+ bos_id = int(getattr(config, "bos_token_id", 1))
87
+ eos_id = int(getattr(config, "eos_token_id", 2))
88
+ unk_id = int(getattr(config, "unk_token_id", 3))
89
+
90
+ return TokenScheme(W=W, B=B, pieces=pieces, sep=sep, suffix=suffix, prom=prom,
91
+ pad_id=pad_id, bos_id=bos_id, eos_id=eos_id, unk_id=unk_id)
92
+
93
+
94
+ class ChessConfig(PretrainedConfig):
95
+ model_type = "chess_transformer"
96
+
97
+ def __init__(
98
+ self,
99
+ vocab_size: int = 85,
100
+ n_embd: int = 128,
101
+ n_layer: int = 5,
102
+ n_head: int = 4,
103
+ n_ctx: int = 256,
104
+ n_inner: Optional[int] = None,
105
+ dropout: float = 0.1,
106
+ layer_norm_epsilon: float = 1e-5,
107
+ tie_weights: bool = False,
108
+ pad_token_id: int = 0,
109
+ bos_token_id: int = 1,
110
+ eos_token_id: int = 2,
111
+ unk_token_id: int = 3,
112
+ **kwargs,
113
+ ):
114
+ self.vocab_size = int(vocab_size)
115
+ self.n_embd = int(n_embd)
116
+ self.n_layer = int(n_layer)
117
+ self.n_head = int(n_head)
118
+ self.n_ctx = int(n_ctx)
119
+ self.n_inner = int(n_inner) if n_inner is not None else 3 * int(n_embd)
120
+ self.dropout = float(dropout)
121
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
122
+ self.tie_weights = bool(tie_weights)
123
+
124
+ kwargs["pad_token_id"] = pad_token_id
125
+ kwargs["bos_token_id"] = bos_token_id
126
+ kwargs["eos_token_id"] = eos_token_id
127
+ kwargs["unk_token_id"] = unk_token_id
128
+ super().__init__(**kwargs)
129
+
130
+
131
+ class MLP(nn.Module):
132
+ def __init__(self, config: ChessConfig):
133
+ super().__init__()
134
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
135
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
136
+ self.dropout = nn.Dropout(config.dropout)
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ x = self.c_fc(x)
140
+ x = F.gelu(x)
141
+ x = self.c_proj(x)
142
+ x = self.dropout(x)
143
+ return x
144
+
145
+
146
+ class MultiHeadAttention(nn.Module):
147
+ def __init__(self, config: ChessConfig):
148
+ super().__init__()
149
+ assert config.n_embd % config.n_head == 0
150
+ self.n_head = config.n_head
151
+ self.head_dim = config.n_embd // config.n_head
152
+
153
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
154
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
155
+ self.dropout = nn.Dropout(config.dropout)
156
+
157
+ bias = torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx)
158
+ self.register_buffer("bias", bias, persistent=False)
159
+
160
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
161
+ B, T, C = x.size()
162
+ qkv = self.c_attn(x)
163
+ q, k, v = qkv.split(C, dim=2)
164
+
165
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
166
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
167
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
168
+
169
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
170
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
171
+
172
+ if attention_mask is not None:
173
+ att = att.masked_fill(attention_mask.view(B, 1, 1, T) == 0, float("-inf"))
174
+
175
+ att = F.softmax(att, dim=-1)
176
+ att = self.dropout(att)
177
+
178
+ y = att @ v
179
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
180
+ y = self.c_proj(y)
181
+ y = self.dropout(y)
182
+ return y
183
+
184
+
185
+ class Block(nn.Module):
186
+ def __init__(self, config: ChessConfig):
187
+ super().__init__()
188
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
189
+ self.attn = MultiHeadAttention(config)
190
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
191
+ self.mlp = MLP(config)
192
+
193
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
194
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
195
+ x = x + self.mlp(self.ln_2(x))
196
+ return x
197
+
198
+
199
+ class ChessForCausalLM(PreTrainedModel):
200
+ config_class = ChessConfig
201
+ base_model_prefix = ""
202
+
203
+ def __init__(self, config: ChessConfig):
204
+ super().__init__(config)
205
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
206
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
207
+ self.drop = nn.Dropout(config.dropout)
208
+ self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
209
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
210
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
211
+
212
+ if getattr(config, "tie_weights", False):
213
+ self.lm_head.weight = self.wte.weight
214
+
215
+ self.post_init()
216
+
217
+ self._tok2id = None
218
+ self._id2tok = None
219
+ self._scheme = None
220
+
221
+ def _ensure_vocab(self):
222
+ if self._tok2id is None or self._id2tok is None:
223
+ name_or_path = getattr(self.config, "_name_or_path", None) or getattr(self, "name_or_path", None)
224
+ if not name_or_path:
225
+ raise ValueError("Cannot resolve model path to load vocab.json")
226
+ self._tok2id, self._id2tok = _load_vocab(name_or_path)
227
+
228
+ def _get_scheme(self) -> TokenScheme:
229
+ if self._scheme is None:
230
+ self._ensure_vocab()
231
+ self._scheme = _detect_scheme(self._tok2id, self.config)
232
+ return self._scheme
233
+
234
+ def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True, **kwargs):
235
+ B, T = input_ids.shape
236
+ if T > self.config.n_ctx:
237
+ input_ids = input_ids[:, -self.config.n_ctx :]
238
+ if attention_mask is not None:
239
+ attention_mask = attention_mask[:, -self.config.n_ctx :]
240
+ if labels is not None:
241
+ labels = labels[:, -self.config.n_ctx :]
242
+ B, T = input_ids.shape
243
+
244
+ pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
245
+ x = self.wte(input_ids) + self.wpe(pos)
246
+ x = self.drop(x)
247
+
248
+ for block in self.h:
249
+ x = block(x, attention_mask=attention_mask)
250
+
251
+ x = self.ln_f(x)
252
+ logits = self.lm_head(x)
253
+
254
+ loss = None
255
+ if labels is not None:
256
+ shift_logits = logits[:, :-1].contiguous()
257
+ shift_labels = labels[:, 1:].contiguous()
258
+ loss = F.cross_entropy(
259
+ shift_logits.view(-1, shift_logits.size(-1)),
260
+ shift_labels.view(-1),
261
+ ignore_index=-100,
262
+ )
263
+
264
+ if not return_dict:
265
+ return (logits, loss)
266
+ return CausalLMOutputWithPast(logits=logits, loss=loss)
267
+
268
+ def _ids_to_tokens(self, ids: List[int]) -> List[str]:
269
+ self._ensure_vocab()
270
+ return [self._id2tok.get(int(i), "[UNK]") for i in ids]
271
+
272
+ def _parse_history_to_board(self, input_ids_1d: List[int]):
273
+ import chess
274
+ scheme = self._get_scheme()
275
+ toks = self._ids_to_tokens(input_ids_1d)
276
+
277
+ specials = {"[PAD]", "[BOS]", "[EOS]", "[UNK]"}
278
+ toks = [t for t in toks if t not in specials]
279
+
280
+ b = chess.Board()
281
+ i = 0
282
+ while i < len(toks):
283
+ while i < len(toks) and toks[i] not in (scheme.W, scheme.B):
284
+ i += 1
285
+ if i >= len(toks):
286
+ break
287
+
288
+ i += 1
289
+
290
+ while i < len(toks) and scheme.sep is not None and toks[i] == scheme.sep:
291
+ i += 1
292
+
293
+ if i >= len(toks) or toks[i] not in scheme.pieces.values():
294
+ break
295
+ i += 1
296
+
297
+ while i < len(toks) and scheme.sep is not None and toks[i] == scheme.sep:
298
+ i += 1
299
+
300
+ if i >= len(toks) or not _is_square(toks[i]):
301
+ break
302
+ src = toks[i]
303
+ i += 1
304
+
305
+ while i < len(toks) and scheme.sep is not None and toks[i] == scheme.sep:
306
+ i += 1
307
+
308
+ if i >= len(toks) or not _is_square(toks[i]):
309
+ break
310
+ dst = toks[i]
311
+ i += 1
312
+
313
+ suffixes = []
314
+ while i < len(toks) and toks[i] not in (scheme.W, scheme.B):
315
+ if scheme.sep is not None and toks[i] == scheme.sep:
316
+ i += 1
317
+ continue
318
+ suffixes.append(toks[i])
319
+ i += 1
320
+
321
+ uci = f"{src}{dst}"
322
+ promo = None
323
+ for p, ptok in scheme.prom.items():
324
+ if ptok in suffixes:
325
+ promo = p.lower()
326
+ break
327
+ if promo is not None:
328
+ uci += promo
329
+
330
+ try:
331
+ mv = chess.Move.from_uci(uci)
332
+ if mv in b.legal_moves:
333
+ b.push(mv)
334
+ else:
335
+ break
336
+ except Exception:
337
+ break
338
+
339
+ return b
340
+
341
+ def _move_to_ids(self, board, move_uci: str) -> List[int]:
342
+ import chess
343
+
344
+ scheme = self._get_scheme()
345
+ self._ensure_vocab()
346
+ tok2id = self._tok2id
347
+
348
+ mv = chess.Move.from_uci(move_uci)
349
+
350
+ color_tok = scheme.W if board.turn == chess.WHITE else scheme.B
351
+ piece = board.piece_at(mv.from_square)
352
+ pl = piece.symbol().upper() if piece is not None else "P"
353
+ if pl not in scheme.pieces:
354
+ pl = "P"
355
+
356
+ src = chess.square_name(mv.from_square)
357
+ dst = chess.square_name(mv.to_square)
358
+
359
+ toks = [color_tok, pl]
360
+ if scheme.sep is not None:
361
+ toks += [scheme.sep, src, scheme.sep, dst]
362
+ else:
363
+ toks += [src, dst]
364
+
365
+ is_capture = board.is_capture(mv)
366
+ board.push(mv)
367
+ is_mate = board.is_checkmate()
368
+ is_check = board.is_check()
369
+ board.pop()
370
+
371
+ suffix_tok = None
372
+ if is_capture and is_mate:
373
+ suffix_tok = scheme.suffix.get("cap_mate")
374
+ elif is_capture and is_check:
375
+ suffix_tok = scheme.suffix.get("cap_check")
376
+ elif is_capture:
377
+ suffix_tok = scheme.suffix.get("cap")
378
+ elif is_mate:
379
+ suffix_tok = scheme.suffix.get("mate")
380
+ elif is_check:
381
+ suffix_tok = scheme.suffix.get("check")
382
+
383
+ if suffix_tok is not None:
384
+ toks.append(suffix_tok)
385
+
386
+ if mv.promotion is not None:
387
+ prom = chess.piece_symbol(mv.promotion).upper()
388
+ if prom in scheme.prom:
389
+ toks.append(scheme.prom[prom])
390
+
391
+ if scheme.sep is not None:
392
+ toks.append(scheme.sep)
393
+
394
+ return [tok2id.get(t, scheme.unk_id) for t in toks]
395
+
396
+ @torch.no_grad()
397
+ def _score_candidates(self, prefix_ids, cand_ids_list, attention_mask, temperature, batch_size=64):
398
+ device = prefix_ids.device
399
+ T0 = prefix_ids.size(1)
400
+ scores = torch.empty(len(cand_ids_list), device=device, dtype=torch.float32)
401
+ pad_id = int(self.config.pad_token_id)
402
+
403
+ for start in range(0, len(cand_ids_list), batch_size):
404
+ batch = cand_ids_list[start : start + batch_size]
405
+ max_c = max(len(c) for c in batch)
406
+
407
+ input_ids_list = []
408
+ attn_list = []
409
+
410
+ for c in batch:
411
+ c_ids = torch.tensor(c, device=device, dtype=torch.long).unsqueeze(0)
412
+ seq = torch.cat([prefix_ids, c_ids], dim=1)
413
+ pad_len = (T0 + max_c) - seq.size(1)
414
+ if pad_len > 0:
415
+ pad = torch.full((1, pad_len), pad_id, device=device, dtype=torch.long)
416
+ seq = torch.cat([seq, pad], dim=1)
417
+ input_ids_list.append(seq)
418
+
419
+ if attention_mask is None:
420
+ a = torch.ones((1, seq.size(1)), device=device, dtype=torch.long)
421
+ else:
422
+ a = attention_mask
423
+ if a.size(1) != T0:
424
+ a = a[:, -T0:]
425
+ ones = torch.ones((1, len(c)), device=device, dtype=torch.long)
426
+ zeros = torch.zeros((1, max_c - len(c)), device=device, dtype=torch.long)
427
+ a = torch.cat([a, ones, zeros], dim=1)
428
+ attn_list.append(a)
429
+
430
+ input_ids = torch.cat(input_ids_list, dim=0)
431
+ attn_mask = torch.cat(attn_list, dim=0)
432
+
433
+ out = self.forward(input_ids=input_ids, attention_mask=attn_mask, return_dict=True)
434
+ logits = out.logits / float(max(1e-6, temperature))
435
+ logp = torch.log_softmax(logits, dim=-1)
436
+
437
+ for bi, c in enumerate(batch):
438
+ lp = 0.0
439
+ for j in range(len(c)):
440
+ pos = T0 + j - 1
441
+ if pos < 0:
442
+ continue
443
+ tok_id = int(c[j])
444
+ lp += float(logp[bi, pos, tok_id].item())
445
+ scores[start + bi] = lp
446
+
447
+ return scores
448
+
449
+ def generate(self, input_ids=None, attention_mask=None, max_new_tokens=16, temperature=1.0, do_sample=False, **kwargs):
450
+ import chess
451
+
452
+ if input_ids is None:
453
+ raise ValueError("generate() requires input_ids")
454
+ if input_ids.dim() == 1:
455
+ input_ids = input_ids.unsqueeze(0)
456
+
457
+ if input_ids.size(0) != 1:
458
+ return super().generate(
459
+ input_ids=input_ids,
460
+ attention_mask=attention_mask,
461
+ max_new_tokens=max_new_tokens,
462
+ temperature=temperature,
463
+ do_sample=do_sample,
464
+ **kwargs,
465
+ )
466
+
467
+ try:
468
+ board = self._parse_history_to_board(input_ids[0].tolist())
469
+ except Exception:
470
+ board = None
471
+
472
+ if board is None or board.is_game_over():
473
+ return super().generate(
474
+ input_ids=input_ids,
475
+ attention_mask=attention_mask,
476
+ max_new_tokens=max_new_tokens,
477
+ temperature=temperature,
478
+ do_sample=do_sample,
479
+ **kwargs,
480
+ )
481
+
482
+ legal = list(board.legal_moves)
483
+ if not legal:
484
+ return input_ids
485
+
486
+ cand_ids_list = [self._move_to_ids(board, mv.uci()) for mv in legal]
487
+
488
+ scores = self._score_candidates(
489
+ prefix_ids=input_ids,
490
+ cand_ids_list=cand_ids_list,
491
+ attention_mask=attention_mask,
492
+ temperature=float(temperature),
493
+ batch_size=64,
494
+ )
495
+
496
+ best = int(torch.argmax(scores).item())
497
+ best_ids = torch.tensor(cand_ids_list[best], device=input_ids.device, dtype=torch.long).unsqueeze(0)
498
+
499
+ if best_ids.size(1) > int(max_new_tokens):
500
+ best_ids = best_ids[:, : int(max_new_tokens)]
501
+
502
+ return torch.cat([input_ids, best_ids], dim=1)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b66b88123670200735099e5106816a20fc5f05d2aae87bc01ee747ac7f1f2fc
3
+ size 3970192
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+
23
+
24
+ class ChessTokenizer(PreTrainedTokenizer):
25
+ """
26
+ A custom tokenizer for chess moves using extended UCI notation.
27
+
28
+
29
+ """
30
+
31
+ model_input_names = ["input_ids", "attention_mask"]
32
+ vocab_files_names = {"vocab_file": "vocab.json"}
33
+
34
+ # Special tokens
35
+ PAD_TOKEN = "[PAD]"
36
+ BOS_TOKEN = "[BOS]"
37
+ EOS_TOKEN = "[EOS]"
38
+ UNK_TOKEN = "[UNK]"
39
+
40
+ def __init__(
41
+ self,
42
+ vocab_file: Optional[str] = None,
43
+ vocab: Optional[Dict[str, int]] = None,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ Initialize the chess tokenizer.
48
+
49
+ Args:
50
+ vocab_file: Path to a JSON file containing the vocabulary mapping.
51
+ vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
52
+ **kwargs: Additional arguments passed to PreTrainedTokenizer.
53
+ """
54
+ # Initialize special tokens
55
+ self._pad_token = self.PAD_TOKEN
56
+ self._bos_token = self.BOS_TOKEN
57
+ self._eos_token = self.EOS_TOKEN
58
+ self._unk_token = self.UNK_TOKEN
59
+
60
+ # Remove any duplicate special-token entries passed through kwargs
61
+ # to avoid "multiple values for keyword" errors when loading from disk.
62
+ kwargs.pop("pad_token", None)
63
+ kwargs.pop("bos_token", None)
64
+ kwargs.pop("eos_token", None)
65
+ kwargs.pop("unk_token", None)
66
+
67
+ # Load or create vocabulary
68
+ if vocab is not None:
69
+ self._vocab = vocab
70
+ elif vocab_file is not None and os.path.exists(vocab_file):
71
+ with open(vocab_file, "r", encoding="utf-8") as f:
72
+ self._vocab = json.load(f)
73
+ else:
74
+ # Create a minimal vocabulary with just special tokens
75
+ # The full vocabulary should be built from the dataset
76
+ self._vocab = self._create_fixed_vocab()
77
+
78
+ # Create reverse mapping
79
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
80
+
81
+ # Call parent init AFTER setting up vocab
82
+ super().__init__(
83
+ pad_token=self._pad_token,
84
+ bos_token=self._bos_token,
85
+ eos_token=self._eos_token,
86
+ unk_token=self._unk_token,
87
+ **kwargs,
88
+ )
89
+
90
+ def _create_fixed_vocab(self) -> Dict[str, int]:
91
+
92
+ vocab_list = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
93
+
94
+ # Separate moves
95
+ vocab_list.append(" ")
96
+
97
+ # Colors
98
+ vocab_list.extend(["W", "B"])
99
+
100
+ # Pieces
101
+ vocab_list.extend(["P", "N", "B", "R", "Q", "K"])
102
+
103
+ # 3. Squares
104
+ files = "abcdefgh"
105
+ ranks = "12345678"
106
+ squares = [f"{f}{r}" for f in files for r in ranks]
107
+ vocab_list.extend(sorted(squares))
108
+
109
+ # Suffixes
110
+ suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)", "(Q)", "(K)", "(x*)", "(x+*)"]
111
+ vocab_list.extend(suffixes)
112
+
113
+ vocab_list = list(dict.fromkeys(vocab_list))
114
+
115
+
116
+ return {token: idx for idx, token in enumerate(vocab_list)}
117
+
118
+
119
+ @property
120
+ def vocab_size(self) -> int:
121
+ """Return the size of the vocabulary."""
122
+ return len(self._vocab)
123
+
124
+ def get_vocab(self) -> Dict[str, int]:
125
+ """Return the vocabulary as a dictionary."""
126
+ return dict(self._vocab)
127
+
128
+ def _tokenize(self, text: str) -> List[str]:
129
+ """
130
+ Tokenize a string of moves into atomic components.
131
+ """
132
+ import re
133
+ tokens = []
134
+ moves = text.strip().split()
135
+
136
+ pattern = re.compile(r"^([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)$")
137
+
138
+ for move in moves:
139
+ match = pattern.match(move)
140
+ if match:
141
+ for i in range(1,6):
142
+ if match.group(i) in self._vocab:
143
+ tokens.append(match.group(i))
144
+ tokens.append(' ')
145
+ else:
146
+ tokens.append(self.UNK_TOKEN)
147
+
148
+ return tokens
149
+
150
+ def _convert_token_to_id(self, token: str) -> int:
151
+ """Convert a token to its ID."""
152
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
153
+
154
+ def _convert_id_to_token(self, index: int) -> str:
155
+ """Convert an ID to its token."""
156
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
157
+
158
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
159
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
160
+ return "".join(t for t in tokens if t not in special)
161
+
162
+ def save_vocabulary(
163
+ self,
164
+ save_directory: str,
165
+ filename_prefix: Optional[str] = None,
166
+ ) -> tuple:
167
+ """
168
+ Save the vocabulary to a JSON file.
169
+
170
+ Args:
171
+ save_directory: Directory to save the vocabulary.
172
+ filename_prefix: Optional prefix for the filename.
173
+
174
+ Returns:
175
+ Tuple containing the path to the saved vocabulary file.
176
+ """
177
+ if not os.path.isdir(save_directory):
178
+ os.makedirs(save_directory, exist_ok=True)
179
+
180
+ vocab_file = os.path.join(
181
+ save_directory,
182
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
183
+ )
184
+
185
+ with open(vocab_file, "w", encoding="utf-8") as f:
186
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
187
+
188
+ return (vocab_file,)
189
+
190
+
tokenizer_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "auto_map": {
37
+ "AutoTokenizer": [
38
+ "tokenizer.ChessTokenizer",
39
+ null
40
+ ]
41
+ },
42
+ "bos_token": "[BOS]",
43
+ "clean_up_tokenization_spaces": false,
44
+ "eos_token": "[EOS]",
45
+ "extra_special_tokens": {},
46
+ "model_max_length": 1000000000000000019884624838656,
47
+ "pad_token": "[PAD]",
48
+ "tokenizer_class": "ChessTokenizer",
49
+ "unk_token": "[UNK]"
50
+ }
vocab.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[BOS]": 1,
4
+ "[EOS]": 2,
5
+ "[UNK]": 3,
6
+ " ": 4,
7
+ "W": 5,
8
+ "B": 6,
9
+ "P": 7,
10
+ "N": 8,
11
+ "R": 9,
12
+ "Q": 10,
13
+ "K": 11,
14
+ "a1": 12,
15
+ "a2": 13,
16
+ "a3": 14,
17
+ "a4": 15,
18
+ "a5": 16,
19
+ "a6": 17,
20
+ "a7": 18,
21
+ "a8": 19,
22
+ "b1": 20,
23
+ "b2": 21,
24
+ "b3": 22,
25
+ "b4": 23,
26
+ "b5": 24,
27
+ "b6": 25,
28
+ "b7": 26,
29
+ "b8": 27,
30
+ "c1": 28,
31
+ "c2": 29,
32
+ "c3": 30,
33
+ "c4": 31,
34
+ "c5": 32,
35
+ "c6": 33,
36
+ "c7": 34,
37
+ "c8": 35,
38
+ "d1": 36,
39
+ "d2": 37,
40
+ "d3": 38,
41
+ "d4": 39,
42
+ "d5": 40,
43
+ "d6": 41,
44
+ "d7": 42,
45
+ "d8": 43,
46
+ "e1": 44,
47
+ "e2": 45,
48
+ "e3": 46,
49
+ "e4": 47,
50
+ "e5": 48,
51
+ "e6": 49,
52
+ "e7": 50,
53
+ "e8": 51,
54
+ "f1": 52,
55
+ "f2": 53,
56
+ "f3": 54,
57
+ "f4": 55,
58
+ "f5": 56,
59
+ "f6": 57,
60
+ "f7": 58,
61
+ "f8": 59,
62
+ "g1": 60,
63
+ "g2": 61,
64
+ "g3": 62,
65
+ "g4": 63,
66
+ "g5": 64,
67
+ "g6": 65,
68
+ "g7": 66,
69
+ "g8": 67,
70
+ "h1": 68,
71
+ "h2": 69,
72
+ "h3": 70,
73
+ "h4": 71,
74
+ "h5": 72,
75
+ "h6": 73,
76
+ "h7": 74,
77
+ "h8": 75,
78
+ "(x)": 76,
79
+ "(+)": 77,
80
+ "(+*)": 78,
81
+ "(o)": 79,
82
+ "(O)": 80,
83
+ "(Q)": 81,
84
+ "(K)": 82,
85
+ "(x*)": 83,
86
+ "(x+*)": 84
87
+ }