Sunxt25 commited on
Commit
7de47b2
·
verified ·
1 Parent(s): a6243fc

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +64 -84
tokenizer.py CHANGED
@@ -7,8 +7,7 @@ import torch
7
 
8
  class ChessTokenizer(PreTrainedTokenizer):
9
  """
10
- 符合评估脚本要求的 Chess Tokenizer。
11
- 词表大小: 149 (4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix)
12
  """
13
 
14
  model_input_names = ["input_ids", "attention_mask"]
@@ -54,98 +53,79 @@ class ChessTokenizer(PreTrainedTokenizer):
54
 
55
  @property
56
  def vocab_size(self) -> int:
57
- """Return the size of the vocabulary."""
58
  return len(self._vocab)
59
-
60
  def get_vocab(self) -> Dict[str, int]:
61
- """Return the vocabulary as a dictionary."""
62
  return dict(self._vocab)
63
-
64
  def _tokenize(self, text: str) -> List[str]:
65
- """
66
- Tokenize a string of moves into a list of tokens.
67
-
68
- Args:
69
- text: A string of space-separated moves.
70
-
71
- Returns:
72
- List of move tokens.
73
- """
74
- return text.strip().split()
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def _convert_token_to_id(self, token: str) -> int:
77
- """Convert a token to its ID."""
78
- return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
79
-
80
  def _convert_id_to_token(self, index: int) -> str:
81
- """Convert an ID to its token."""
82
- return self._ids_to_tokens.get(index, self.UNK_TOKEN)
83
-
 
 
 
 
84
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
85
- """Convert a list of tokens back to a string."""
86
- # Filter out special tokens for cleaner output
87
- special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
88
- return " ".join(t for t in tokens if t not in special)
89
-
90
- def save_vocabulary(
91
- self,
92
- save_directory: str,
93
- filename_prefix: Optional[str] = None,
94
- ) -> tuple:
95
- """
96
- Save the vocabulary to a JSON file.
97
-
98
- Args:
99
- save_directory: Directory to save the vocabulary.
100
- filename_prefix: Optional prefix for the filename.
101
-
102
- Returns:
103
- Tuple containing the path to the saved vocabulary file.
104
- """
105
  if not os.path.isdir(save_directory):
106
  os.makedirs(save_directory, exist_ok=True)
107
-
108
- vocab_file = os.path.join(
109
- save_directory,
110
- (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
111
- )
112
-
113
  with open(vocab_file, "w", encoding="utf-8") as f:
114
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
115
-
116
  return (vocab_file,)
117
 
118
-
119
- def count_vocab_from_dataset(
120
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
121
- split: str = "train",
122
- column: str = "text",
123
- max_samples: Optional[int] = 10000,
124
- ) -> Dict[str, int]:
125
- """
126
- Count token frequencies in a dataset (useful for vocabulary analysis).
127
-
128
- Args:
129
- dataset_name: Name of the dataset on Hugging Face Hub.
130
- split: Dataset split to use.
131
- column: Column containing the game strings.
132
- max_samples: Maximum number of samples to process.
133
-
134
- Returns:
135
- Dictionary mapping tokens to their frequencies.
136
- """
137
- from collections import Counter
138
- from datasets import load_dataset
139
-
140
- dataset = load_dataset(dataset_name, split=split)
141
-
142
- if max_samples is not None:
143
- dataset = dataset.select(range(min(max_samples, len(dataset))))
144
-
145
- token_counts = Counter()
146
-
147
- for example in dataset:
148
- moves = example[column].strip().split()
149
- token_counts.update(moves)
150
-
151
- return dict(token_counts)
 
7
 
8
  class ChessTokenizer(PreTrainedTokenizer):
9
  """
10
+ vocab size: 149 (4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix)
 
11
  """
12
 
13
  model_input_names = ["input_ids", "attention_mask"]
 
53
 
54
  @property
55
  def vocab_size(self) -> int:
 
56
  return len(self._vocab)
57
+
58
  def get_vocab(self) -> Dict[str, int]:
 
59
  return dict(self._vocab)
60
+
61
  def _tokenize(self, text: str) -> List[str]:
62
+ tokens = []
63
+ parts = text.strip().split()
64
+ for part in parts:
65
+ if part in self._vocab:
66
+ tokens.append(part)
67
+ elif len(part) >= 6:
68
+ piece, f_sq, t_sq = part[:2], part[2:4] + "_f", part[4:6] + "_t"
69
+ if piece in self._vocab: tokens.append(piece)
70
+ if f_sq in self._vocab: tokens.append(f_sq)
71
+ if t_sq in self._vocab: tokens.append(t_sq)
72
+ if len(part) > 6 and part[6:] in self.suffixes:
73
+ tokens.append(part[6:])
74
+ return tokens
75
+
76
+ def _convert_id_to_token(self, index: int) -> str:
77
+ token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
78
+ if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
79
+ return ""
80
+ return token.replace("_f", "").replace("_t", "")
81
+
82
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
83
+ res = []
84
+ for t in tokens:
85
+ if not t: continue
86
+ # if piece token,new move,add space
87
+ if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
88
+ res.append(" " + t)
89
+ else:
90
+ res.append(t)
91
+ return "".join(res).strip()
92
  def _convert_token_to_id(self, token: str) -> int:
93
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
 
 
94
  def _convert_id_to_token(self, index: int) -> str:
95
+ token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
96
+ if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
97
+ return ""
98
+ if token in self.suffixes:
99
+ return token
100
+ return token.replace("_f", "").replace("_t", "")
101
+
102
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
103
+ return "".join([t for t in tokens if t])
104
+
105
+ def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
106
+ if hasattr(token_ids, "tolist"):
107
+ ids = token_ids.tolist()
108
+ elif isinstance(token_ids, (int, torch.LongTensor, torch.IntTensor)):
109
+ ids = [int(token_ids)] if isinstance(token_ids, int) else token_ids.tolist()
110
+ else:
111
+ ids = token_ids
112
+
113
+ tokens = [self._convert_id_to_token(i) for i in ids]
114
+ return self.convert_tokens_to_string(tokens)
115
+
116
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
 
 
 
 
 
 
117
  if not os.path.isdir(save_directory):
118
  os.makedirs(save_directory, exist_ok=True)
119
+ vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
 
 
 
 
 
120
  with open(vocab_file, "w", encoding="utf-8") as f:
121
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
 
122
  return (vocab_file,)
123
 
124
+ @classmethod
125
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "ChessTokenizer":
126
+ vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
127
+ if not os.path.exists(vocab_file):
128
+ return cls()
129
+ with open(vocab_file, "r", encoding="utf-8") as f:
130
+ vocab = json.load(f)
131
+ return cls(vocab=vocab, **kwargs)