File size: 4,441 Bytes
7ec364f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer, AutoTokenizer
class ChessTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"]
SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"]
SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"]
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
# 1. Build or Load Vocab
self._vocab = vocab
if vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
if not self._vocab:
self._vocab = self._build_split_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
pad_token = kwargs.pop("pad_token", self.PAD_TOKEN)
bos_token = kwargs.pop("bos_token", self.BOS_TOKEN)
eos_token = kwargs.pop("eos_token", self.EOS_TOKEN)
unk_token = kwargs.pop("unk_token", self.UNK_TOKEN)
super().__init__(
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
**kwargs,
)
def _build_split_vocab(self):
tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
tokens += self.PIECES + self.SQUARES + self.SUFFIXES
unique_tokens = sorted(list(set(tokens)))
return {t: i for i, t in enumerate(unique_tokens)}
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def _tokenize(self, text: str) -> List[str]:
moves = text.strip().split()
tokens = []
pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)")
for move in moves:
match = pattern.match(move)
if match:
p, s, t, suf = match.groups()
tokens.extend([p, s, t])
tokens.append(self._normalize_suffix(suf))
else:
tokens.extend(["WP", "a1", "a1", "(-)"])
return tokens
def _normalize_suffix(self, suf: str) -> str:
suf = suf.strip()
if not suf:
return "(-)"
if suf.startswith("x"):
if "+" in suf: return "(x+)"
if "#" in suf: return "(x#)"
return "(x)"
if suf == "+": return "(+)"
if suf == "#": return "(#)"
if suf in {"O", "o"}: return f"({suf})"
if suf in {"Q", "=Q"}: return "=Q"
return "(-)"
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
out = []
specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
clean = [t for t in tokens if t not in specials]
current_move = ""
for i, t in enumerate(clean):
if t == "(-)":
pass
else:
current_move += t
if (i + 1) % 4 == 0:
out.append(current_move)
current_move = ""
if current_move: out.append(current_move)
return " ".join(out)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(path, "w") as f:
json.dump(self._vocab, f)
return (path,)
@classmethod
def build_vocab_from_dataset(cls, *args, **kwargs):
print("Using static 4-Step Split vocabulary.")
return cls()
# Register
AutoTokenizer.register("ChessTokenizer", ChessTokenizer) |