File size: 7,232 Bytes
4cb5f70 01fd197 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import json
import re
from typing import List, Optional, Tuple, Dict
from transformers import PreTrainedTokenizer
class SingleNucleotideTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs):
# 定义词表
self.vocab_list = [
"<oov>", "<s>", "</s>", "<pad>", "<mask>",
"<bog>", "<eog>", "<bok>", "<eok>", "<+>", "<->",
"<mam>", "<vrt>", "<inv>", "<pln>", "<fng>", "<prt>",
"<arc>", "<bct>", "<mit>", "<plt>", "<plm>", "<vir>",
"<cds>", "<pseudo>", "<tRNA>", "<rRNA>", "<ncRNA>",
"<sp0>", "<sp1>", "<sp2>", "<sp3>",
"A", "C", "G", "<K>", "<M>", "N", "<R>", "<S>", "T", "<W>", "<Y>"
]
# 创建词汇映射
self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)}
self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()}
self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()}
# 设置特殊token
self.unk_token = "N"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
# 编译正则表达式以匹配特殊token
special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">"))
self.special_token_re = re.compile(f"({special_tokens_pattern})")
# 编译正则表达式以匹配普通token
self.normal_token_re = re.compile(r"[ACGTN]")
# 设置特殊token ID
self.unk_token_id = self.vocab[self.unk_token]
self.bos_token_id = self.vocab[self.bos_token]
self.eos_token_id = self.vocab[self.eos_token]
self.pad_token_id = self.vocab[self.pad_token]
self.mask_token_id = self.vocab[self.mask_token]
# 调用父类初始化
super().__init__(
unk_token=self.unk_token,
bos_token=self.bos_token,
eos_token=self.eos_token,
pad_token=self.pad_token,
mask_token=self.mask_token,
**kwargs
)
self.clean_up_tokenization_spaces = True
@property
def vocab_size(self) -> int:
return len(self.vocab)
def get_vocab(self) -> Dict[str, int]:
return self.vocab
def _tokenize(self, text: str, **kwargs) -> List[str]:
tokens = []
pos = 0
text_length = len(text)
while pos < text_length:
# 首先尝试匹配特殊token
special_match = self.special_token_re.match(text, pos)
if special_match:
token = special_match.group()
tokens.append(token)
pos = special_match.end()
continue
# 然后尝试匹配普通token
normal_match = self.normal_token_re.match(text, pos)
if normal_match:
token = normal_match.group()
# 确保token在词汇表中
if token in self.vocab:
tokens.append(token)
else:
tokens.append(self.unk_token)
pos = normal_match.end()
continue
# 如果都不匹配,跳过字符并使用unk_token
tokens.append(self.unk_token)
pos += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
# 简单地连接所有token
return "".join(tokens)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
# Llama通常不使用token类型ID
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 2) # +2 for [CLS] and [SEP]
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1)
def save_pretrained(self, save_directory: str, **kwargs):
"""重写save_pretrained以包含auto_map配置"""
# 先调用父类方法保存词汇表等
vocab_files = super().save_pretrained(save_directory, **kwargs)
# 创建或更新tokenizer_config.json
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
# 读取现有的配置或创建新的
if os.path.exists(tokenizer_config_path):
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
config = json.load(f)
else:
config = {}
# 添加auto_map配置
config.update({
"auto_map": {
"AutoTokenizer": [
"tokenizer.SingleNucleotideTokenizer", # 如果是直接运行的脚本
None
]
},
})
# 保存配置
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
return vocab_files
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
import os
# 确保目录存在
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# 创建词汇文件路径
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt"
)
# 写入词汇表
with open(vocab_file, "w", encoding="utf-8") as f:
for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]):
f.write(f"{token} {idx}\n")
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
# 直接创建新的tokenizer实例
return cls(**kwargs) |