LucaOne-default-step36M / tokenization_lucaone.py
LucaGroup's picture
Update weights and modeling code to latest version
6c59832 verified
#!/usr/bin/env python
# encoding: utf-8
'''
@license: (C) Copyright 2025, Hey.
@author: Hey
@email: sanyuan.hy@alibaba-inc.com
@tel: 137****6540
@datetime: 2025/12/30 11:33
@project: lucaone
@file: tokenization_lucaone
@desc: tokenization_lucaone
'''
import os
import json
import itertools
from typing import List, Optional, Dict, Any, Tuple, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
def gene_seq_replace(seq):
"""
Gene sequence preprocessing: A->1, U/T->2, C->3, G->4, N->5
Optimized for performance.
"""
# 使用字典映射比 if-else 判断快
mapping = {
'A': '1', 'a': '1',
'T': '2', 't': '2', 'U': '2', 'u': '2',
'C': '3', 'c': '3',
'G': '4', 'g': '4'
}
# 对于不在字典中的字符(如 N),默认返回 '5'
return "".join([mapping.get(ch, '5') for ch in seq])
class LucaGPLMTokenizer(PreTrainedTokenizer):
"""
HuggingFace-compatible tokenizer that performs identical tokenization
to the old model's Alphabet class.
"""
# Vocabulary definitions matching the old model
gene_prepend_toks = ['[PAD]', '[UNK]']
gene_append_toks = ['[CLS]', '[SEP]', '[MASK]']
gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
prot_prepend_toks = ['[PAD]', '[UNK]']
prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
gene_prot_prepend_toks = ['[PAD]', '[UNK]']
gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
# EXACT VOCABULARY ORDER FROM ORIGINAL ALPHABET CLASS
gene_prot_standard_toks = [
'1', # 5 - gene A (after gene_seq_replace)
'2', # 6 - gene T/U (after gene_seq_replace)
'3', # 7 - gene C (after gene_seq_replace)
'4', # 8 - gene G (after gene_seq_replace)
'5', # 9 - gene N/unknown
'L', # 10 - protein
'A', # 11 - protein
'G', # 12 - protein
'V', # 13 - protein
'S', # 14 - protein
'E', # 15 - protein
'R', # 16 - protein
'T', # 17 - protein
'I', # 18 - protein
'D', # 19 - protein
'P', # 20 - protein
'K', # 21 - protein
'Q', # 22 - protein
'N', # 23 - protein
'F', # 24 - protein
'Y', # 25 - protein
'M', # 26 - protein
'H', # 27 - protein
'W', # 28 - protein
'C', # 29 - protein
'X', # 30 - protein unknown
'B', # 31 - protein
'U', # 32 - protein
'Z', # 33 - protein
'O', # 34 - protein
'J', # 35 - protein
'.', # 36 - special
'-', # 37 - special
'*' # 38 - special
]
def __init__(
self,
vocab_type: str = "gene_prot",
prepend_bos: bool = True,
append_eos: bool = True,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
**kwargs
):
# Set vocabulary based on type
if vocab_type.lower() == "prot":
prepend_toks = self.prot_prepend_toks
append_toks = self.prot_append_toks
standard_toks = self.prot_standard_toks
elif vocab_type.lower() == "gene":
prepend_toks = self.gene_prepend_toks
append_toks = self.gene_append_toks
standard_toks = self.gene_standard_toks
elif vocab_type.lower() in ["gene_prot", "prot_gene"]:
prepend_toks = self.gene_prot_prepend_toks
append_toks = self.gene_prot_append_toks
standard_toks = self.gene_prot_standard_toks
else:
raise ValueError(f"Not support tokenizer vocab_type: {vocab_type}")
# Build vocabulary
self.all_toks = list(prepend_toks) + list(append_toks) + list(standard_toks)
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
self.idx_to_tok = {i: tok for i, tok in enumerate(self.all_toks)}
# Store configuration
self.vocab_type = vocab_type
self.prepend_bos = prepend_bos
self.append_eos = append_eos
self.unique_no_split_tokens = self.all_toks.copy()
# Special token indices
self.unk_idx = self.tok_to_idx.get("[UNK]", 1)
self.padding_idx = self.tok_to_idx.get("[PAD]", 0)
self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
self.eos_idx = self.tok_to_idx.get("[SEP]", 3)
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
cls_token=cls_token,
sep_token=sep_token,
mask_token=mask_token,
**kwargs
)
def get_vocab(self) -> Dict[str, int]:
return self.tok_to_idx.copy()
@property
def vocab_size(self) -> int:
return len(self.all_toks)
def get_idx(self, tok):
return self.tok_to_idx.get(tok, self.unk_idx)
def get_tok(self, idx):
return self.idx_to_tok.get(idx, "[UNK]")
def _tokenize_char_level(self, text: str) -> List[str]:
"""Simple character-level tokenization (fallback)"""
return list(text)
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize text using the same logic as the old Alphabet.tokenize() method
"""
text = text.strip()
if not text:
return []
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self.get_idx(token)
def _convert_id_to_token(self, index: int) -> str:
return self.get_tok(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return "".join(tokens)
def _convert_text_to_ids(self, text: str, seq_type: str) -> List[int]:
"""Internal helper to convert text to IDs without special tokens."""
if seq_type == "gene":
text = gene_seq_replace(text)
tokens = self._tokenize(text)
return [self._convert_token_to_id(token) for token in tokens]
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Build model inputs from a sequence by adding special tokens.
This mimics the old model's prepend_bos and append_eos behavior.
"""
result = token_ids_0.copy()
if self.prepend_bos:
result = [self.cls_idx] + result
if self.append_eos:
result = result + [self.eos_idx]
return result
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
result = [0] * len(token_ids_0)
if self.prepend_bos:
result = [1] + result
if self.append_eos:
result = result + [1]
return result
def encode(
self,
text: str,
seq_type: str = "gene",
add_special_tokens: bool = True,
padding: Union[bool, str] = False, # 虽然 encode 通常不处理 padding,但保持 API 兼容性
truncation: bool = False, # <--- 关键参数
max_length: Optional[int] = None, # <--- 关键参数
**kwargs
) -> List[int]:
# 1. 基础转换
token_ids = self._convert_text_to_ids(text, seq_type)
# 2. 添加特殊 token
if add_special_tokens:
token_ids = self.build_inputs_with_special_tokens(token_ids)
# 3. 执行截断 (修复点:之前这里缺失逻辑)
if truncation and max_length is not None and len(token_ids) > max_length:
token_ids = token_ids[:max_length]
# 如果启用了 append_eos,强行把截断后的最后一位改回 SEP
if add_special_tokens and self.append_eos:
token_ids[-1] = self.eos_idx
return token_ids
def __call__(
self,
text: Union[str, List[str]],
text_pair: Optional[Union[str, List[str]]] = None,
seq_type: str = "gene",
add_special_tokens: bool = True,
padding: Union[bool, str] = False,
max_length: Optional[int] = None,
return_attention_mask: bool = True,
return_token_type_ids: bool = True,
return_tensors: Optional[str] = None,
truncation: bool = False,
**kwargs
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Main callable method for tokenization - HuggingFace standard interface
"""
if isinstance(text, list):
# Handle batch processing
return self.batch_encode_plus(
text,
text_pair=text_pair,
seq_type=seq_type,
add_special_tokens=add_special_tokens,
padding=padding,
max_length=max_length,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_tensors=return_tensors,
truncation=truncation,
**kwargs
)
else:
# Handle single text
return self.encode_plus(
text,
text_pair=text_pair,
seq_type=seq_type,
add_special_tokens=add_special_tokens,
padding=padding,
max_length=max_length,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_tensors=return_tensors,
truncation=truncation,
**kwargs
)
def batch_encode_plus(self, *args, **kwargs):
# 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
# return super().batch_encode_plus(*args, **kwargs)
# 修改
# 循环处理每一条数据
batch_outputs = []
batch_text = kwargs["text"]
seq_type = kwargs["seq_type"]
for text in batch_text:
batch_outputs.append(self.encode_plus(text, seq_type=seq_type, **kwargs))
# 将结果合并为 Dict[str, List[List[int]]]
# 这样 Dataset.map(batched=True) 才能正确解析
combined = {key: [] for key in batch_outputs[0].keys()}
for output in batch_outputs:
for key, value in output.items():
combined[key].append(value)
return combined
def encode_plus(
self,
text: str,
text_pair: Optional[str] = None,
seq_type: str = "gene",
add_special_tokens: bool = True,
padding: Union[bool, str] = False,
max_length: Optional[int] = None,
return_attention_mask: bool = True,
return_token_type_ids: bool = True,
return_tensors: Optional[str] = None,
truncation: bool = False,
**kwargs
) -> Dict[str, Any]:
# 修改
# 忽略掉不认识的参数,比如 text_pair
kwargs.pop("text_pair", None)
# 调用修复后的 encode,它现在会正确处理截断
token_ids = self.encode(
text,
seq_type=seq_type,
add_special_tokens=add_special_tokens,
truncation=truncation,
max_length=max_length
)
# 处理 Padding
attention_mask = [1] * len(token_ids)
if padding == "max_length" and max_length is not None:
if len(token_ids) < max_length:
pad_length = max_length - len(token_ids)
token_ids.extend([self.padding_idx] * pad_length)
attention_mask.extend([0] * pad_length)
# 注意:padding=True (dynamic padding) 通常由 batch_encode_plus 处理,这里单条通常不处理
result = {"input_ids": token_ids}
if return_attention_mask:
result["attention_mask"] = attention_mask
if return_token_type_ids:
# 0 for gene, 1 for protein
type_value = 0 if seq_type == "gene" else 1
result["token_type_ids"] = [type_value] * len(token_ids)
if return_tensors == "pt":
import torch
for key, value in result.items():
result[key] = torch.tensor(value, dtype=torch.long).unsqueeze(0)
return result
def encode_old_model_style(
self,
text: str,
seq_type: str = "gene",
max_length: int = None
) -> List[int]:
"""
Encode using the EXACT same process as the old model's encoder function.
This replicates the logic from src/llm/lucaone_virus/get_embedding.py:encoder()
"""
# Preprocess gene sequences (done in get_embedding function BEFORE calling encoder)
if seq_type == "gene":
text = gene_seq_replace(text)
# Call tokenizer.encode (which does NOT include BOS/EOS in old model)
seq_encoded = self.encode(text, seq_type=seq_type, add_special_tokens=False)
# Apply max_length truncation if specified
if max_length and len(seq_encoded) > max_length:
seq_encoded = seq_encoded[:max_length]
# Calculate processed_seq_len (as done in old model)
processed_seq_len = len(seq_encoded) + int(self.prepend_bos) + int(self.append_eos)
# Create input_ids tensor (as done in old model encoder function)
input_ids = [self.padding_idx] * processed_seq_len
# Add BOS token if enabled
if self.prepend_bos:
input_ids[0] = self.cls_idx
# Place the encoded sequence
start_idx = int(self.prepend_bos)
for i, token_id in enumerate(seq_encoded):
input_ids[start_idx + i] = token_id
# Add EOS token if enabled
if self.append_eos:
input_ids[len(seq_encoded) + int(self.prepend_bos)] = self.eos_idx
return input_ids
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the tokenizer vocabulary to a JSON file.
Required by HuggingFace tokenizer interface.
"""
if filename_prefix is None:
filename_prefix = ""
else:
filename_prefix = filename_prefix + "-"
vocab_file = os.path.join(save_directory, f"{filename_prefix}vocab.json")
vocab_dict = self.get_vocab()
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(vocab_dict, f, ensure_ascii=False, indent=2)
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
Load tokenizer from pretrained model path (standard HuggingFace interface)
"""
vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
if os.path.exists(vocab_file):
print("Load from saved vocabulary (not implemented yet, use default)")
return cls(vocab_type="gene_prot", **kwargs)
else:
return cls(vocab_type="gene_prot", **kwargs)
class LucaGPLMTokenizerFast(PreTrainedTokenizerFast):
"""
Fast tokenizer version - currently just delegates to slow tokenizer
"""
slow_tokenizer_class = LucaGPLMTokenizer
def __init__(self, **kwargs):
# For now, this is just a placeholder
# In a full implementation, you would use the tokenizers library
super().__init__(**kwargs)
__all__ = ["LucaGPLMTokenizer", "LucaGPLMTokenizerFast", "gene_seq_replace"]