| import torch |
| from typing import List, Optional, Union, Dict |
| from torch import Tensor |
|
|
| from itertools import compress |
|
|
| |
| from tokenizers import Tokenizer |
| from transformers import PreTrainedTokenizerFast, BatchEncoding |
| from tokenizers.models import WordPiece |
| from tokenizers.pre_tokenizers import Split |
|
|
|
|
| VOCAB = { |
| "<pad>": 0, |
| "<unk>": 1, |
| "<mask>": 2, |
| "<bos>": 3, |
| "<eos>": 4, |
| "|": 5, |
| "X": 6, |
| "B": 7, |
| "O": 8, |
| "U": 9, |
| "Z": 10, |
| "J": 11, |
| "L": 12, |
| "A": 13, |
| "G": 14, |
| "V": 15, |
| "S": 16, |
| "E": 17, |
| "R": 18, |
| "T": 19, |
| "I": 20, |
| "D": 21, |
| "P": 22, |
| "K": 23, |
| "Q": 24, |
| "N": 25, |
| "F": 26, |
| "Y": 27, |
| "M": 28, |
| "H": 29, |
| "W": 30, |
| "C": 31, |
| } |
|
|
|
|
| class ProteinTokenizer(PreTrainedTokenizerFast): |
|
|
| def __init__( |
| self, |
| pad_token_id: int, |
| mask_token_id: int, |
| bos_token_id: int, |
| eos_token_id: int, |
| unk_token_id: int, |
| max_length: int, |
| other_special_token_ids: Optional[List[int]] = None, |
| ambiguous_token_ids: Optional[List[int]] = None, |
| **kwargs, |
| ): |
| """Vocabulary comprising the amino acids, and the special tokens <unk>, <bos>, <eos>, <pad> and <mask>. |
| |
| Args: |
| vocab_path (str): Path to the vocabulary file to load. |
| pad_token_id (int): <PAD> token index. |
| mask_token_id (int): <MASK> token index. |
| bos_token_id (int): <BOS> token index. |
| eos_token_id (int): <EOS> token index. |
| unk_token_id (int): <UNK> token index. |
| other_special_token_ids (Optional[List[int]]): List of additional special tokens. |
| """ |
| |
| token_to_id = dict() |
| id_to_token = dict() |
|
|
| for token, token_id in VOCAB.items(): |
| token = token.strip() |
| token_to_id[token] = token_id |
| id_to_token[token_id] = token |
|
|
| |
| tokenizer_object = Tokenizer(WordPiece(vocab=token_to_id, unk_token=id_to_token.get(unk_token_id))) |
|
|
| |
| tokenizer_object.pre_tokenizer = Split("", behavior="removed") |
|
|
| super().__init__( |
| pad_token_id=pad_token_id, |
| mask_token_id=mask_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| unk_token_id=unk_token_id, |
| pad_token=id_to_token.get(pad_token_id), |
| bos_token=id_to_token.get(bos_token_id), |
| eos_token=id_to_token.get(eos_token_id), |
| unk_token=id_to_token.get(unk_token_id), |
| mask_token=id_to_token.get(mask_token_id), |
| max_length=max_length, |
| ambiguous_token_ids=ambiguous_token_ids, |
| model_max_length=max_length, |
| padding_side="right", |
| truncation_side="right", |
| model_input_names=["input_ids", "attention_mask", "special_tokens_mask"], |
| tokenizer_object=tokenizer_object, |
| ) |
|
|
| if other_special_token_ids is not None: |
| self.add_special_tokens({"additional_special_tokens": list(id_to_token.get(i) for i in other_special_token_ids)}) |
|
|
| self.ambiguous_token_ids = ambiguous_token_ids |
|
|
| self.key_to_padding = {"input_ids": self.pad_token_id, "attention_mask": 0, "special_tokens_mask": 1, "position_ids": 0} |
| self.key_to_dtype = { |
| "input_ids": torch.long, |
| "attention_mask": torch.bool, |
| "special_tokens_mask": torch.bool, |
| "position_ids": torch.int, |
| } |
|
|
| def truncate( |
| self, |
| encoded_inputs: Dict[str, List[int]], |
| max_length: Optional[int] = None, |
| random_truncate: bool = True, |
| ) -> Dict[str, List[List[int]]]: |
| """ |
| Randomly truncate sequences in encoded inputs to the specified maximum length. |
| |
| Args: |
| encoded_inputs (BatchEncoding): Tokenized inputs with keys like 'input_ids' as tensors. |
| max_length (Optional[int]): Maximum length for truncation. Defaults to model's max length if None. |
| random_truncate (bool): Whether to randomly truncate sequences. |
| |
| Returns: |
| Dict[str, List[List[int]]]: Randomly truncated tokenized inputs. |
| """ |
|
|
| for i, sequence in enumerate(encoded_inputs["input_ids"]): |
| if len(sequence) > max_length: |
| if random_truncate: |
| offset = torch.randint(0, len(sequence) - max_length + 1, (1,)).item() |
| else: |
| offset = 0 |
| for key in encoded_inputs: |
| encoded_inputs[key][i] = encoded_inputs[key][i][offset : offset + max_length] |
|
|
| |
|
|
| return encoded_inputs |
|
|
| def remove_ambiguous(self, encoded_inputs: Dict[str, List[int]]) -> Dict[str, List[List[int]]]: |
| """ |
| Remove ambiguous amino acids from the input sequences. |
| |
| Args: |
| encoded_inputs (BatchEncoding): Tokenized inputs with keys like 'input_ids' as tensors. |
| |
| Returns: |
| Dict[str, List[List[int]]]: Tokenized inputs without ambiguous amino acids. |
| """ |
| filtered_inputs = {key: [] for key in encoded_inputs} |
|
|
| for i, sequence in enumerate(encoded_inputs["input_ids"]): |
| mask = [token not in self.ambiguous_token_ids for token in sequence] |
|
|
| |
| if not any(mask): |
| continue |
|
|
| |
| for key in encoded_inputs: |
| filtered_inputs[key].append(list(compress(encoded_inputs[key][i], mask))) |
|
|
| return filtered_inputs |
|
|
| def _pad( |
| self, |
| encoded_inputs: Dict[str, List[List[int]]], |
| padding: Union[bool, str] = True, |
| max_length: Optional[int] = None, |
| pad_to_multiple_of: int = 8, |
| **kwargs, |
| ) -> Dict[str, List[List[int]]]: |
| """ |
| Remove ambiguous amino acids from the input sequences. |
| |
| Args: |
| encoded_inputs (Dict[str, List[List[int]]): Tokenized inputs with keys like 'input_ids' as tensors. |
| |
| Returns: |
| Dict[str, List[List[int]]]: Tokenized inputs without ambiguous amino acids. |
| """ |
|
|
| if isinstance(encoded_inputs, list): |
| tmp = dict() |
| for key in encoded_inputs[0]: |
| tmp[key] = [encoded_inputs[i][key] for i in range(len(encoded_inputs))] |
| encoded_inputs = tmp |
|
|
| if max_length is None: |
| max_length = self.model_max_length |
|
|
| sequence_lengths = [len(sequence) for sequence in encoded_inputs["input_ids"]] |
| if padding == "longest" or padding == True: |
| max_length = min(max_length, max(sequence_lengths)) |
|
|
| if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): |
| max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of |
|
|
| for i, seq_len in enumerate(sequence_lengths): |
| if seq_len < max_length: |
| for key in encoded_inputs: |
| encoded_inputs[key][i] = encoded_inputs[key][i] + [self.key_to_padding[key]] * (max_length - seq_len) |
|
|
| return encoded_inputs |
|
|
| def pad( |
| self, |
| encoded_inputs: Dict[str, List[List[int]]], |
| padding: Union[bool, str] = True, |
| max_length: Optional[int] = None, |
| pad_to_multiple_of: int = 8, |
| return_tensors: str = "pt", |
| **kwargs, |
| ) -> Dict[str, List[List[int]]]: |
| """ |
| Remove ambiguous amino acids from the input sequences. |
| |
| Args: |
| encoded_inputs (Dict[str, List[List[int]]): Tokenized inputs with keys like 'input_ids' as tensors. |
| |
| Returns: |
| Dict[str, List[List[int]]]: Tokenized inputs without ambiguous amino acids. |
| """ |
|
|
| encoded_inputs = self._pad( |
| encoded_inputs, |
| padding, |
| max_length, |
| pad_to_multiple_of, |
| **kwargs, |
| ) |
|
|
| if return_tensors is not None: |
| return BatchEncoding(encoded_inputs, tensor_type=return_tensors) |
|
|
| return encoded_inputs |
|
|
| def __call__( |
| self, |
| text: str | List[str], |
| max_length: Optional[int] = None, |
| padding: Union[bool, str] = False, |
| truncation: bool = False, |
| random_truncate: bool = True, |
| remove_ambiguous: bool = False, |
| return_special_tokens_mask: bool = True, |
| return_tensors: str = None, |
| **kwargs, |
| ) -> Dict[str, Tensor]: |
|
|
| if isinstance(text, str): |
| encoded_inputs = self.__call__( |
| [text], |
| max_length, |
| padding, |
| truncation, |
| random_truncate, |
| remove_ambiguous, |
| return_special_tokens_mask, |
| return_tensors, |
| ) |
| for key in encoded_inputs: |
| encoded_inputs[key] = encoded_inputs[key][0] |
| return encoded_inputs |
|
|
| |
| encoded_inputs = super().__call__( |
| text, |
| padding=False, |
| truncation=False, |
| return_special_tokens_mask=return_special_tokens_mask, |
| **kwargs, |
| ) |
|
|
| if max_length is None: |
| max_length = self.model_max_length |
|
|
| |
| if truncation: |
| encoded_inputs = self.truncate( |
| encoded_inputs, |
| max_length=max_length, |
| random_truncate=random_truncate, |
| ) |
|
|
| |
| |
| encoded_inputs["position_ids"] = [list(range(len(seq))) for seq in encoded_inputs["input_ids"]] |
|
|
| |
| if remove_ambiguous and self.ambiguous_token_ids is not None: |
| encoded_inputs = self.remove_ambiguous(encoded_inputs) |
|
|
| |
| if padding: |
| encoded_inputs = self._pad(encoded_inputs, max_length=max_length, return_tensors=return_tensors) |
|
|
| if return_tensors is not None: |
| return BatchEncoding(encoded_inputs, tensor_type=return_tensors) |
|
|
| return encoded_inputs |
|
|