File size: 4,907 Bytes
95a9cb5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
_SQUARE_RE = re.compile(r"[a-h][1-8]")
_PROMO_RE = re.compile(r"=([QRBNqrbn])")
def _all_squares() -> List[str]:
files = "abcdefgh"
ranks = "12345678"
return [f + r for r in ranks for f in files]
class ChessSquareTokenizer(PreTrainedTokenizer):
"""
We read strings like "WPe2e4" or "BPd7d8=Q" and turn them into tokens.
We also insert [EOS] after each move so generation can stop cleanly.
"""
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
W_TOKEN = "W"
B_TOKEN = "B"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
if vocab is not None:
self._vocab = dict(vocab)
elif vocab_file is not None and Path(vocab_file).exists():
self._vocab = json.loads(Path(vocab_file).read_text(encoding="utf-8"))
else:
self._vocab = self._build_default_vocab()
self._ids_to_tokens = {i: t for t, i in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
@staticmethod
def _build_default_vocab() -> Dict[str, int]:
special = [
ChessSquareTokenizer.PAD_TOKEN,
ChessSquareTokenizer.BOS_TOKEN,
ChessSquareTokenizer.EOS_TOKEN,
ChessSquareTokenizer.UNK_TOKEN,
]
turns = [ChessSquareTokenizer.W_TOKEN, ChessSquareTokenizer.B_TOKEN]
squares = _all_squares()
promos = ["q", "r", "b", "n"]
tokens = special + turns + squares + promos
return {t: i for i, t in enumerate(tokens)}
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def _tokenize(self, text: str) -> List[str]:
# Input is a list of moves separated by spaces.
tokens: List[str] = []
for chunk in text.strip().split():
if chunk in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN):
tokens.append(chunk)
continue
# Moves in the dataset start with W or B.
if chunk and chunk[0] in ("W", "B"):
tokens.append(chunk[0])
from_sq, to_sq, promo = self._parse_move_chunk(chunk)
if from_sq is None or to_sq is None:
tokens.append(self.UNK_TOKEN)
continue
tokens.append(from_sq)
tokens.append(to_sq)
if promo is not None:
tokens.append(promo)
# End-of-move marker.
tokens.append(self.EOS_TOKEN)
return tokens
@staticmethod
def _parse_move_chunk(chunk: str) -> Tuple[Optional[str], Optional[str], Optional[str]]:
# Grab the first two squares we see.
squares = _SQUARE_RE.findall(chunk)
if len(squares) < 2:
return None, None, None
from_sq, to_sq = squares[0], squares[1]
# Promotion shows up like "=Q".
promo = None
m = _PROMO_RE.search(chunk)
if m:
promo = m.group(1).lower()
if promo not in {"q", "r", "b", "n"}:
promo = None
return from_sq, to_sq, promo
def convert_tokens_to_string(self, tokens: List[str]) -> str:
# Keep squares and promo tokens, drop PAD for cleanliness.
return " ".join(t for t in tokens if t != self.PAD_TOKEN)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
save_dir = Path(save_directory)
save_dir.mkdir(parents=True, exist_ok=True)
fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
path = save_dir / fname
path.write_text(json.dumps(self._vocab, indent=2), encoding="utf-8")
return (str(path),)
|