File size: 5,559 Bytes
d3b95ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Custom Atomic Chess Tokenizer for the Chess Challenge.
Strategy: Component-level tokenization (W, P, e2, e4) to save vocabulary size.
"""

from __future__ import annotations

import json
import os
from typing import Dict, List, Optional, Tuple


from transformers import PreTrainedTokenizer

class ChessTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids", "attention_mask"]
    
    def __init__(self, vocab_file: str = None, **kwargs):
        # 1. 定义原子词表
        self.special_tokens = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
        self.colors = ["W", "B"]
        self.pieces = ["P", "N", "B", "R", "Q", "K"]
        self.squares = [f"{c}{r}" for c in "abcdefgh" for r in range(1, 9)] # a1...h8
        self.suffixes = ["x", "+", "#", "=", "O-O", "O-O-O"] # captures, checks, castling
        
        # 2. 合并所有 Token
        all_tokens = self.special_tokens + self.colors + self.pieces + self.squares + self.suffixes
        
        # 3. 构建内存中的字典
        self.vocab = {t: i for i, t in enumerate(all_tokens)}
        self.ids_to_tokens = {i: t for t, i in self.vocab.items()}

        kwargs.pop("pad_token", None)
        kwargs.pop("bos_token", None)
        kwargs.pop("eos_token", None)
        kwargs.pop("unk_token", None)
        
        # 4. 初始化父类
        super().__init__(
            pad_token="[PAD]",
            bos_token="[BOS]",
            eos_token="[EOS]",
            unk_token="[UNK]",
            **kwargs
        )
    
    @property
    def vocab_size(self) -> int:
        return len(self.vocab)
    
    def get_vocab(self) -> Dict[str, int]:
        return dict(self.vocab)

    def _tokenize(self, text: str) -> List[str]:
        """
        Input: "WPe2e4 BNg8f6"
        Output: ['W', 'P', 'e2', 'e4', 'B', 'N', 'g8', 'f6']
        """
        tokens = []
        moves = text.strip().split()
        
        for move in moves:
            # 1. 处理特殊易位
            if "O-O" in move:
                tokens.append(move)
                continue
            
            # 2. 线性扫描拆解 (Greedy Match)
            # 我们只需要不断从字符串头部切下最长的合法Token
            remaining = move
            while remaining:
                matched = False
                # 尝试从长度2的Token开始匹配 (如 e4, e2, x)
                # 因为我们的词表里最长的普通Token就是2个字符 (a1, x, +, P, W)
                # 除了易位(已处理)
                
                # 优先匹配2个字符的 (主要是坐标 a1-h8)
                if len(remaining) >= 2 and remaining[:2] in self.vocab:
                    tokens.append(remaining[:2])
                    remaining = remaining[2:]
                    matched = True
                    continue
                
                # 匹配1个字符的 (W, B, P, N, x, +)
                if len(remaining) >= 1 and remaining[:1] in self.vocab:
                    tokens.append(remaining[:1])
                    remaining = remaining[1:]
                    matched = True
                    continue
                
                # 如果都匹配不上,说明有脏数据,简单跳过或作为UNK处理
                if not matched:
                    # 为了防止死循环,强制消费一个字符
                    # 实际训练中你可以选择 tokens.append(self.unk_token)
                    remaining = remaining[1:]
        
        return tokens

    def _convert_token_to_id(self, token: str) -> int:
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index: int) -> str:
        return self.ids_to_tokens.get(index, self.unk_token)

    # --- 👇 新增的关键方法 1: 保存词表 ---
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        保存 vocab.json 到指定目录。没有这个,save_pretrained 会出问题。
        """
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory, exist_ok=True)
            
        vocab_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.vocab, f, ensure_ascii=False)
            
        return (vocab_file,)

    # --- 👇 新增的关键方法 2: 还原字符串 ---
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """
        将 Token 列表还原为棋谱字符串。
        Input: ['W', 'P', 'e2', 'e4', 'B', 'P', 'e7', 'e5']
        Output: "WPe2e4 BPe7e5"
        """
        out_string = []
        for t in tokens:
            # 过滤特殊 Token
            if t in self.special_tokens:
                continue
                
            # 逻辑:如果这个 Token 是颜色 ('W'/'B') 或者是易位 ('O-O')
            # 说明它是一个新动作的开始,前面需要加空格
            # (除非它是整个句子的第一个)
            if t in self.colors or "O-O" in t:
                if out_string: # 如果不是第一个
                    out_string.append(" ")
            
            out_string.append(t)
            
        return "".join(out_string).strip()

    # 可选:提供一个类方法来构建(虽然这里是硬编码,但为了接口兼容)
    @classmethod
    def build_vocab_from_dataset(cls, *args, **kwargs):
        return cls()