File size: 5,510 Bytes
f9222c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
4-Step Split Tokenizer
Splits moves into: [Piece] -> [From] -> [To] -> [Suffix]
Minimizes vocabulary to ~150 tokens.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer, AutoTokenizer
class ChessTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
# 1. Pieces
PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"]
# 2. Squares
SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"]
# 3. Suffixes (Crucial: (-) represents "No Suffix/Quiet Move")
SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"]
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
# def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
# # 1. Build or Load Vocab first
# self._vocab = vocab
# if vocab_file and os.path.exists(vocab_file):
# with open(vocab_file, "r", encoding="utf-8") as f:
# self._vocab = json.load(f)
# if not self._vocab:
# self._vocab = self._build_split_vocab()
# self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# # 2. Call parent init with explicit tokens to prevent auto-add errors
# super().__init__(
# pad_token=self.PAD_TOKEN,
# bos_token=self.BOS_TOKEN,
# eos_token=self.EOS_TOKEN,
# unk_token=self.UNK_TOKEN,
# **kwargs,
# )
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
# 1. Build or Load Vocab
self._vocab = vocab
if vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
if not self._vocab:
self._vocab = self._build_split_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# 2. Handle Special Tokens Safely
# We "pop" them from kwargs to prevent the "multiple values" error.
# This prioritizes the loaded config (kwargs) if it exists,
# falling back to your class constants if it doesn't.
pad_token = kwargs.pop("pad_token", self.PAD_TOKEN)
bos_token = kwargs.pop("bos_token", self.BOS_TOKEN)
eos_token = kwargs.pop("eos_token", self.EOS_TOKEN)
unk_token = kwargs.pop("unk_token", self.UNK_TOKEN)
# 3. Call parent
super().__init__(
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
**kwargs,
)
def _build_split_vocab(self):
tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
tokens += self.PIECES + self.SQUARES + self.SUFFIXES
# Sort and unique to be safe
unique_tokens = sorted(list(set(tokens)))
return {t: i for i, t in enumerate(unique_tokens)}
def get_vocab(self) -> Dict[str, int]:
"""Required by Hugging Face PreTrainedTokenizer"""
return dict(self._vocab)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def _tokenize(self, text: str) -> List[str]:
moves = text.strip().split()
tokens = []
# Regex: (Piece)(Square)(Square)(Optional Suffix)
pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)")
for move in moves:
match = pattern.match(move)
if match:
p, s, t, suf = match.groups()
tokens.extend([p, s, t])
tokens.append(suf if suf else "(-)")
else:
tokens.append(self.UNK_TOKEN)
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
out = []
specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
clean = [t for t in tokens if t not in specials]
current_move = ""
for i, t in enumerate(clean):
if t == "(-)":
pass
else:
current_move += t
# Every 4th token completes a move
if (i + 1) % 4 == 0:
out.append(current_move)
current_move = ""
if current_move: out.append(current_move)
return " ".join(out)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(path, "w") as f:
json.dump(self._vocab, f)
return (path,)
@classmethod
def build_vocab_from_dataset(cls, *args, **kwargs):
print("Using static 4-Step Split vocabulary.")
return cls()
# Register
AutoTokenizer.register("ChessTokenizer", ChessTokenizer) |