| import torch |
| from typing import List, Union, Optional |
| from rust_trie import Trie |
| import os |
|
|
|
|
| class Tokenizer: |
| def __init__(self, tokens: List[str], unk_token_id: Optional[int] = None): |
| self.ids_to_tokens = tokens |
| self.trie = Trie(unk_token_id) |
| for token in tokens: |
| self.trie.add(token) |
| |
| if unk_token_id is None: |
| self.ids_to_tokens += ["<unk>"] |
| self.pad_token_id = self.ids_to_tokens.index("<pad>") |
| self.mask_token_id = self.ids_to_tokens.index("<mask>") |
|
|
| def __call__(self, sequences: Union[str, List], *args, **kwargs): |
| if isinstance(sequences, str): |
| return self.encode(sequences, *args, **kwargs) |
| else: |
| return self.batch_encode(sequences, *args, **kwargs) |
|
|
| def encode( |
| self, |
| sequence: str, |
| add_special_tokens: bool = False, |
| return_tensor: bool = False, |
| max_sequence_length: Optional[int] = None, |
| ) -> List[int]: |
| if max_sequence_length is not None: |
| if add_special_tokens: |
| max_sequence_length -= 2 |
| if len(sequence) > max_sequence_length: |
| |
| start_idx = torch.randint( |
| 0, len(sequence) - max_sequence_length + 1, (1,) |
| ) |
| sequence = sequence[start_idx : start_idx + max_sequence_length] |
|
|
| if add_special_tokens: |
| sequence = "<cls>" + sequence + "<eos>" |
| output = self.trie.tokenize(sequence) |
| if return_tensor: |
| output = torch.tensor(output, dtype=torch.long) |
| return output |
|
|
| def batch_encode( |
| self, |
| sequences: List[str], |
| add_special_tokens: bool = False, |
| return_tensors: bool = False, |
| max_sequence_length: Optional[int] = None, |
| ) -> List[List[int]]: |
| output = [] |
| if max_sequence_length is None and return_tensors: |
| max_sequence_length = max([len(sequence) for sequence in sequences]) |
| if add_special_tokens: |
| max_sequence_length += 2 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for sequence in sequences: |
| output.append( |
| self.encode( |
| sequence, |
| add_special_tokens, |
| return_tensors, |
| max_sequence_length=max_sequence_length, |
| ) |
| ) |
| if return_tensors: |
| tensor_out = torch.full( |
| (len(output), max_sequence_length), self.pad_token_id |
| ) |
| for i, sequence in enumerate(output): |
| tensor_out[i, : len(sequence)] = sequence |
| output = tensor_out |
| return output |
|
|
| def decode(self, tokens: List[int]) -> str: |
| return "".join([self.ids_to_tokens[idx] for idx in tokens]) |
|
|
|
|
| class EsmTokenizer(Tokenizer): |
| def __init__(self): |
| tokens = [ |
| "<cls>", |
| "<pad>", |
| "<eos>", |
| "<unk>", |
| "L", |
| "A", |
| "G", |
| "V", |
| "S", |
| "E", |
| "R", |
| "T", |
| "I", |
| "D", |
| "P", |
| "K", |
| "Q", |
| "N", |
| "F", |
| "Y", |
| "M", |
| "H", |
| "W", |
| "C", |
| "X", |
| "B", |
| "U", |
| "Z", |
| "O", |
| ".", |
| "-", |
| "<null_1>", |
| "<mask>", |
| ] |
| super().__init__(tokens, unk_token_id=3) |
|
|
|
|
| class PTMTokenizer(Tokenizer): |
| def __init__(self): |
| tokens = [ |
| "<cls>", |
| "<pad>", |
| "<eos>", |
| "<unk>", |
| ".", |
| "-", |
| "<null_1>", |
| "<mask>", |
| "L", |
| "A", |
| "G", |
| "V", |
| "S", |
| "E", |
| "R", |
| "T", |
| "I", |
| "D", |
| "P", |
| "K", |
| "Q", |
| "N", |
| "F", |
| "Y", |
| "M", |
| "H", |
| "W", |
| "C", |
| "X", |
| "B", |
| "U", |
| "Z", |
| "O", |
| "PTM", |
| "<N-linked (GlcNAc...) asparagine>", |
| "<Pyrrolidone carboxylic acid>", |
| "<Phosphoserine>", |
| "<Phosphothreonine>", |
| "<N-acetylalanine>", |
| "<N-acetylmethionine>", |
| "<N6-acetyllysine>", |
| "<Phosphotyrosine>", |
| "<S-diacylglycerol cysteine>", |
| "<N6-(pyridoxal phosphate)lysine>", |
| "<N-acetylserine>", |
| "<N6-carboxylysine>", |
| "<N6-succinyllysine>", |
| "<S-palmitoyl cysteine>", |
| "<O-(pantetheine 4'-phosphoryl)serine>", |
| "<Sulfotyrosine>", |
| "<O-linked (GalNAc...) threonine>", |
| "<Omega-N-methylarginine>", |
| "<N-myristoyl glycine>", |
| "<4-hydroxyproline>", |
| "<Asymmetric dimethylarginine>", |
| "<N5-methylglutamine>", |
| "<4-aspartylphosphate>", |
| "<S-geranylgeranyl cysteine>", |
| "<4-carboxyglutamate>", |
| ] |
| super().__init__(tokens, unk_token_id=3) |
| self.ptm_token_start = self.ids_to_tokens.index("PTM") |
|
|
| def is_ptm_token(self, input_ids: torch.tensor): |
| return input_ids > self.ptm_token_start |
|
|
| def is_special_token(self, input_ids: torch.tensor): |
| l_id = self.ids_to_tokens.index("L") |
| return input_ids < l_id |
|
|
| def __len__(self): |
| return len(self.ids_to_tokens) |
|
|
| def get_vocab_size(self): |
| return len(self.ids_to_tokens) |
|
|
|
|
| class AptTokenizer(Tokenizer): |
| def __init__(self): |
| |
| |
| |
| tokens = [ |
| "<cls>", |
| "<pad>", |
| "<eos>", |
| "L", |
| "A", |
| "G", |
| "V", |
| "S", |
| "E", |
| "R", |
| "T", |
| "I", |
| "D", |
| "P", |
| "K", |
| "Q", |
| "N", |
| "F", |
| "Y", |
| "M", |
| "H", |
| "W", |
| "C", |
| "B", |
| "U", |
| "Z", |
| "O", |
| "<mask>", |
| ] |
| super().__init__(tokens) |
|
|