GenerTeam commited on
Commit
4cb5f70
·
verified ·
1 Parent(s): 1ab6665

Update tokenizer

Browse files
Files changed (4) hide show
  1. special_tokens_map.json +2 -7
  2. tokenizer.py +224 -0
  3. tokenizer_config.json +34 -17
  4. vocab.txt +43 -0
special_tokens_map.json CHANGED
@@ -1,12 +1,7 @@
1
  {
2
  "bos_token": "<s>",
3
  "eos_token": "</s>",
 
4
  "pad_token": "<pad>",
5
- "unk_token": {
6
- "content": "<oov>",
7
- "lstrip": false,
8
- "normalized": false,
9
- "rstrip": false,
10
- "single_word": false
11
- }
12
  }
 
1
  {
2
  "bos_token": "<s>",
3
  "eos_token": "</s>",
4
+ "mask_token": "<mask>",
5
  "pad_token": "<pad>",
6
+ "unk_token": "N"
 
 
 
 
 
 
7
  }
tokenizer.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from typing import List, Optional, Tuple, Dict
5
+ from transformers import PreTrainedTokenizer
6
+
7
+ class SingleNucleotideTokenizer(PreTrainedTokenizer):
8
+ def __init__(self, **kwargs):
9
+ # 定义词表
10
+ self.vocab_list = [
11
+ "<oov>", "<s>", "</s>", "<pad>", "<mask>",
12
+ "<bog>", "<eog>", "<bok>", "<eok>", "<+>", "<->",
13
+ "<mam>", "<vrt>", "<inv>", "<pln>", "<fng>", "<prt>",
14
+ "<arc>", "<bct>", "<mit>", "<plt>", "<plm>", "<vir>",
15
+ "<cds>", "<pseudo>", "<tRNA>", "<rRNA>", "<ncRNA>",
16
+ "<sp0>", "<sp1>", "<sp2>", "<sp3>",
17
+ "A", "C", "G", "<K>", "<M>", "N", "<R>", "<S>", "T", "<W>", "<Y>"
18
+ ]
19
+
20
+ # 创建词汇映射
21
+ self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)}
22
+ self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()}
23
+ self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()}
24
+
25
+ # 设置特殊token
26
+ self.unk_token = "N"
27
+ self.bos_token = "<s>"
28
+ self.eos_token = "</s>"
29
+ self.pad_token = "<pad>"
30
+ self.mask_token = "<mask>"
31
+
32
+ # 编译正则表达式以匹配特殊token
33
+ special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">"))
34
+ self.special_token_re = re.compile(f"({special_tokens_pattern})")
35
+
36
+ # 编译正则表达式以匹配普通token
37
+ self.normal_token_re = re.compile(r"[ACGTN]")
38
+
39
+ # 设置特殊token ID
40
+ self.unk_token_id = self.vocab[self.unk_token]
41
+ self.bos_token_id = self.vocab[self.bos_token]
42
+ self.eos_token_id = self.vocab[self.eos_token]
43
+ self.pad_token_id = self.vocab[self.pad_token]
44
+ self.mask_token_id = self.vocab[self.mask_token]
45
+
46
+ # 调用父类初始化
47
+ super().__init__(
48
+ unk_token=self.unk_token,
49
+ bos_token=self.bos_token,
50
+ eos_token=self.eos_token,
51
+ pad_token=self.pad_token,
52
+ mask_token=self.mask_token,
53
+ **kwargs
54
+ )
55
+ self.clean_up_tokenization_spaces = True
56
+
57
+ @property
58
+ def vocab_size(self) -> int:
59
+ return len(self.vocab)
60
+
61
+ def get_vocab(self) -> Dict[str, int]:
62
+ return self.vocab
63
+
64
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
65
+ tokens = []
66
+ pos = 0
67
+ text_length = len(text)
68
+
69
+ while pos < text_length:
70
+ # 首先尝试匹配特殊token
71
+ special_match = self.special_token_re.match(text, pos)
72
+ if special_match:
73
+ token = special_match.group()
74
+ tokens.append(token)
75
+ pos = special_match.end()
76
+ continue
77
+
78
+ # 然后尝试匹配普通token
79
+ normal_match = self.normal_token_re.match(text, pos)
80
+ if normal_match:
81
+ token = normal_match.group()
82
+ # 确保token在词汇表中
83
+ if token in self.vocab:
84
+ tokens.append(token)
85
+ else:
86
+ tokens.append(self.unk_token)
87
+ pos = normal_match.end()
88
+ continue
89
+
90
+ # 如果都不匹配,跳过字符并使用unk_token
91
+ tokens.append(self.unk_token)
92
+ pos += 1
93
+
94
+ return tokens
95
+
96
+ def _convert_token_to_id(self, token: str) -> int:
97
+ return self.vocab.get(token, self.unk_token_id)
98
+
99
+ def _convert_id_to_token(self, index: int) -> str:
100
+ return self.ids_to_tokens.get(index, self.unk_token)
101
+
102
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
103
+ # 简单地连接所有token
104
+ return "".join(tokens)
105
+
106
+ def build_inputs_with_special_tokens(
107
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
108
+ ) -> List[int]:
109
+ if token_ids_1 is None:
110
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
111
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
112
+
113
+ def get_special_tokens_mask(
114
+ self,
115
+ token_ids_0: List[int],
116
+ token_ids_1: Optional[List[int]] = None,
117
+ already_has_special_tokens: bool = False
118
+ ) -> List[int]:
119
+ if already_has_special_tokens:
120
+ return super().get_special_tokens_mask(
121
+ token_ids_0=token_ids_0,
122
+ token_ids_1=token_ids_1,
123
+ already_has_special_tokens=True
124
+ )
125
+
126
+ if token_ids_1 is None:
127
+ return [1] + ([0] * len(token_ids_0)) + [1]
128
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
129
+
130
+ def create_token_type_ids_from_sequences(
131
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
132
+ ) -> List[int]:
133
+ # Llama通常不使用token类型ID
134
+ if token_ids_1 is None:
135
+ return [0] * (len(token_ids_0) + 2) # +2 for [CLS] and [SEP]
136
+ return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1)
137
+
138
+ def save_pretrained(self, save_directory: str, **kwargs):
139
+ """重写save_pretrained以包含auto_map配置"""
140
+ # 先调用父类方法保存词汇表等
141
+ vocab_files = super().save_pretrained(save_directory, **kwargs)
142
+
143
+ # 创建或更新tokenizer_config.json
144
+ tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
145
+
146
+ # 读取现有的配置或创建新的
147
+ if os.path.exists(tokenizer_config_path):
148
+ with open(tokenizer_config_path, "r", encoding="utf-8") as f:
149
+ config = json.load(f)
150
+ else:
151
+ config = {}
152
+
153
+ # 添加auto_map配置
154
+ config.update({
155
+ "auto_map": {
156
+ "AutoTokenizer": [
157
+ "tokenizer.SingleNucleotideTokenizer", # 如果是直接运行的脚本
158
+ None
159
+ ]
160
+ },
161
+ })
162
+
163
+ # 保存配置
164
+ with open(tokenizer_config_path, "w", encoding="utf-8") as f:
165
+ json.dump(config, f, ensure_ascii=False, indent=2)
166
+
167
+ return vocab_files
168
+
169
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
170
+ import os
171
+
172
+ # 确保目录存在
173
+ if not os.path.exists(save_directory):
174
+ os.makedirs(save_directory)
175
+
176
+ # 创建词汇文件路径
177
+ vocab_file = os.path.join(
178
+ save_directory,
179
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt"
180
+ )
181
+
182
+ # 写入词汇表
183
+ with open(vocab_file, "w", encoding="utf-8") as f:
184
+ for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]):
185
+ f.write(f"{token} {idx}\n")
186
+
187
+ return (vocab_file,)
188
+
189
+ @classmethod
190
+ def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
191
+ # 直接创建新的tokenizer实例
192
+ return cls(**kwargs)
193
+
194
+
195
+ from transformers import PreTrainedTokenizer, AutoTokenizer
196
+
197
+ AutoTokenizer.register("atcg_tokenizer", SingleNucleotideTokenizer)
198
+
199
+
200
+ from transformers import AutoTokenizer
201
+ loaded_tokenizer = AutoTokenizer.from_pretrained('/vepfs-mlp2/mlp-public/liqiuyi/GENERanno-eukaryote-0.5b-diffusion')
202
+
203
+ # 初始化tokenizer
204
+ tokenizer = SingleNucleotideTokenizer()
205
+
206
+ # 测试tokenizer
207
+ text = "ACGTKMNRSTWY<cds><prt>"
208
+ tokens = tokenizer.tokenize(text)
209
+ print("Tokens:", tokens)
210
+
211
+ # 转换为ID
212
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
213
+ print("Token IDs:", token_ids)
214
+
215
+ # 转换回文本
216
+ decoded_text = tokenizer.decode(token_ids)
217
+ print("Decoded text:", decoded_text)
218
+
219
+ # 测试特殊token
220
+ special_tokens = tokenizer.build_inputs_with_special_tokens(token_ids)
221
+ print("With special tokens:", special_tokens)
222
+
223
+ tokenizer.save_pretrained('/vepfs-mlp2/mlp-public/liqiuyi/GENERanno-eukaryote-0.5b-diffusion')
224
+
tokenizer_config.json CHANGED
@@ -1,26 +1,39 @@
1
  {
2
- "add_bos_token": true,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
  "added_tokens_decoder": {
6
- "0": {
7
- "content": "<oov>",
8
  "lstrip": false,
9
  "normalized": false,
10
  "rstrip": false,
11
  "single_word": false,
12
  "special": true
13
  },
14
- "1": {
15
- "content": "<s>",
16
  "lstrip": false,
17
  "normalized": false,
18
  "rstrip": false,
19
  "single_word": false,
20
  "special": true
21
  },
22
- "2": {
23
- "content": "</s>",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  "lstrip": false,
25
  "normalized": false,
26
  "rstrip": false,
@@ -29,14 +42,18 @@
29
  }
30
  },
31
  "bos_token": "<s>",
32
- "clean_up_tokenization_spaces": false,
33
  "eos_token": "</s>",
34
- "legacy": true,
 
35
  "model_max_length": 1000000000000000019884624838656,
36
  "pad_token": "<pad>",
37
- "sp_model_kwargs": {},
38
- "spaces_between_special_tokens": false,
39
- "tokenizer_class": "LlamaTokenizer",
40
- "unk_token": "<oov>",
41
- "use_default_system_prompt": false
42
- }
 
 
 
 
1
  {
 
 
 
2
  "added_tokens_decoder": {
3
+ "1": {
4
+ "content": "<s>",
5
  "lstrip": false,
6
  "normalized": false,
7
  "rstrip": false,
8
  "single_word": false,
9
  "special": true
10
  },
11
+ "2": {
12
+ "content": "</s>",
13
  "lstrip": false,
14
  "normalized": false,
15
  "rstrip": false,
16
  "single_word": false,
17
  "special": true
18
  },
19
+ "3": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "4": {
28
+ "content": "<mask>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "37": {
36
+ "content": "N",
37
  "lstrip": false,
38
  "normalized": false,
39
  "rstrip": false,
 
42
  }
43
  },
44
  "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
  "eos_token": "</s>",
47
+ "extra_special_tokens": {},
48
+ "mask_token": "<mask>",
49
  "model_max_length": 1000000000000000019884624838656,
50
  "pad_token": "<pad>",
51
+ "tokenizer_class": "SingleNucleotideTokenizer",
52
+ "unk_token": "N",
53
+ "auto_map": {
54
+ "AutoTokenizer": [
55
+ "tokenizer.SingleNucleotideTokenizer",
56
+ null
57
+ ]
58
+ }
59
+ }
vocab.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <oov> 0
2
+ <s> 1
3
+ </s> 2
4
+ <pad> 3
5
+ <mask> 4
6
+ <bog> 5
7
+ <eog> 6
8
+ <bok> 7
9
+ <eok> 8
10
+ <+> 9
11
+ <-> 10
12
+ <mam> 11
13
+ <vrt> 12
14
+ <inv> 13
15
+ <pln> 14
16
+ <fng> 15
17
+ <prt> 16
18
+ <arc> 17
19
+ <bct> 18
20
+ <mit> 19
21
+ <plt> 20
22
+ <plm> 21
23
+ <vir> 22
24
+ <cds> 23
25
+ <pseudo> 24
26
+ <tRNA> 25
27
+ <rRNA> 26
28
+ <ncRNA> 27
29
+ <sp0> 28
30
+ <sp1> 29
31
+ <sp2> 30
32
+ <sp3> 31
33
+ A 32
34
+ C 33
35
+ G 34
36
+ <K> 35
37
+ <M> 36
38
+ N 37
39
+ <R> 38
40
+ <S> 39
41
+ T 40
42
+ <W> 41
43
+ <Y> 42