File size: 7,232 Bytes
4cb5f70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01fd197
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import json
import re
from typing import List, Optional, Tuple, Dict
from transformers import PreTrainedTokenizer

class SingleNucleotideTokenizer(PreTrainedTokenizer):
    def __init__(self, **kwargs):
        # 定义词表
        self.vocab_list = [
            "<oov>", "<s>", "</s>", "<pad>", "<mask>", 
            "<bog>", "<eog>", "<bok>", "<eok>", "<+>", "<->", 
            "<mam>", "<vrt>", "<inv>", "<pln>", "<fng>", "<prt>", 
            "<arc>", "<bct>", "<mit>", "<plt>", "<plm>", "<vir>", 
            "<cds>", "<pseudo>", "<tRNA>", "<rRNA>", "<ncRNA>", 
            "<sp0>", "<sp1>", "<sp2>", "<sp3>", 
            "A", "C", "G", "<K>", "<M>", "N", "<R>", "<S>", "T", "<W>", "<Y>"
        ]
        
        # 创建词汇映射
        self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)}
        self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()}
        self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()}
        
        # 设置特殊token
        self.unk_token = "N"
        self.bos_token = "<s>"
        self.eos_token = "</s>"
        self.pad_token = "<pad>"
        self.mask_token = "<mask>"
        
        # 编译正则表达式以匹配特殊token
        special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">"))
        self.special_token_re = re.compile(f"({special_tokens_pattern})")
        
        # 编译正则表达式以匹配普通token
        self.normal_token_re = re.compile(r"[ACGTN]")
        
        # 设置特殊token ID
        self.unk_token_id = self.vocab[self.unk_token]
        self.bos_token_id = self.vocab[self.bos_token]
        self.eos_token_id = self.vocab[self.eos_token]
        self.pad_token_id = self.vocab[self.pad_token]
        self.mask_token_id = self.vocab[self.mask_token]
        
        # 调用父类初始化
        super().__init__(
            unk_token=self.unk_token,
            bos_token=self.bos_token,
            eos_token=self.eos_token,
            pad_token=self.pad_token,
            mask_token=self.mask_token,
            **kwargs
        )
        self.clean_up_tokenization_spaces = True
    
    @property
    def vocab_size(self) -> int:
        return len(self.vocab)
    
    def get_vocab(self) -> Dict[str, int]:
        return self.vocab
    
    def _tokenize(self, text: str, **kwargs) -> List[str]:
        tokens = []
        pos = 0
        text_length = len(text)
        
        while pos < text_length:
            # 首先尝试匹配特殊token
            special_match = self.special_token_re.match(text, pos)
            if special_match:
                token = special_match.group()
                tokens.append(token)
                pos = special_match.end()
                continue
            
            # 然后尝试匹配普通token
            normal_match = self.normal_token_re.match(text, pos)
            if normal_match:
                token = normal_match.group()
                # 确保token在词汇表中
                if token in self.vocab:
                    tokens.append(token)
                else:
                    tokens.append(self.unk_token)
                pos = normal_match.end()
                continue
            
            # 如果都不匹配,跳过字符并使用unk_token
            tokens.append(self.unk_token)
            pos += 1
        
        return tokens
    
    def _convert_token_to_id(self, token: str) -> int:
        return self.vocab.get(token, self.unk_token_id)
    
    def _convert_id_to_token(self, index: int) -> str:
        return self.ids_to_tokens.get(index, self.unk_token)
    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        # 简单地连接所有token
        return "".join(tokens)
    
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        if token_ids_1 is None:
            return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
        return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
    
    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True
            )
        
        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
    
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        # Llama通常不使用token类型ID
        if token_ids_1 is None:
            return [0] * (len(token_ids_0) + 2)  # +2 for [CLS] and [SEP]
        return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1)
    
    def save_pretrained(self, save_directory: str, **kwargs):
        """重写save_pretrained以包含auto_map配置"""
        # 先调用父类方法保存词汇表等
        vocab_files = super().save_pretrained(save_directory, **kwargs)
        
        # 创建或更新tokenizer_config.json
        tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
        
        # 读取现有的配置或创建新的
        if os.path.exists(tokenizer_config_path):
            with open(tokenizer_config_path, "r", encoding="utf-8") as f:
                config = json.load(f)
        else:
            config = {}
        
        # 添加auto_map配置
        config.update({
            "auto_map": {
                "AutoTokenizer": [
                    "tokenizer.SingleNucleotideTokenizer",  # 如果是直接运行的脚本
                    None
                ]
            },
        })
        
        # 保存配置
        with open(tokenizer_config_path, "w", encoding="utf-8") as f:
            json.dump(config, f, ensure_ascii=False, indent=2)
        
        return vocab_files
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        import os
        
        # 确保目录存在
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        
        # 创建词汇文件路径
        vocab_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") + "vocab.txt"
        )
        
        # 写入词汇表
        with open(vocab_file, "w", encoding="utf-8") as f:
            for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]):
                f.write(f"{token} {idx}\n")
        
        return (vocab_file,)
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
        # 直接创建新的tokenizer实例
        return cls(**kwargs)