chess-khanghoang0902-v4 / tokenizer.py
khanghoang0902's picture
update model
5912985 verified
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional, Tuple, Any, Union, Sequence
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
SIDE_W = "SIDE_W"
SIDE_B = "SIDE_B"
PROMO_PREFIX = "PROMO_"
CAPTURE = "CAPTURE"
CHECK = "CHECK"
MATE = "MATE"
CASTLE = "CASTLE"
PIECES = ["P", "N", "B", "R", "Q", "K"]
MOVE_RE = re.compile(
r"^(?P<side>[WB])"
r"(?P<piece>[PNBRQK])"
r"(?P<from>[a-h][1-8])"
r"(?P<to>[a-h][1-8])"
r"(?P<rest>.*)$"
)
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._build_fixed_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _build_fixed_vocab(self) -> Dict[str, int]:
special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
sides = [self.SIDE_W, self.SIDE_B]
pieces = [f"PIECE_{p}" for p in self.PIECES]
squares = [f"SQ_{file}{rank}" for file in "abcdefgh" for rank in "12345678"]
promos = [f"{self.PROMO_PREFIX}{p}" for p in ["Q", "R", "B", "N"]]
flags = [self.CAPTURE, self.CHECK, self.MATE, self.CASTLE]
tokens = special + sides + pieces + squares + promos + flags
return {tok: i for i, tok in enumerate(tokens)}
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
out: List[str] = []
for move in text.strip().split():
out.extend(self._tokenize_move(move))
return out
def _tokenize_move(self, move: str) -> List[str]:
m = self.MOVE_RE.match(move)
if not m: return [self.UNK_TOKEN]
side = m.group("side")
piece = m.group("piece")
frm = m.group("from")
to = m.group("to")
rest = m.group("rest") or ""
tokens: List[str] = []
tokens.append(self.SIDE_W if side == "W" else self.SIDE_B)
tokens.append(f"PIECE_{piece}")
tokens.append(f"SQ_{frm}")
tokens.append(f"SQ_{to}")
promo = self._parse_promotion(rest)
if promo is not None:
tokens.append(f"{self.PROMO_PREFIX}{promo}")
if "(x)" in rest: tokens.append(self.CAPTURE)
if "++" in rest or "(+*)" in rest or "#" in rest:
tokens.append(self.MATE)
elif "+" in rest or "(+)" in rest:
tokens.append(self.CHECK)
if "(o)" in rest or "(O)" in rest:
tokens.append(self.CASTLE)
return tokens
def _parse_promotion(self, rest: str) -> Optional[str]:
m = re.search(r"=([QRBNqrbn])", rest)
if m: return m.group(1).upper()
m2 = re.search(r"([QRBNqrbn])", rest)
if m2 and "(" not in rest:
if rest.strip() in ["Q", "R", "B", "N", "q", "r", "b", "n"]:
return rest.strip().upper()
return None
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
output = []
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
for t in tokens:
if t in special: continue
if t == self.SIDE_W: output.append("W")
elif t == self.SIDE_B: output.append("B")
elif t.startswith("PIECE_"): output.append(t.replace("PIECE_", ""))
elif t.startswith("SQ_"): output.append(t.replace("SQ_", ""))
elif t.startswith(self.PROMO_PREFIX): output.append("=" + t.replace(self.PROMO_PREFIX, ""))
elif t == self.CAPTURE: output.append("(x)")
elif t == self.CHECK: output.append("(+)")
elif t == self.MATE: output.append("(+*)")
elif t == self.CASTLE: output.append("(o)")
else:
pass
return "".join(output)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def decode(self, token_ids: Union[int, Sequence[int]], skip_special_tokens: bool = False, **kwargs) -> str:
if isinstance(token_ids, int): ids = [token_ids]
elif "torch" in str(type(token_ids)): ids = token_ids.detach().cpu().flatten().tolist()
else: ids = list(token_ids)
toks = [self._convert_id_to_token(i) for i in ids]
if skip_special_tokens:
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
toks = [t for t in toks if t not in special]
return self.convert_tokens_to_string(toks)