File size: 4,876 Bytes
b391e74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | from __future__ import annotations
import json, os, re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
# UCI étendu: WPe2e4, BNg8f6(x+*), promotions "=Q", roque "(o)/(O)"
_MOVE_RE = re.compile(r"^(?P<side>[WB])(?P<piece>[PNBRQK])(?P<src>[a-h][1-8])(?P<dst>[a-h][1-8])(?P<suffix>.*)$")
_PROMO_RE = re.compile(r"=([QRBNqrbn])")
def _parse_suffix(suffix: str):
s = (suffix or "").strip()
is_capture = "x" in s
is_check = "+" in s
is_mate = "*" in s
castle = "O-O-O" if "(O)" in s else ("O-O" if "(o)" in s else None)
promo = None
m = _PROMO_RE.search(s)
if m:
promo = m.group(1).lower()
return is_capture, is_check, is_mate, castle, promo
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
if vocab is not None:
self._vocab = vocab
elif vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def _create_default_vocab(self) -> Dict[str, int]:
tokens: List[str] = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
# Side+piece tokens (12)
tokens += [f"[W{p}]" for p in "PNBRQK"]
tokens += [f"[B{p}]" for p in "PNBRQK"]
# 64 squares
tokens += [f"[{f}{r}]" for f in "abcdefgh" for r in "12345678"]
# Flags / castles / promotions
tokens += ["[x]", "[+]", "[#]", "[O-O]", "[O-O-O]"]
tokens += [f"[={p}]" for p in "qrbn"]
return {tok: i for i, tok in enumerate(tokens)}
def _tokenize(self, text: str) -> List[str]:
out: List[str] = []
for move in (text or "").strip().split():
# Raw UCI like e2e4 / e7e8q (no side/piece available)
if re.fullmatch(r"[a-h][1-8][a-h][1-8][qrbn]?", move):
src, dst = move[:2], move[2:4]
out += [f"[{src}]", f"[{dst}]"]
if len(move) == 5:
out += [f"[={move[4]}]"]
continue
m = _MOVE_RE.match(move)
if not m:
out.append(self.UNK_TOKEN)
continue
side = m.group("side") # "W" or "B"
piece = m.group("piece") # P/N/B/R/Q/K
src = f"[{m.group('src')}]"
dst = f"[{m.group('dst')}]"
is_cap, is_chk, is_mate, castle, promo = _parse_suffix(m.group("suffix") or "")
out += [f"[{side}{piece}]", src, dst]
if castle:
out.append(f"[{castle}]")
if is_cap:
out.append("[x]")
if is_mate:
out.append("[#]")
elif is_chk:
out.append("[+]")
if promo:
out.append(f"[={promo}]")
return out
def convert_tokens_to_string(self, tokens: List[str]) -> str:
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,) |