File size: 16,482 Bytes
b5831e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c59832
b5831e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c59832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5831e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c59832
 
 
b5831e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
#!/usr/bin/env python
# encoding: utf-8
'''
@license: (C) Copyright 2025, Hey.
@author: Hey
@email: sanyuan.hy@alibaba-inc.com
@tel: 137****6540
@datetime: 2025/12/30 11:33
@project: lucaone
@file: tokenization_lucaone
@desc: tokenization_lucaone
'''

import os
import json
import itertools
from typing import List, Optional, Dict, Any, Tuple, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

def gene_seq_replace(seq):
    """
    Gene sequence preprocessing: A->1, U/T->2, C->3, G->4, N->5
    Optimized for performance.
    """
    # 使用字典映射比 if-else 判断快
    mapping = {
        'A': '1', 'a': '1',
        'T': '2', 't': '2', 'U': '2', 'u': '2',
        'C': '3', 'c': '3',
        'G': '4', 'g': '4'
    }
    # 对于不在字典中的字符(如 N),默认返回 '5'
    return "".join([mapping.get(ch, '5') for ch in seq])

class LucaGPLMTokenizer(PreTrainedTokenizer):
    """
    HuggingFace-compatible tokenizer that performs identical tokenization 
    to the old model's Alphabet class.
    """
    
    # Vocabulary definitions matching the old model
    gene_prepend_toks = ['[PAD]', '[UNK]']
    gene_append_toks = ['[CLS]', '[SEP]', '[MASK]']
    gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
    
    prot_prepend_toks = ['[PAD]', '[UNK]']
    prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
    prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
    
    gene_prot_prepend_toks = ['[PAD]', '[UNK]']
    gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
    # EXACT VOCABULARY ORDER FROM ORIGINAL ALPHABET CLASS

    gene_prot_standard_toks = [
        '1',      # 5 - gene A (after gene_seq_replace)
        '2',      # 6 - gene T/U (after gene_seq_replace) 
        '3',      # 7 - gene C (after gene_seq_replace)
        '4',      # 8 - gene G (after gene_seq_replace)
        '5',      # 9 - gene N/unknown
        'L',      # 10 - protein
        'A',      # 11 - protein
        'G',      # 12 - protein
        'V',      # 13 - protein
        'S',      # 14 - protein
        'E',      # 15 - protein
        'R',      # 16 - protein
        'T',      # 17 - protein
        'I',      # 18 - protein
        'D',      # 19 - protein
        'P',      # 20 - protein
        'K',      # 21 - protein
        'Q',      # 22 - protein
        'N',      # 23 - protein
        'F',      # 24 - protein
        'Y',      # 25 - protein
        'M',      # 26 - protein
        'H',      # 27 - protein
        'W',      # 28 - protein
        'C',      # 29 - protein
        'X',      # 30 - protein unknown
        'B',      # 31 - protein
        'U',      # 32 - protein
        'Z',      # 33 - protein
        'O',      # 34 - protein
        'J',      # 35 - protein
        '.',      # 36 - special
        '-',      # 37 - special
        '*'       # 38 - special
    ]

    def __init__(
        self,
        vocab_type: str = "gene_prot",
        prepend_bos: bool = True,
        append_eos: bool = True,
        unk_token="[UNK]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        sep_token="[SEP]",
        mask_token="[MASK]",
        **kwargs
    ):
        # Set vocabulary based on type
        if vocab_type.lower() == "prot":
            prepend_toks = self.prot_prepend_toks
            append_toks = self.prot_append_toks
            standard_toks = self.prot_standard_toks
        elif vocab_type.lower() == "gene":
            prepend_toks = self.gene_prepend_toks
            append_toks = self.gene_append_toks
            standard_toks = self.gene_standard_toks
        elif vocab_type.lower() in ["gene_prot", "prot_gene"]:
            prepend_toks = self.gene_prot_prepend_toks
            append_toks = self.gene_prot_append_toks
            standard_toks = self.gene_prot_standard_toks
        else:
            raise ValueError(f"Not support tokenizer vocab_type: {vocab_type}")
        
        # Build vocabulary
        self.all_toks = list(prepend_toks) + list(append_toks) + list(standard_toks)
        self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
        self.idx_to_tok = {i: tok for i, tok in enumerate(self.all_toks)}
        
        # Store configuration
        self.vocab_type = vocab_type
        self.prepend_bos = prepend_bos
        self.append_eos = append_eos
        self.unique_no_split_tokens = self.all_toks.copy()
        
        # Special token indices
        self.unk_idx = self.tok_to_idx.get("[UNK]", 1)
        self.padding_idx = self.tok_to_idx.get("[PAD]", 0)
        self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
        self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
        self.eos_idx = self.tok_to_idx.get("[SEP]", 3)

        super().__init__(
            unk_token=unk_token,
            pad_token=pad_token,
            cls_token=cls_token,
            sep_token=sep_token,
            mask_token=mask_token,
            **kwargs
        )

    def get_vocab(self) -> Dict[str, int]:
        return self.tok_to_idx.copy()

    @property
    def vocab_size(self) -> int:
        return len(self.all_toks)

    def get_idx(self, tok):
        return self.tok_to_idx.get(tok, self.unk_idx)

    def get_tok(self, idx):
        return self.idx_to_tok.get(idx, "[UNK]")

    def _tokenize_char_level(self, text: str) -> List[str]:
        """Simple character-level tokenization (fallback)"""
        return list(text)

    def _tokenize(self, text: str) -> List[str]:
        """
        Tokenize text using the same logic as the old Alphabet.tokenize() method
        """
        text = text.strip()
        if not text:
            return []
            
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self.get_idx(token)

    def _convert_id_to_token(self, index: int) -> str:
        return self.get_tok(index)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        return "".join(tokens)

    def _convert_text_to_ids(self, text: str, seq_type: str) -> List[int]:
        """Internal helper to convert text to IDs without special tokens."""
        if seq_type == "gene":
            text = gene_seq_replace(text)
        tokens = self._tokenize(text)
        return [self._convert_token_to_id(token) for token in tokens]

    def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
        """
        Build model inputs from a sequence by adding special tokens.
        This mimics the old model's prepend_bos and append_eos behavior.
        """
        result = token_ids_0.copy()
        
        if self.prepend_bos:
            result = [self.cls_idx] + result
        if self.append_eos:
            result = result + [self.eos_idx]
            
        return result

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list.
        """
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        result = [0] * len(token_ids_0)
        if self.prepend_bos:
            result = [1] + result
        if self.append_eos:
            result = result + [1]
        return result

    def encode(
        self,
        text: str,
        seq_type: str = "gene",
        add_special_tokens: bool = True,
        padding: Union[bool, str] = False, # 虽然 encode 通常不处理 padding,但保持 API 兼容性
        truncation: bool = False,          # <--- 关键参数
        max_length: Optional[int] = None,  # <--- 关键参数
        **kwargs
    ) -> List[int]:
        
        # 1. 基础转换
        token_ids = self._convert_text_to_ids(text, seq_type)
        
        # 2. 添加特殊 token
        if add_special_tokens:
            token_ids = self.build_inputs_with_special_tokens(token_ids)
            
        # 3. 执行截断 (修复点:之前这里缺失逻辑)
        if truncation and max_length is not None and len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
            # 如果启用了 append_eos,强行把截断后的最后一位改回 SEP
            if add_special_tokens and self.append_eos:
                token_ids[-1] = self.eos_idx
                
        return token_ids

    def __call__(
        self,
        text: Union[str, List[str]],
        text_pair: Optional[Union[str, List[str]]] = None,
        seq_type: str = "gene",
        add_special_tokens: bool = True,
        padding: Union[bool, str] = False,
        max_length: Optional[int] = None,
        return_attention_mask: bool = True,
        return_token_type_ids: bool = True,
        return_tensors: Optional[str] = None,
        truncation: bool = False,
        **kwargs
    ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
        """
        Main callable method for tokenization - HuggingFace standard interface
        """
        if isinstance(text, list):
            # Handle batch processing
            return self.batch_encode_plus(
                text,
                text_pair=text_pair,
                seq_type=seq_type,
                add_special_tokens=add_special_tokens,
                padding=padding,
                max_length=max_length,
                return_attention_mask=return_attention_mask,
                return_token_type_ids=return_token_type_ids,
                return_tensors=return_tensors,
                truncation=truncation,
                **kwargs
            )
        else:
            # Handle single text
            return self.encode_plus(
                text,
                text_pair=text_pair,
                seq_type=seq_type,
                add_special_tokens=add_special_tokens,
                padding=padding,
                max_length=max_length,
                return_attention_mask=return_attention_mask,
                return_token_type_ids=return_token_type_ids,
                return_tensors=return_tensors,
                truncation=truncation,
                **kwargs
            )

    def batch_encode_plus(self, *args, **kwargs):
        # 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
        # return super().batch_encode_plus(*args, **kwargs)
        # 修改
        # 循环处理每一条数据
        batch_outputs = []
        batch_text = kwargs["text"]
        seq_type = kwargs["seq_type"]
        for text in batch_text:
            batch_outputs.append(self.encode_plus(text, seq_type=seq_type, **kwargs))

        # 将结果合并为 Dict[str, List[List[int]]]
        # 这样 Dataset.map(batched=True) 才能正确解析
        combined = {key: [] for key in batch_outputs[0].keys()}
        for output in batch_outputs:
            for key, value in output.items():
                combined[key].append(value)

        return combined

    def encode_plus(
        self,
        text: str,
        text_pair: Optional[str] = None,
        seq_type: str = "gene",
        add_special_tokens: bool = True,
        padding: Union[bool, str] = False,
        max_length: Optional[int] = None,
        return_attention_mask: bool = True,
        return_token_type_ids: bool = True,
        return_tensors: Optional[str] = None,
        truncation: bool = False,
        **kwargs
    ) -> Dict[str, Any]:
        # 修改
        # 忽略掉不认识的参数,比如 text_pair
        kwargs.pop("text_pair", None)
        # 调用修复后的 encode,它现在会正确处理截断
        token_ids = self.encode(
            text, 
            seq_type=seq_type, 
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length
        )
        
        # 处理 Padding
        attention_mask = [1] * len(token_ids)
        if padding == "max_length" and max_length is not None:
            if len(token_ids) < max_length:
                pad_length = max_length - len(token_ids)
                token_ids.extend([self.padding_idx] * pad_length)
                attention_mask.extend([0] * pad_length)
        # 注意:padding=True (dynamic padding) 通常由 batch_encode_plus 处理,这里单条通常不处理
        
        result = {"input_ids": token_ids}
        
        if return_attention_mask:
            result["attention_mask"] = attention_mask
        
        if return_token_type_ids:
            # 0 for gene, 1 for protein
            type_value = 0 if seq_type == "gene" else 1
            result["token_type_ids"] = [type_value] * len(token_ids)
        
        if return_tensors == "pt":
            import torch
            for key, value in result.items():
                result[key] = torch.tensor(value, dtype=torch.long).unsqueeze(0)
        
        return result

    def encode_old_model_style(
        self,
        text: str,
        seq_type: str = "gene", 
        max_length: int = None
    ) -> List[int]:
        """
        Encode using the EXACT same process as the old model's encoder function.
        This replicates the logic from src/llm/lucaone_virus/get_embedding.py:encoder()
        """
        # Preprocess gene sequences (done in get_embedding function BEFORE calling encoder)
        if seq_type == "gene":
            text = gene_seq_replace(text)
        
        # Call tokenizer.encode (which does NOT include BOS/EOS in old model)
        seq_encoded = self.encode(text, seq_type=seq_type, add_special_tokens=False)
        
        # Apply max_length truncation if specified  
        if max_length and len(seq_encoded) > max_length:
            seq_encoded = seq_encoded[:max_length]
        
        # Calculate processed_seq_len (as done in old model)
        processed_seq_len = len(seq_encoded) + int(self.prepend_bos) + int(self.append_eos)
        
        # Create input_ids tensor (as done in old model encoder function)
        input_ids = [self.padding_idx] * processed_seq_len
        
        # Add BOS token if enabled
        if self.prepend_bos:
            input_ids[0] = self.cls_idx
            
        # Place the encoded sequence
        start_idx = int(self.prepend_bos)
        for i, token_id in enumerate(seq_encoded):
            input_ids[start_idx + i] = token_id
            
        # Add EOS token if enabled  
        if self.append_eos:
            input_ids[len(seq_encoded) + int(self.prepend_bos)] = self.eos_idx
            
        return input_ids

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Save the tokenizer vocabulary to a JSON file.
        Required by HuggingFace tokenizer interface.
        """
        if filename_prefix is None:
            filename_prefix = ""
        else:
            filename_prefix = filename_prefix + "-"
        
        vocab_file = os.path.join(save_directory, f"{filename_prefix}vocab.json")
        vocab_dict = self.get_vocab()
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(vocab_dict, f, ensure_ascii=False, indent=2)
        
        return (vocab_file,)

    @classmethod 
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        """
        Load tokenizer from pretrained model path (standard HuggingFace interface)
        """
        vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
        if os.path.exists(vocab_file):
            print("Load from saved vocabulary (not implemented yet, use default)")
            return cls(vocab_type="gene_prot", **kwargs)
        else:
            return cls(vocab_type="gene_prot", **kwargs)

class LucaGPLMTokenizerFast(PreTrainedTokenizerFast):
    """
    Fast tokenizer version - currently just delegates to slow tokenizer
    """
    slow_tokenizer_class = LucaGPLMTokenizer
    
    def __init__(self, **kwargs):
        # For now, this is just a placeholder
        # In a full implementation, you would use the tokenizers library
        super().__init__(**kwargs)

__all__ = ["LucaGPLMTokenizer", "LucaGPLMTokenizerFast", "gene_seq_replace"]