chess-sam-model-v2 / tokenizer.py
Sammy972's picture
Chess Challenge submission by Sammy972
61848e5 verified
"""
Factorized Chess Tokenizer for the Chess Challenge.
Instead of "1 move = 1 token", we represent a move as multiple tokens:
- Side: [W] / [B]
- Piece: [P], [N], [BISHOP], [R], [Q], [K]
- Squares: [e2], [e4], ...
- Optional suffix: [x], [+], [#], [O-O], [O-O-O]
- Optional promotion: [prom_Q], [prom_R], [prom_B], [prom_N]
Important:
- We KEEP squares as tokens so evaluation (regex [a-h][1-8]) can extract UCI moves.
- We decode squares to plain "e2" etc, and promotions to "q/r/b/n" so evaluate.py can detect promotions.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
MOVE_RE = re.compile(
r"^(?P<side>[WB])"
r"(?P<piece>[PNBRQK])"
r"(?P<src>[a-h][1-8])"
r"(?P<dst>[a-h][1-8])"
r"(?P<suffix>.*)$"
)
SQUARE_TOKEN_RE = re.compile(r"^\[[a-h][1-8]\]$")
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
# Special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Avoid duplicate kwargs
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
side_tokens = ["[W]", "[B]"]
piece_tokens = ["[P]", "[N]", "[BISHOP]", "[R]", "[Q]", "[K]"]
square_tokens = [f"[{file}{rank}]" for rank in "12345678" for file in "abcdefgh"]
suffix_tokens = [
"[x]", "[+]", "[#]",
"[O-O]", "[O-O-O]",
"[prom_Q]", "[prom_R]", "[prom_B]", "[prom_N]",
]
vocab_list = special_tokens + side_tokens + piece_tokens + square_tokens + suffix_tokens
return {tok: i for i, tok in enumerate(vocab_list)}
# IMPORTANT: prevent HF from auto-adding BOS/EOS on top of your text
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return token_ids_0
return token_ids_0 + token_ids_1
@classmethod
def build_vocab_from_iterator(cls, iterator, min_frequency: int = 1) -> "ChessTokenizer":
# Fixed vocab (we ignore dataset frequency)
return cls()
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500,
max_samples: Optional[int] = 100000,
) -> "ChessTokenizer":
# Fixed vocab (we ignore dataset frequency)
return cls()
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
tokens: List[str] = []
parts = str(text).strip().split()
specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
for p in parts:
if p in specials:
tokens.append(p)
continue
m = MOVE_RE.match(p)
if not m:
tokens.append(self.UNK_TOKEN)
continue
side = m.group("side")
piece = m.group("piece")
src = m.group("src")
dst = m.group("dst")
suffix = m.group("suffix") or ""
tokens.append("[W]" if side == "W" else "[B]")
if piece == "B":
tokens.append("[BISHOP]")
else:
tokens.append(f"[{piece}]")
tokens.append(f"[{src}]")
tokens.append(f"[{dst}]")
# capture/check/checkmate
if "x" in suffix:
tokens.append("[x]")
if "*" in suffix:
tokens.append("[#]")
elif "+" in suffix:
tokens.append("[+]")
# castling annotation (optional, squares already encode it)
if piece == "K":
if (src, dst) in (("e1", "g1"), ("e8", "g8")) or "(o)" in suffix:
tokens.append("[O-O]")
elif (src, dst) in (("e1", "c1"), ("e8", "c8")) or "(O)" in suffix:
tokens.append("[O-O-O]")
# promotion
if "=" in suffix:
i = suffix.find("=")
if i != -1 and i + 1 < len(suffix):
promo = suffix[i + 1].upper()
if promo in ("Q", "R", "B", "N"):
tokens.append(f"[prom_{promo}]")
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 3))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Decode tokens into a compact string so evaluate.py can extract squares easily.
Examples:
[W] [P] [e2] [e4] -> "WPe2e4"
... [e7] [e8] [prom_Q] -> "WPe7e8q" (promotion detectable)
"""
out: List[str] = []
for t in tokens:
if t in (self.PAD_TOKEN,):
continue
# keep these literal so evaluator can compare EOS if needed
if t in (self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN):
out.append(t)
continue
if t == "[W]":
out.append("W")
elif t == "[B]":
out.append("B")
elif t == "[BISHOP]":
out.append("B")
elif t in ("[P]", "[N]", "[R]", "[Q]", "[K]"):
out.append(t.strip("[]"))
elif SQUARE_TOKEN_RE.match(t):
out.append(t[1:-1]) # "[e2]" -> "e2"
elif t == "[x]":
out.append("(x)")
elif t == "[+]":
out.append("(+)")
elif t == "[#]":
out.append("(+*)")
elif t == "[O-O]":
out.append("(o)")
elif t == "[O-O-O]":
out.append("(O)")
elif t == "[prom_Q]":
out.append("q")
elif t == "[prom_R]":
out.append("r")
elif t == "[prom_B]":
out.append("b")
elif t == "[prom_N]":
out.append("n")
else:
out.append(t)
return "".join(out)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)