File size: 6,669 Bytes
bb88ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Atomic Chess Tokenizer.

Decomposes chess moves into atomic components:
[Piece] + [Source] + [Destination] + [Suffix]

Example: "WPe2e4(x)" -> ["WP", "e2", "e4", "(x)"]

Benefits:
- Drastically reduces vocab size (~1200 -> ~90)
- Saves ~140k parameters in the embedding layer
- Allows the model to learn spatial relationships (e2 is close to e3)
"""

from __future__ import annotations

import json
import os
import re
from typing import Dict, List, Optional

from transformers import PreTrainedTokenizer

class ChessTokenizer(PreTrainedTokenizer):
    
    model_input_names = ["input_ids", "attention_mask"]
    
    # Special tokens
    PAD_TOKEN = "[PAD]"
    BOS_TOKEN = "[BOS]"
    EOS_TOKEN = "[EOS]"
    UNK_TOKEN = "[UNK]"
    
    # Regex to parse the extended UCI format
    # Groups: 1=Piece, 2=Source, 3=Dest, 4=Suffix
    MOVE_REGEX = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)")

    def __init__(
        self,
        vocab_file: Optional[str] = None,
        vocab: Optional[Dict[str, int]] = None,
        **kwargs,
    ):
        self._pad_token = self.PAD_TOKEN
        self._bos_token = self.BOS_TOKEN
        self._eos_token = self.EOS_TOKEN
        self._unk_token = self.UNK_TOKEN
        
        # Clean kwargs
        kwargs.pop("pad_token", None)
        kwargs.pop("bos_token", None)
        kwargs.pop("eos_token", None)
        kwargs.pop("unk_token", None)
        
        if vocab is not None:
            self._vocab = vocab
        elif vocab_file is not None and os.path.exists(vocab_file):
            with open(vocab_file, "r", encoding="utf-8") as f:
                self._vocab = json.load(f)
        else:
            self._vocab = self._create_atomic_vocab()
            
        self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
        
        super().__init__(
            pad_token=self._pad_token,
            bos_token=self._bos_token,
            eos_token=self._eos_token,
            unk_token=self._unk_token,
            **kwargs,
        )

    def _create_atomic_vocab(self) -> Dict[str, int]:
        """
        Manually builds the vocabulary because we know the rules of Chess.
        We don't need to learn this from the dataset.
        """
        vocab = {}
        idx = 0
        
        # 1. Special Tokens
        for token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
            vocab[token] = idx
            idx += 1
            
        # 2. Pieces (Color + Type)
        colors = ['W', 'B']
        pieces = ['P', 'N', 'B', 'R', 'Q', 'K']
        for c in colors:
            for p in pieces:
                vocab[f"{c}{p}"] = idx
                idx += 1
                
        # 3. Squares (a1 to h8)
        files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
        ranks = ['1', '2', '3', '4', '5', '6', '7', '8']
        for f in files:
            for r in ranks:
                vocab[f"{f}{r}"] = idx
                idx += 1
                
        # 4. Common Suffixes (derived from Lichess notation)
        # (x)=capture, (+)=check, (#)=mate, (o)=castling
        suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)", "=", "=Q", "=R", "=B", "=N"]
        for s in suffixes:
            vocab[s] = idx
            idx += 1
            
        return vocab

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

    def get_vocab(self) -> Dict[str, int]:
        return dict(self._vocab)

    def _tokenize(self, text: str) -> List[str]:
        """
        Splits a string of moves into atomic tokens.
        "WPe2e4" -> ["WP", "e2", "e4"]
        """
        raw_moves = text.strip().split()
        tokens = []
        
        for move in raw_moves:
            match = self.MOVE_REGEX.match(move)
            if match:
                # Add piece, source, dest
                tokens.extend([match.group(1), match.group(2), match.group(3)])
                # Add suffix if it exists
                suffix = match.group(4)
                if suffix:
                    tokens.append(suffix)
            else:
                # Fallback for weird formatting (or UNK)
                tokens.append(move)
                
        return tokens

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))

    def _convert_id_to_token(self, index: int) -> str:
        return self._ids_to_tokens.get(index, self.UNK_TOKEN)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """
        Reconstructs moves from atomic tokens.
        This is tricky because we need to join them without spaces, 
        but add spaces between actual moves.
        """
        out = []
        current_move = []
        
        special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
        
        for t in tokens:
            if t in special:
                continue
            
            current_move.append(t)
            
            # Logic to decide when a move ends
            # A move usually ends after a Suffix OR after a Destination square if no suffix follows
            # This heuristic is simple: if we have a piece, src, and dest, check next token
            
            # Simplified reconstruction:
            # Just join everything and use a heuristic to insert spaces?
            # Better: The model generates atomic tokens. 
            # We know a move starts with [WB][PNBRQK].
            
        # Robust reconstruction approach:
        full_str = "".join([t for t in tokens if t not in special])
        
        # Insert space before every Piece token (except the first one)
        # Regex lookbehind isn't strictly necessary, we can just replace
        formatted = re.sub(r'(?<!^)([WB][PNBRQK])', r' \1', full_str)
        
        return formatted

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory, exist_ok=True)
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self._vocab, f, ensure_ascii=False, indent=2)
        return (vocab_file,)

    # We don't really need build_vocab_from_dataset anymore as we hardcoded the rules,
    # but we keep the method signature to satisfy the template.
    @classmethod
    def build_vocab_from_dataset(cls, *args, **kwargs):
        print("Note: Atomic tokenizer uses a static vocabulary rule set.")
        return cls()