Sunxt25 commited on
Commit
a6243fc
·
verified ·
1 Parent(s): 2ab9c78

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +83 -57
tokenizer.py CHANGED
@@ -54,72 +54,98 @@ class ChessTokenizer(PreTrainedTokenizer):
54
 
55
  @property
56
  def vocab_size(self) -> int:
 
57
  return len(self._vocab)
58
-
59
  def get_vocab(self) -> Dict[str, int]:
 
60
  return dict(self._vocab)
61
-
62
  def _tokenize(self, text: str) -> List[str]:
63
- """关键:支持识别带后缀的 token,让 eval 识别为 decomposed 模式"""
64
- tokens = []
65
- parts = text.strip().split()
66
- for part in parts:
67
- if part in self._vocab:
68
- tokens.append(part)
69
- elif len(part) >= 6: # 处理 WPe2e4 紧凑格式
70
- piece, f_sq, t_sq = part[:2], part[2:4] + "_f", part[4:6] + "_t"
71
- if piece in self._vocab: tokens.append(piece)
72
- if f_sq in self._vocab: tokens.append(f_sq)
73
- if t_sq in self._vocab: tokens.append(t_sq)
74
- if len(part) > 6 and part[6:] in self.suffixes:
75
- tokens.append(part[6:])
76
- return tokens
77
-
78
  def _convert_id_to_token(self, index: int) -> str:
79
- """关键:去掉后缀,让 eval 的正则 [a-h][1-8] 能抓到坐标"""
80
- token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
81
- if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
82
- return ""
83
- return token.replace("_f", "").replace("_t", "")
84
-
85
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
86
- """关键:在 Piece 前加空格,确保历史棋局格式正确"""
87
- res = []
88
- for t in tokens:
89
- if not t: continue
90
- # 如果是棋子 token,说明是新 move,加空格
91
- if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
92
- res.append(" " + t)
93
- else:
94
- res.append(t)
95
- return "".join(res).strip()
96
- def _convert_token_to_id(self, token: str) -> int:
97
- return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
98
-
99
- def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
100
- if hasattr(token_ids, "tolist"):
101
- ids = token_ids.tolist()
102
- elif isinstance(token_ids, (int, torch.LongTensor, torch.IntTensor)):
103
- ids = [int(token_ids)] if isinstance(token_ids, int) else token_ids.tolist()
104
- else:
105
- ids = token_ids
106
-
107
- tokens = [self._convert_id_to_token(i) for i in ids]
108
- return self.convert_tokens_to_string(tokens)
109
-
110
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
111
  if not os.path.isdir(save_directory):
112
  os.makedirs(save_directory, exist_ok=True)
113
- vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
 
 
 
 
 
114
  with open(vocab_file, "w", encoding="utf-8") as f:
115
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
 
116
  return (vocab_file,)
117
 
118
- @classmethod
119
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "ChessTokenizer":
120
- vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
121
- if not os.path.exists(vocab_file):
122
- return cls()
123
- with open(vocab_file, "r", encoding="utf-8") as f:
124
- vocab = json.load(f)
125
- return cls(vocab=vocab, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @property
56
  def vocab_size(self) -> int:
57
+ """Return the size of the vocabulary."""
58
  return len(self._vocab)
59
+
60
  def get_vocab(self) -> Dict[str, int]:
61
+ """Return the vocabulary as a dictionary."""
62
  return dict(self._vocab)
63
+
64
  def _tokenize(self, text: str) -> List[str]:
65
+ """
66
+ Tokenize a string of moves into a list of tokens.
67
+
68
+ Args:
69
+ text: A string of space-separated moves.
70
+
71
+ Returns:
72
+ List of move tokens.
73
+ """
74
+ return text.strip().split()
75
+
76
+ def _convert_token_to_id(self, token: str) -> int:
77
+ """Convert a token to its ID."""
78
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
79
+
80
  def _convert_id_to_token(self, index: int) -> str:
81
+ """Convert an ID to its token."""
82
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
83
+
 
 
 
84
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
85
+ """Convert a list of tokens back to a string."""
86
+ # Filter out special tokens for cleaner output
87
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
88
+ return " ".join(t for t in tokens if t not in special)
89
+
90
+ def save_vocabulary(
91
+ self,
92
+ save_directory: str,
93
+ filename_prefix: Optional[str] = None,
94
+ ) -> tuple:
95
+ """
96
+ Save the vocabulary to a JSON file.
97
+
98
+ Args:
99
+ save_directory: Directory to save the vocabulary.
100
+ filename_prefix: Optional prefix for the filename.
101
+
102
+ Returns:
103
+ Tuple containing the path to the saved vocabulary file.
104
+ """
 
 
 
 
 
105
  if not os.path.isdir(save_directory):
106
  os.makedirs(save_directory, exist_ok=True)
107
+
108
+ vocab_file = os.path.join(
109
+ save_directory,
110
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
111
+ )
112
+
113
  with open(vocab_file, "w", encoding="utf-8") as f:
114
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
115
+
116
  return (vocab_file,)
117
 
118
+
119
+ def count_vocab_from_dataset(
120
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
121
+ split: str = "train",
122
+ column: str = "text",
123
+ max_samples: Optional[int] = 10000,
124
+ ) -> Dict[str, int]:
125
+ """
126
+ Count token frequencies in a dataset (useful for vocabulary analysis).
127
+
128
+ Args:
129
+ dataset_name: Name of the dataset on Hugging Face Hub.
130
+ split: Dataset split to use.
131
+ column: Column containing the game strings.
132
+ max_samples: Maximum number of samples to process.
133
+
134
+ Returns:
135
+ Dictionary mapping tokens to their frequencies.
136
+ """
137
+ from collections import Counter
138
+ from datasets import load_dataset
139
+
140
+ dataset = load_dataset(dataset_name, split=split)
141
+
142
+ if max_samples is not None:
143
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
144
+
145
+ token_counts = Counter()
146
+
147
+ for example in dataset:
148
+ moves = example[column].strip().split()
149
+ token_counts.update(moves)
150
+
151
+ return dict(token_counts)