Update tokenizer.py
Browse files- tokenizer.py +37 -25
tokenizer.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import json
|
| 3 |
import os
|
|
|
|
| 4 |
from typing import Dict, List, Optional
|
| 5 |
from transformers import PreTrainedTokenizer
|
| 6 |
import torch
|
| 7 |
|
| 8 |
class ChessTokenizer(PreTrainedTokenizer):
|
| 9 |
"""
|
| 10 |
-
vocab size:
|
| 11 |
"""
|
| 12 |
|
| 13 |
model_input_names = ["input_ids", "attention_mask"]
|
|
@@ -24,6 +25,7 @@ class ChessTokenizer(PreTrainedTokenizer):
|
|
| 24 |
self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']]
|
| 25 |
self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh']
|
| 26 |
self.suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
|
|
|
|
| 27 |
|
| 28 |
if vocab is not None:
|
| 29 |
self._vocab = vocab
|
|
@@ -32,14 +34,11 @@ class ChessTokenizer(PreTrainedTokenizer):
|
|
| 32 |
self._vocab = json.load(f)
|
| 33 |
else:
|
| 34 |
self._vocab = {t: i for i, t in enumerate(special_tokens)}
|
| 35 |
-
for cp in self.colors_pieces:
|
| 36 |
-
|
| 37 |
-
for sq in self.squares:
|
| 38 |
-
|
| 39 |
-
for
|
| 40 |
-
self._vocab[f"{sq}_t"] = len(self._vocab)
|
| 41 |
-
for suf in self.suffixes:
|
| 42 |
-
self._vocab[suf] = len(self._vocab)
|
| 43 |
|
| 44 |
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
|
| 45 |
|
|
@@ -54,8 +53,9 @@ class ChessTokenizer(PreTrainedTokenizer):
|
|
| 54 |
@property
|
| 55 |
def vocab_size(self) -> int:
|
| 56 |
return len(self._vocab)
|
| 57 |
-
|
| 58 |
def get_vocab(self) -> Dict[str, int]:
|
|
|
|
| 59 |
return dict(self._vocab)
|
| 60 |
|
| 61 |
def _tokenize(self, text: str) -> List[str]:
|
|
@@ -64,43 +64,55 @@ class ChessTokenizer(PreTrainedTokenizer):
|
|
| 64 |
for part in parts:
|
| 65 |
if part in self._vocab:
|
| 66 |
tokens.append(part)
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if piece in self._vocab: tokens.append(piece)
|
| 70 |
if f_sq in self._vocab: tokens.append(f_sq)
|
| 71 |
if t_sq in self._vocab: tokens.append(t_sq)
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return tokens
|
| 75 |
|
| 76 |
def _convert_id_to_token(self, index: int) -> str:
|
| 77 |
token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
|
| 78 |
if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
|
| 79 |
return ""
|
|
|
|
|
|
|
|
|
|
| 80 |
return token.replace("_f", "").replace("_t", "")
|
| 81 |
|
| 82 |
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 83 |
res = []
|
| 84 |
for t in tokens:
|
| 85 |
if not t: continue
|
| 86 |
-
#
|
| 87 |
if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
|
| 88 |
res.append(" " + t)
|
| 89 |
else:
|
| 90 |
res.append(t)
|
| 91 |
return "".join(res).strip()
|
|
|
|
| 92 |
def _convert_token_to_id(self, token: str) -> int:
|
| 93 |
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
|
| 94 |
-
def _convert_id_to_token(self, index: int) -> str:
|
| 95 |
-
token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
|
| 96 |
-
if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
|
| 97 |
-
return ""
|
| 98 |
-
if token in self.suffixes:
|
| 99 |
-
return token
|
| 100 |
-
return token.replace("_f", "").replace("_t", "")
|
| 101 |
-
|
| 102 |
-
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 103 |
-
return "".join([t for t in tokens if t])
|
| 104 |
|
| 105 |
def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
|
| 106 |
if hasattr(token_ids, "tolist"):
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
+
import re
|
| 5 |
from typing import Dict, List, Optional
|
| 6 |
from transformers import PreTrainedTokenizer
|
| 7 |
import torch
|
| 8 |
|
| 9 |
class ChessTokenizer(PreTrainedTokenizer):
|
| 10 |
"""
|
| 11 |
+
vocab size: 4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix + 4 promotions (qrbn)
|
| 12 |
"""
|
| 13 |
|
| 14 |
model_input_names = ["input_ids", "attention_mask"]
|
|
|
|
| 25 |
self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']]
|
| 26 |
self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh']
|
| 27 |
self.suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
|
| 28 |
+
self.promotions = ["q", "r", "b", "n"] # promotion Token
|
| 29 |
|
| 30 |
if vocab is not None:
|
| 31 |
self._vocab = vocab
|
|
|
|
| 34 |
self._vocab = json.load(f)
|
| 35 |
else:
|
| 36 |
self._vocab = {t: i for i, t in enumerate(special_tokens)}
|
| 37 |
+
for cp in self.colors_pieces: self._vocab[cp] = len(self._vocab)
|
| 38 |
+
for sq in self.squares: self._vocab[f"{sq}_f"] = len(self._vocab)
|
| 39 |
+
for sq in self.squares: self._vocab[f"{sq}_t"] = len(self._vocab)
|
| 40 |
+
for suf in self.suffixes: self._vocab[suf] = len(self._vocab)
|
| 41 |
+
for promo in self.promotions: self._vocab[promo] = len(self._vocab)
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
|
| 44 |
|
|
|
|
| 53 |
@property
|
| 54 |
def vocab_size(self) -> int:
|
| 55 |
return len(self._vocab)
|
| 56 |
+
|
| 57 |
def get_vocab(self) -> Dict[str, int]:
|
| 58 |
+
"""Return the vocabulary as a dictionary."""
|
| 59 |
return dict(self._vocab)
|
| 60 |
|
| 61 |
def _tokenize(self, text: str) -> List[str]:
|
|
|
|
| 64 |
for part in parts:
|
| 65 |
if part in self._vocab:
|
| 66 |
tokens.append(part)
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
# Deal with WPe7e8q
|
| 70 |
+
if len(part) >= 6:
|
| 71 |
+
piece = part[:2]
|
| 72 |
+
f_sq = part[2:4] + "_f"
|
| 73 |
+
t_sq = part[4:6] + "_t"
|
| 74 |
if piece in self._vocab: tokens.append(piece)
|
| 75 |
if f_sq in self._vocab: tokens.append(f_sq)
|
| 76 |
if t_sq in self._vocab: tokens.append(t_sq)
|
| 77 |
+
|
| 78 |
+
# Check if rest part include promotion or suffix
|
| 79 |
+
rest = part[6:]
|
| 80 |
+
if not rest: continue
|
| 81 |
+
|
| 82 |
+
# Extract promotion letter (q, r, b, n)
|
| 83 |
+
promo_match = re.search(r'[qrbnQRBN]', rest)
|
| 84 |
+
if promo_match:
|
| 85 |
+
p_char = promo_match.group(0).lower()
|
| 86 |
+
if p_char in self._vocab: tokens.append(p_char)
|
| 87 |
+
|
| 88 |
+
# Extract suffixes
|
| 89 |
+
for suf in self.suffixes:
|
| 90 |
+
if suf in rest:
|
| 91 |
+
tokens.append(suf)
|
| 92 |
return tokens
|
| 93 |
|
| 94 |
def _convert_id_to_token(self, index: int) -> str:
|
| 95 |
token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
|
| 96 |
if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
|
| 97 |
return ""
|
| 98 |
+
# Same if promotion or suffix, delete _f or _t
|
| 99 |
+
if token in self.promotions or token in self.suffixes:
|
| 100 |
+
return token
|
| 101 |
return token.replace("_f", "").replace("_t", "")
|
| 102 |
|
| 103 |
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 104 |
res = []
|
| 105 |
for t in tokens:
|
| 106 |
if not t: continue
|
| 107 |
+
# If piece (WP), new move, add space
|
| 108 |
if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
|
| 109 |
res.append(" " + t)
|
| 110 |
else:
|
| 111 |
res.append(t)
|
| 112 |
return "".join(res).strip()
|
| 113 |
+
|
| 114 |
def _convert_token_to_id(self, token: str) -> int:
|
| 115 |
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
|
| 118 |
if hasattr(token_ids, "tolist"):
|