File size: 5,510 Bytes
f9222c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
4-Step Split Tokenizer
Splits moves into: [Piece] -> [From] -> [To] -> [Suffix]
Minimizes vocabulary to ~150 tokens.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer, AutoTokenizer

class ChessTokenizer(PreTrainedTokenizer):
    vocab_files_names = {"vocab_file": "vocab.json"}
    model_input_names = ["input_ids", "attention_mask"]
    
    # 1. Pieces
    PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"]
    # 2. Squares
    SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"]
    # 3. Suffixes (Crucial: (-) represents "No Suffix/Quiet Move")
    SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"]
    
    PAD_TOKEN = "[PAD]"
    BOS_TOKEN = "[BOS]"
    EOS_TOKEN = "[EOS]"
    UNK_TOKEN = "[UNK]"

    # def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
    #     # 1. Build or Load Vocab first
    #     self._vocab = vocab
    #     if vocab_file and os.path.exists(vocab_file):
    #         with open(vocab_file, "r", encoding="utf-8") as f:
    #             self._vocab = json.load(f)
        
    #     if not self._vocab:
    #         self._vocab = self._build_split_vocab()
            
    #     self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
        
    #     # 2. Call parent init with explicit tokens to prevent auto-add errors
    #     super().__init__(
    #         pad_token=self.PAD_TOKEN,
    #         bos_token=self.BOS_TOKEN,
    #         eos_token=self.EOS_TOKEN,
    #         unk_token=self.UNK_TOKEN,
    #         **kwargs,
    #     )
    def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
        # 1. Build or Load Vocab
        self._vocab = vocab
        if vocab_file and os.path.exists(vocab_file):
            with open(vocab_file, "r", encoding="utf-8") as f:
                self._vocab = json.load(f)
        
        if not self._vocab:
            self._vocab = self._build_split_vocab()
            
        self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
        
        # 2. Handle Special Tokens Safely
        # We "pop" them from kwargs to prevent the "multiple values" error.
        # This prioritizes the loaded config (kwargs) if it exists, 
        # falling back to your class constants if it doesn't.
        pad_token = kwargs.pop("pad_token", self.PAD_TOKEN)
        bos_token = kwargs.pop("bos_token", self.BOS_TOKEN)
        eos_token = kwargs.pop("eos_token", self.EOS_TOKEN)
        unk_token = kwargs.pop("unk_token", self.UNK_TOKEN)
        
        # 3. Call parent
        super().__init__(
            pad_token=pad_token,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            **kwargs,
        )

    def _build_split_vocab(self):
        tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
        tokens += self.PIECES + self.SQUARES + self.SUFFIXES
        # Sort and unique to be safe
        unique_tokens = sorted(list(set(tokens)))
        return {t: i for i, t in enumerate(unique_tokens)}

    def get_vocab(self) -> Dict[str, int]:
        """Required by Hugging Face PreTrainedTokenizer"""
        return dict(self._vocab)

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

    def _tokenize(self, text: str) -> List[str]:
        moves = text.strip().split()
        tokens = []
        
        # Regex: (Piece)(Square)(Square)(Optional Suffix)
        pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)")
        
        for move in moves:
            match = pattern.match(move)
            if match:
                p, s, t, suf = match.groups()
                tokens.extend([p, s, t])
                tokens.append(suf if suf else "(-)")
            else:
                tokens.append(self.UNK_TOKEN)
        
        return tokens

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))

    def _convert_id_to_token(self, index: int) -> str:
        return self._ids_to_tokens.get(index, self.UNK_TOKEN)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        out = []
        specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
        clean = [t for t in tokens if t not in specials]
        
        current_move = ""
        for i, t in enumerate(clean):
            if t == "(-)":
                pass
            else:
                current_move += t
            
            # Every 4th token completes a move
            if (i + 1) % 4 == 0:
                out.append(current_move)
                current_move = ""
                
        if current_move: out.append(current_move)
        return " ".join(out)
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
        with open(path, "w") as f:
            json.dump(self._vocab, f)
        return (path,)

    @classmethod
    def build_vocab_from_dataset(cls, *args, **kwargs):
        print("Using static 4-Step Split vocabulary.")
        return cls()

# Register
AutoTokenizer.register("ChessTokenizer", ChessTokenizer)