Sunxt25 commited on
Commit
6156855
·
verified ·
1 Parent(s): acfaf7d

Delete tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -121
tokenizer.py DELETED
@@ -1,121 +0,0 @@
1
- from __future__ import annotations
2
- import json
3
- import os
4
- from typing import Dict, List, Optional
5
- from transformers import PreTrainedTokenizer
6
-
7
- class ChessTokenizer(PreTrainedTokenizer):
8
- """
9
- 符合评估脚本要求的 Chess Tokenizer。
10
- 1. 词表大小为 144 (4 special + 12 pieces + 64 from_sq + 64 to_sq)。
11
- 2. Decode 结果为紧凑格式(如 "WPe2e4"),确保 evaluate.py 的切片 [2:4] 和 [4:6] 正确。
12
- 3. 区分起始格和目标格语义。
13
- """
14
-
15
- model_input_names = ["input_ids", "attention_mask"]
16
- vocab_files_names = {"vocab_file": "vocab.json"}
17
-
18
- PAD_TOKEN = "[PAD]"
19
- BOS_TOKEN = "[BOS]"
20
- EOS_TOKEN = "[EOS]"
21
- UNK_TOKEN = "[UNK]"
22
-
23
- def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
24
- special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
25
-
26
- # 必须使用大写,以匹配 evaluate.py 生成的棋谱
27
- self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']] # 12个
28
- self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh'] # 64个
29
-
30
- if vocab is not None:
31
- self._vocab = vocab
32
- elif vocab_file is not None and os.path.exists(vocab_file):
33
- with open(vocab_file, "r", encoding="utf-8") as f:
34
- self._vocab = json.load(f)
35
- else:
36
- # 构建 144 大小的词表
37
- self._vocab = {t: i for i, t in enumerate(special_tokens)} # 0-3
38
-
39
- # 4-15: Piece tokens
40
- for cp in self.colors_pieces:
41
- self._vocab[cp] = len(self._vocab)
42
-
43
- # 16-79: From Square tokens (内部带后缀防止重名)
44
- for sq in self.squares:
45
- self._vocab[f"{sq}_f"] = len(self._vocab)
46
-
47
- # 80-143: To Square tokens
48
- for sq in self.squares:
49
- self._vocab[f"{sq}_t"] = len(self._vocab)
50
-
51
- self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
52
-
53
- super().__init__(
54
- pad_token=self.PAD_TOKEN,
55
- bos_token=self.BOS_TOKEN,
56
- eos_token=self.EOS_TOKEN,
57
- unk_token=self.UNK_TOKEN,
58
- **kwargs,
59
- )
60
-
61
- @property
62
- def vocab_size(self) -> int:
63
- return len(self._vocab)
64
-
65
- def get_vocab(self) -> Dict[str, int]:
66
- return dict(self._vocab)
67
-
68
- def _tokenize(self, text: str) -> List[str]:
69
- """将 WPe2e4 拆分为三个 token"""
70
- tokens = []
71
- # 处理可能的空格分隔(如历史棋谱)
72
- moves = text.strip().split()
73
- for move in moves:
74
- # 过滤特殊 token 字符串
75
- if move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
76
- tokens.append(move)
77
- continue
78
-
79
- if len(move) >= 6:
80
- cp = move[:2] # 例如 "WP"
81
- from_sq = move[2:4] + "_f" # 例如 "e2_f"
82
- to_sq = move[4:6] + "_t" # 例如 "e4_t"
83
- tokens.extend([cp, from_sq, to_sq])
84
- return tokens
85
-
86
- def _convert_token_to_id(self, token: str) -> int:
87
- return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
88
-
89
- def _convert_id_to_token(self, index: int) -> str:
90
- token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
91
- # 关键:在 decode 时去掉内部后缀,还原为 "e2", "e4"
92
- return token.replace("_f", "").replace("_t", "")
93
-
94
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
95
- """
96
- 将 token 列表合并。
97
- evaluate.py 要求输出如 "WPe2e4",因此这里不加空格。
98
- """
99
- # 过滤特殊 token,只保留棋步内容
100
- clean_tokens = [t for t in tokens if t not in self.all_special_tokens]
101
- return "".join(clean_tokens)
102
-
103
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
104
- if not os.path.isdir(save_directory):
105
- os.makedirs(save_directory, exist_ok=True)
106
- vocab_file = os.path.join(
107
- save_directory,
108
- (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
109
- )
110
- with open(vocab_file, "w", encoding="utf-8") as f:
111
- json.dump(self._vocab, f, ensure_ascii=False, indent=2)
112
- return (vocab_file,)
113
-
114
- @classmethod
115
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "ChessTokenizer":
116
- vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
117
- if not os.path.exists(vocab_file):
118
- return cls() # 如果没有文件则初始化默认的
119
- with open(vocab_file, "r", encoding="utf-8") as f:
120
- vocab = json.load(f)
121
- return cls(vocab=vocab, **kwargs)