File size: 4,907 Bytes
95a9cb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from transformers import PreTrainedTokenizer


_SQUARE_RE = re.compile(r"[a-h][1-8]")
_PROMO_RE = re.compile(r"=([QRBNqrbn])")


def _all_squares() -> List[str]:
    files = "abcdefgh"
    ranks = "12345678"
    return [f + r for r in ranks for f in files]


class ChessSquareTokenizer(PreTrainedTokenizer):
    """
    We read strings like "WPe2e4" or "BPd7d8=Q" and turn them into tokens.
    We also insert [EOS] after each move so generation can stop cleanly.
    """

    vocab_files_names = {"vocab_file": "vocab.json"}
    model_input_names = ["input_ids", "attention_mask"]

    PAD_TOKEN = "[PAD]"
    BOS_TOKEN = "[BOS]"
    EOS_TOKEN = "[EOS]"
    UNK_TOKEN = "[UNK]"
    W_TOKEN = "W"
    B_TOKEN = "B"

    def __init__(
        self,
        vocab_file: Optional[str] = None,
        vocab: Optional[Dict[str, int]] = None,
        **kwargs,
    ):
        self._pad_token = self.PAD_TOKEN
        self._bos_token = self.BOS_TOKEN
        self._eos_token = self.EOS_TOKEN
        self._unk_token = self.UNK_TOKEN

        kwargs.pop("pad_token", None)
        kwargs.pop("bos_token", None)
        kwargs.pop("eos_token", None)
        kwargs.pop("unk_token", None)

        if vocab is not None:
            self._vocab = dict(vocab)
        elif vocab_file is not None and Path(vocab_file).exists():
            self._vocab = json.loads(Path(vocab_file).read_text(encoding="utf-8"))
        else:
            self._vocab = self._build_default_vocab()

        self._ids_to_tokens = {i: t for t, i in self._vocab.items()}

        super().__init__(
            pad_token=self._pad_token,
            bos_token=self._bos_token,
            eos_token=self._eos_token,
            unk_token=self._unk_token,
            **kwargs,
        )

    @staticmethod
    def _build_default_vocab() -> Dict[str, int]:
        special = [
            ChessSquareTokenizer.PAD_TOKEN,
            ChessSquareTokenizer.BOS_TOKEN,
            ChessSquareTokenizer.EOS_TOKEN,
            ChessSquareTokenizer.UNK_TOKEN,
        ]
        turns = [ChessSquareTokenizer.W_TOKEN, ChessSquareTokenizer.B_TOKEN]
        squares = _all_squares()
        promos = ["q", "r", "b", "n"]
        tokens = special + turns + squares + promos
        return {t: i for i, t in enumerate(tokens)}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

    def get_vocab(self) -> Dict[str, int]:
        return dict(self._vocab)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab.get(token, self._vocab[self.UNK_TOKEN])

    def _convert_id_to_token(self, index: int) -> str:
        return self._ids_to_tokens.get(index, self.UNK_TOKEN)

    def _tokenize(self, text: str) -> List[str]:
        # Input is a list of moves separated by spaces.
        tokens: List[str] = []

        for chunk in text.strip().split():
            if chunk in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN):
                tokens.append(chunk)
                continue

            # Moves in the dataset start with W or B.
            if chunk and chunk[0] in ("W", "B"):
                tokens.append(chunk[0])

            from_sq, to_sq, promo = self._parse_move_chunk(chunk)
            if from_sq is None or to_sq is None:
                tokens.append(self.UNK_TOKEN)
                continue

            tokens.append(from_sq)
            tokens.append(to_sq)
            if promo is not None:
                tokens.append(promo)

            # End-of-move marker.
            tokens.append(self.EOS_TOKEN)

        return tokens

    @staticmethod
    def _parse_move_chunk(chunk: str) -> Tuple[Optional[str], Optional[str], Optional[str]]:
        # Grab the first two squares we see.
        squares = _SQUARE_RE.findall(chunk)
        if len(squares) < 2:
            return None, None, None

        from_sq, to_sq = squares[0], squares[1]

        # Promotion shows up like "=Q".
        promo = None
        m = _PROMO_RE.search(chunk)
        if m:
            promo = m.group(1).lower()
            if promo not in {"q", "r", "b", "n"}:
                promo = None

        return from_sq, to_sq, promo

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        # Keep squares and promo tokens, drop PAD for cleanliness.
        return " ".join(t for t in tokens if t != self.PAD_TOKEN)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        save_dir = Path(save_directory)
        save_dir.mkdir(parents=True, exist_ok=True)
        fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        path = save_dir / fname
        path.write_text(json.dumps(self._vocab, indent=2), encoding="utf-8")
        return (str(path),)