Sunxt25 commited on
Commit
a4c10fe
·
verified ·
1 Parent(s): a17994d

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +37 -25
tokenizer.py CHANGED
@@ -1,13 +1,14 @@
1
  from __future__ import annotations
2
  import json
3
  import os
 
4
  from typing import Dict, List, Optional
5
  from transformers import PreTrainedTokenizer
6
  import torch
7
 
8
  class ChessTokenizer(PreTrainedTokenizer):
9
  """
10
- vocab size: 149 (4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix)
11
  """
12
 
13
  model_input_names = ["input_ids", "attention_mask"]
@@ -24,6 +25,7 @@ class ChessTokenizer(PreTrainedTokenizer):
24
  self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']]
25
  self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh']
26
  self.suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
 
27
 
28
  if vocab is not None:
29
  self._vocab = vocab
@@ -32,14 +34,11 @@ class ChessTokenizer(PreTrainedTokenizer):
32
  self._vocab = json.load(f)
33
  else:
34
  self._vocab = {t: i for i, t in enumerate(special_tokens)}
35
- for cp in self.colors_pieces:
36
- self._vocab[cp] = len(self._vocab)
37
- for sq in self.squares:
38
- self._vocab[f"{sq}_f"] = len(self._vocab)
39
- for sq in self.squares:
40
- self._vocab[f"{sq}_t"] = len(self._vocab)
41
- for suf in self.suffixes:
42
- self._vocab[suf] = len(self._vocab)
43
 
44
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
45
 
@@ -54,8 +53,9 @@ class ChessTokenizer(PreTrainedTokenizer):
54
  @property
55
  def vocab_size(self) -> int:
56
  return len(self._vocab)
57
-
58
  def get_vocab(self) -> Dict[str, int]:
 
59
  return dict(self._vocab)
60
 
61
  def _tokenize(self, text: str) -> List[str]:
@@ -64,43 +64,55 @@ class ChessTokenizer(PreTrainedTokenizer):
64
  for part in parts:
65
  if part in self._vocab:
66
  tokens.append(part)
67
- elif len(part) >= 6:
68
- piece, f_sq, t_sq = part[:2], part[2:4] + "_f", part[4:6] + "_t"
 
 
 
 
 
69
  if piece in self._vocab: tokens.append(piece)
70
  if f_sq in self._vocab: tokens.append(f_sq)
71
  if t_sq in self._vocab: tokens.append(t_sq)
72
- if len(part) > 6 and part[6:] in self.suffixes:
73
- tokens.append(part[6:])
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return tokens
75
 
76
  def _convert_id_to_token(self, index: int) -> str:
77
  token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
78
  if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
79
  return ""
 
 
 
80
  return token.replace("_f", "").replace("_t", "")
81
 
82
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
83
  res = []
84
  for t in tokens:
85
  if not t: continue
86
- # if piece token,new moveadd space
87
  if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
88
  res.append(" " + t)
89
  else:
90
  res.append(t)
91
  return "".join(res).strip()
 
92
  def _convert_token_to_id(self, token: str) -> int:
93
  return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
94
- def _convert_id_to_token(self, index: int) -> str:
95
- token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
96
- if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
97
- return ""
98
- if token in self.suffixes:
99
- return token
100
- return token.replace("_f", "").replace("_t", "")
101
-
102
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
103
- return "".join([t for t in tokens if t])
104
 
105
  def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
106
  if hasattr(token_ids, "tolist"):
 
1
  from __future__ import annotations
2
  import json
3
  import os
4
+ import re
5
  from typing import Dict, List, Optional
6
  from transformers import PreTrainedTokenizer
7
  import torch
8
 
9
  class ChessTokenizer(PreTrainedTokenizer):
10
  """
11
+ vocab size: 4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix + 4 promotions (qrbn)
12
  """
13
 
14
  model_input_names = ["input_ids", "attention_mask"]
 
25
  self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']]
26
  self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh']
27
  self.suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
28
+ self.promotions = ["q", "r", "b", "n"] # promotion Token
29
 
30
  if vocab is not None:
31
  self._vocab = vocab
 
34
  self._vocab = json.load(f)
35
  else:
36
  self._vocab = {t: i for i, t in enumerate(special_tokens)}
37
+ for cp in self.colors_pieces: self._vocab[cp] = len(self._vocab)
38
+ for sq in self.squares: self._vocab[f"{sq}_f"] = len(self._vocab)
39
+ for sq in self.squares: self._vocab[f"{sq}_t"] = len(self._vocab)
40
+ for suf in self.suffixes: self._vocab[suf] = len(self._vocab)
41
+ for promo in self.promotions: self._vocab[promo] = len(self._vocab)
 
 
 
42
 
43
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
44
 
 
53
  @property
54
  def vocab_size(self) -> int:
55
  return len(self._vocab)
56
+
57
  def get_vocab(self) -> Dict[str, int]:
58
+ """Return the vocabulary as a dictionary."""
59
  return dict(self._vocab)
60
 
61
  def _tokenize(self, text: str) -> List[str]:
 
64
  for part in parts:
65
  if part in self._vocab:
66
  tokens.append(part)
67
+ continue
68
+
69
+ # Deal with WPe7e8q
70
+ if len(part) >= 6:
71
+ piece = part[:2]
72
+ f_sq = part[2:4] + "_f"
73
+ t_sq = part[4:6] + "_t"
74
  if piece in self._vocab: tokens.append(piece)
75
  if f_sq in self._vocab: tokens.append(f_sq)
76
  if t_sq in self._vocab: tokens.append(t_sq)
77
+
78
+ # Check if rest part include promotion or suffix
79
+ rest = part[6:]
80
+ if not rest: continue
81
+
82
+ # Extract promotion letter (q, r, b, n)
83
+ promo_match = re.search(r'[qrbnQRBN]', rest)
84
+ if promo_match:
85
+ p_char = promo_match.group(0).lower()
86
+ if p_char in self._vocab: tokens.append(p_char)
87
+
88
+ # Extract suffixes
89
+ for suf in self.suffixes:
90
+ if suf in rest:
91
+ tokens.append(suf)
92
  return tokens
93
 
94
  def _convert_id_to_token(self, index: int) -> str:
95
  token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
96
  if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
97
  return ""
98
+ # Same if promotion or suffix, delete _f or _t
99
+ if token in self.promotions or token in self.suffixes:
100
+ return token
101
  return token.replace("_f", "").replace("_t", "")
102
 
103
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
104
  res = []
105
  for t in tokens:
106
  if not t: continue
107
+ # If piece (WP), new move, add space
108
  if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
109
  res.append(" " + t)
110
  else:
111
  res.append(t)
112
  return "".join(res).strip()
113
+
114
  def _convert_token_to_id(self, token: str) -> int:
115
  return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
 
 
 
 
 
 
 
 
 
 
116
 
117
  def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
118
  if hasattr(token_ids, "tolist"):