import os import json import re from typing import List, Optional, Tuple, Dict from transformers import PreTrainedTokenizer class SingleNucleotideTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): # 定义词表 self.vocab_list = [ "", "", "", "", "", "", "", "", "", "<+>", "<->", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "A", "C", "G", "", "", "N", "", "", "T", "", "" ] # 创建词汇映射 self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)} self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()} # 设置特殊token self.unk_token = "N" self.bos_token = "" self.eos_token = "" self.pad_token = "" self.mask_token = "" # 编译正则表达式以匹配特殊token special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">")) self.special_token_re = re.compile(f"({special_tokens_pattern})") # 编译正则表达式以匹配普通token self.normal_token_re = re.compile(r"[ACGTN]") # 设置特殊token ID self.unk_token_id = self.vocab[self.unk_token] self.bos_token_id = self.vocab[self.bos_token] self.eos_token_id = self.vocab[self.eos_token] self.pad_token_id = self.vocab[self.pad_token] self.mask_token_id = self.vocab[self.mask_token] # 调用父类初始化 super().__init__( unk_token=self.unk_token, bos_token=self.bos_token, eos_token=self.eos_token, pad_token=self.pad_token, mask_token=self.mask_token, **kwargs ) self.clean_up_tokenization_spaces = True @property def vocab_size(self) -> int: return len(self.vocab) def get_vocab(self) -> Dict[str, int]: return self.vocab def _tokenize(self, text: str, **kwargs) -> List[str]: tokens = [] pos = 0 text_length = len(text) while pos < text_length: # 首先尝试匹配特殊token special_match = self.special_token_re.match(text, pos) if special_match: token = special_match.group() tokens.append(token) pos = special_match.end() continue # 然后尝试匹配普通token normal_match = self.normal_token_re.match(text, pos) if normal_match: token = normal_match.group() # 确保token在词汇表中 if token in self.vocab: tokens.append(token) else: tokens.append(self.unk_token) pos = normal_match.end() continue # 如果都不匹配,跳过字符并使用unk_token tokens.append(self.unk_token) pos += 1 return tokens def _convert_token_to_id(self, token: str) -> int: return self.vocab.get(token, self.unk_token_id) def _convert_id_to_token(self, index: int) -> str: return self.ids_to_tokens.get(index, self.unk_token) def convert_tokens_to_string(self, tokens: List[str]) -> str: # 简单地连接所有token return "".join(tokens) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: if token_ids_1 is None: return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is None: return [1] + ([0] * len(token_ids_0)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: # Llama通常不使用token类型ID if token_ids_1 is None: return [0] * (len(token_ids_0) + 2) # +2 for [CLS] and [SEP] return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1) def save_pretrained(self, save_directory: str, **kwargs): """重写save_pretrained以包含auto_map配置""" # 先调用父类方法保存词汇表等 vocab_files = super().save_pretrained(save_directory, **kwargs) # 创建或更新tokenizer_config.json tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json") # 读取现有的配置或创建新的 if os.path.exists(tokenizer_config_path): with open(tokenizer_config_path, "r", encoding="utf-8") as f: config = json.load(f) else: config = {} # 添加auto_map配置 config.update({ "auto_map": { "AutoTokenizer": [ "tokenizer.SingleNucleotideTokenizer", # 如果是直接运行的脚本 None ] }, }) # 保存配置 with open(tokenizer_config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=2) return vocab_files def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: import os # 确保目录存在 if not os.path.exists(save_directory): os.makedirs(save_directory) # 创建词汇文件路径 vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt" ) # 写入词汇表 with open(vocab_file, "w", encoding="utf-8") as f: for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]): f.write(f"{token} {idx}\n") return (vocab_file,) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): # 直接创建新的tokenizer实例 return cls(**kwargs)