File size: 5,559 Bytes
0a6bfa6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | """
Custom Atomic Chess Tokenizer for the Chess Challenge.
Strategy: Component-level tokenization (W, P, e2, e4) to save vocabulary size.
"""
from __future__ import annotations
import json
import os
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file: str = None, **kwargs):
# 1. 定义原子词表
self.special_tokens = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
self.colors = ["W", "B"]
self.pieces = ["P", "N", "B", "R", "Q", "K"]
self.squares = [f"{c}{r}" for c in "abcdefgh" for r in range(1, 9)] # a1...h8
self.suffixes = ["x", "+", "#", "=", "O-O", "O-O-O"] # captures, checks, castling
# 2. 合并所有 Token
all_tokens = self.special_tokens + self.colors + self.pieces + self.squares + self.suffixes
# 3. 构建内存中的字典
self.vocab = {t: i for i, t in enumerate(all_tokens)}
self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# 4. 初始化父类
super().__init__(
pad_token="[PAD]",
bos_token="[BOS]",
eos_token="[EOS]",
unk_token="[UNK]",
**kwargs
)
@property
def vocab_size(self) -> int:
return len(self.vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self.vocab)
def _tokenize(self, text: str) -> List[str]:
"""
Input: "WPe2e4 BNg8f6"
Output: ['W', 'P', 'e2', 'e4', 'B', 'N', 'g8', 'f6']
"""
tokens = []
moves = text.strip().split()
for move in moves:
# 1. 处理特殊易位
if "O-O" in move:
tokens.append(move)
continue
# 2. 线性扫描拆解 (Greedy Match)
# 我们只需要不断从字符串头部切下最长的合法Token
remaining = move
while remaining:
matched = False
# 尝试从长度2的Token开始匹配 (如 e4, e2, x)
# 因为我们的词表里最长的普通Token就是2个字符 (a1, x, +, P, W)
# 除了易位(已处理)
# 优先匹配2个字符的 (主要是坐标 a1-h8)
if len(remaining) >= 2 and remaining[:2] in self.vocab:
tokens.append(remaining[:2])
remaining = remaining[2:]
matched = True
continue
# 匹配1个字符的 (W, B, P, N, x, +)
if len(remaining) >= 1 and remaining[:1] in self.vocab:
tokens.append(remaining[:1])
remaining = remaining[1:]
matched = True
continue
# 如果都匹配不上,说明有脏数据,简单跳过或作为UNK处理
if not matched:
# 为了防止死循环,强制消费一个字符
# 实际训练中你可以选择 tokens.append(self.unk_token)
remaining = remaining[1:]
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index: int) -> str:
return self.ids_to_tokens.get(index, self.unk_token)
# --- 👇 新增的关键方法 1: 保存词表 ---
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
保存 vocab.json 到指定目录。没有这个,save_pretrained 会出问题。
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, ensure_ascii=False)
return (vocab_file,)
# --- 👇 新增的关键方法 2: 还原字符串 ---
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
将 Token 列表还原为棋谱字符串。
Input: ['W', 'P', 'e2', 'e4', 'B', 'P', 'e7', 'e5']
Output: "WPe2e4 BPe7e5"
"""
out_string = []
for t in tokens:
# 过滤特殊 Token
if t in self.special_tokens:
continue
# 逻辑:如果这个 Token 是颜色 ('W'/'B') 或者是易位 ('O-O')
# 说明它是一个新动作的开始,前面需要加空格
# (除非它是整个句子的第一个)
if t in self.colors or "O-O" in t:
if out_string: # 如果不是第一个
out_string.append(" ")
out_string.append(t)
return "".join(out_string).strip()
# 可选:提供一个类方法来构建(虽然这里是硬编码,但为了接口兼容)
@classmethod
def build_vocab_from_dataset(cls, *args, **kwargs):
return cls() |