GenerTeam's picture
Update tokenizer
01fd197 verified
import os
import json
import re
from typing import List, Optional, Tuple, Dict
from transformers import PreTrainedTokenizer
class SingleNucleotideTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs):
# 定义词表
self.vocab_list = [
"<oov>", "<s>", "</s>", "<pad>", "<mask>",
"<bog>", "<eog>", "<bok>", "<eok>", "<+>", "<->",
"<mam>", "<vrt>", "<inv>", "<pln>", "<fng>", "<prt>",
"<arc>", "<bct>", "<mit>", "<plt>", "<plm>", "<vir>",
"<cds>", "<pseudo>", "<tRNA>", "<rRNA>", "<ncRNA>",
"<sp0>", "<sp1>", "<sp2>", "<sp3>",
"A", "C", "G", "<K>", "<M>", "N", "<R>", "<S>", "T", "<W>", "<Y>"
]
# 创建词汇映射
self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)}
self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()}
self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()}
# 设置特殊token
self.unk_token = "N"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
# 编译正则表达式以匹配特殊token
special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">"))
self.special_token_re = re.compile(f"({special_tokens_pattern})")
# 编译正则表达式以匹配普通token
self.normal_token_re = re.compile(r"[ACGTN]")
# 设置特殊token ID
self.unk_token_id = self.vocab[self.unk_token]
self.bos_token_id = self.vocab[self.bos_token]
self.eos_token_id = self.vocab[self.eos_token]
self.pad_token_id = self.vocab[self.pad_token]
self.mask_token_id = self.vocab[self.mask_token]
# 调用父类初始化
super().__init__(
unk_token=self.unk_token,
bos_token=self.bos_token,
eos_token=self.eos_token,
pad_token=self.pad_token,
mask_token=self.mask_token,
**kwargs
)
self.clean_up_tokenization_spaces = True
@property
def vocab_size(self) -> int:
return len(self.vocab)
def get_vocab(self) -> Dict[str, int]:
return self.vocab
def _tokenize(self, text: str, **kwargs) -> List[str]:
tokens = []
pos = 0
text_length = len(text)
while pos < text_length:
# 首先尝试匹配特殊token
special_match = self.special_token_re.match(text, pos)
if special_match:
token = special_match.group()
tokens.append(token)
pos = special_match.end()
continue
# 然后尝试匹配普通token
normal_match = self.normal_token_re.match(text, pos)
if normal_match:
token = normal_match.group()
# 确保token在词汇表中
if token in self.vocab:
tokens.append(token)
else:
tokens.append(self.unk_token)
pos = normal_match.end()
continue
# 如果都不匹配,跳过字符并使用unk_token
tokens.append(self.unk_token)
pos += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
# 简单地连接所有token
return "".join(tokens)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
# Llama通常不使用token类型ID
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 2) # +2 for [CLS] and [SEP]
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1)
def save_pretrained(self, save_directory: str, **kwargs):
"""重写save_pretrained以包含auto_map配置"""
# 先调用父类方法保存词汇表等
vocab_files = super().save_pretrained(save_directory, **kwargs)
# 创建或更新tokenizer_config.json
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
# 读取现有的配置或创建新的
if os.path.exists(tokenizer_config_path):
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
config = json.load(f)
else:
config = {}
# 添加auto_map配置
config.update({
"auto_map": {
"AutoTokenizer": [
"tokenizer.SingleNucleotideTokenizer", # 如果是直接运行的脚本
None
]
},
})
# 保存配置
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
return vocab_files
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
import os
# 确保目录存在
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# 创建词汇文件路径
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt"
)
# 写入词汇表
with open(vocab_file, "w", encoding="utf-8") as f:
for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]):
f.write(f"{token} {idx}\n")
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
# 直接创建新的tokenizer实例
return cls(**kwargs)