File size: 4,441 Bytes
7ec364f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer, AutoTokenizer

class ChessTokenizer(PreTrainedTokenizer):
    vocab_files_names = {"vocab_file": "vocab.json"}
    model_input_names = ["input_ids", "attention_mask"]
    
    PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"]
    SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"]
    SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"]
    
    PAD_TOKEN = "[PAD]"
    BOS_TOKEN = "[BOS]"
    EOS_TOKEN = "[EOS]"
    UNK_TOKEN = "[UNK]"

    
    def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
        # 1. Build or Load Vocab
        self._vocab = vocab
        if vocab_file and os.path.exists(vocab_file):
            with open(vocab_file, "r", encoding="utf-8") as f:
                self._vocab = json.load(f)
        
        if not self._vocab:
            self._vocab = self._build_split_vocab()
            
        self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
       
        pad_token = kwargs.pop("pad_token", self.PAD_TOKEN)
        bos_token = kwargs.pop("bos_token", self.BOS_TOKEN)
        eos_token = kwargs.pop("eos_token", self.EOS_TOKEN)
        unk_token = kwargs.pop("unk_token", self.UNK_TOKEN)

        super().__init__(
            pad_token=pad_token,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            **kwargs,
        )

    def _build_split_vocab(self):
        tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
        tokens += self.PIECES + self.SQUARES + self.SUFFIXES
        unique_tokens = sorted(list(set(tokens)))
        return {t: i for i, t in enumerate(unique_tokens)}

    def get_vocab(self) -> Dict[str, int]:
        return dict(self._vocab)

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

    def _tokenize(self, text: str) -> List[str]:
        moves = text.strip().split()
        tokens = []
        
        
        pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)")
        
        for move in moves:
            match = pattern.match(move)
            if match:
                p, s, t, suf = match.groups()
                tokens.extend([p, s, t])
                tokens.append(self._normalize_suffix(suf))
            else:
                tokens.extend(["WP", "a1", "a1", "(-)"])
        
        return tokens

    def _normalize_suffix(self, suf: str) -> str:
        suf = suf.strip()
        if not suf:
            return "(-)"
        if suf.startswith("x"):
            if "+" in suf: return "(x+)"
            if "#" in suf: return "(x#)"
            return "(x)"
        if suf == "+": return "(+)"
        if suf == "#": return "(#)"
        if suf in {"O", "o"}: return f"({suf})"
        if suf in {"Q", "=Q"}: return "=Q"
        return "(-)"


    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))

    def _convert_id_to_token(self, index: int) -> str:
        return self._ids_to_tokens.get(index, self.UNK_TOKEN)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        out = []
        specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
        clean = [t for t in tokens if t not in specials]
        
        current_move = ""
        for i, t in enumerate(clean):
            if t == "(-)":
                pass
            else:
                current_move += t
            
            if (i + 1) % 4 == 0:
                out.append(current_move)
                current_move = ""
                
        if current_move: out.append(current_move)
        return " ".join(out)
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
        with open(path, "w") as f:
            json.dump(self._vocab, f)
        return (path,)

    @classmethod
    def build_vocab_from_dataset(cls, *args, **kwargs):
        print("Using static 4-Step Split vocabulary.")
        return cls()

# Register
AutoTokenizer.register("ChessTokenizer", ChessTokenizer)