chess_model_lucas_meb_3 / tokenizer.py
luluM's picture
Chess Challenge submission by luluM
b391e74 verified
from __future__ import annotations
import json, os, re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
# UCI étendu: WPe2e4, BNg8f6(x+*), promotions "=Q", roque "(o)/(O)"
_MOVE_RE = re.compile(r"^(?P<side>[WB])(?P<piece>[PNBRQK])(?P<src>[a-h][1-8])(?P<dst>[a-h][1-8])(?P<suffix>.*)$")
_PROMO_RE = re.compile(r"=([QRBNqrbn])")
def _parse_suffix(suffix: str):
s = (suffix or "").strip()
is_capture = "x" in s
is_check = "+" in s
is_mate = "*" in s
castle = "O-O-O" if "(O)" in s else ("O-O" if "(o)" in s else None)
promo = None
m = _PROMO_RE.search(s)
if m:
promo = m.group(1).lower()
return is_capture, is_check, is_mate, castle, promo
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
if vocab is not None:
self._vocab = vocab
elif vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def _create_default_vocab(self) -> Dict[str, int]:
tokens: List[str] = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
# Side+piece tokens (12)
tokens += [f"[W{p}]" for p in "PNBRQK"]
tokens += [f"[B{p}]" for p in "PNBRQK"]
# 64 squares
tokens += [f"[{f}{r}]" for f in "abcdefgh" for r in "12345678"]
# Flags / castles / promotions
tokens += ["[x]", "[+]", "[#]", "[O-O]", "[O-O-O]"]
tokens += [f"[={p}]" for p in "qrbn"]
return {tok: i for i, tok in enumerate(tokens)}
def _tokenize(self, text: str) -> List[str]:
out: List[str] = []
for move in (text or "").strip().split():
# Raw UCI like e2e4 / e7e8q (no side/piece available)
if re.fullmatch(r"[a-h][1-8][a-h][1-8][qrbn]?", move):
src, dst = move[:2], move[2:4]
out += [f"[{src}]", f"[{dst}]"]
if len(move) == 5:
out += [f"[={move[4]}]"]
continue
m = _MOVE_RE.match(move)
if not m:
out.append(self.UNK_TOKEN)
continue
side = m.group("side") # "W" or "B"
piece = m.group("piece") # P/N/B/R/Q/K
src = f"[{m.group('src')}]"
dst = f"[{m.group('dst')}]"
is_cap, is_chk, is_mate, castle, promo = _parse_suffix(m.group("suffix") or "")
out += [f"[{side}{piece}]", src, dst]
if castle:
out.append(f"[{castle}]")
if is_cap:
out.append("[x]")
if is_mate:
out.append("[#]")
elif is_chk:
out.append("[+]")
if promo:
out.append(f"[={promo}]")
return out
def convert_tokens_to_string(self, tokens: List[str]) -> str:
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)