| from typing import Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| |
| from src.data.esm.tokenization.function_tokenizer import ( |
| InterProQuantizedTokenizer as EsmFunctionTokenizer, |
| ) |
|
|
| from src.data.esm.tokenization.residue_tokenizer import ( |
| ResidueAnnotationsTokenizer, |
| ) |
| from src.data.esm.tokenization.sasa_tokenizer import ( |
| SASADiscretizingTokenizer, |
| ) |
| from src.data.esm.tokenization.sequence_tokenizer import ( |
| EsmSequenceTokenizer, |
| ) |
| from src.data.esm.tokenization.ss_tokenizer import ( |
| SecondaryStructureTokenizer, |
| ) |
| from src.data.esm.tokenization.structure_tokenizer import ( |
| StructureTokenizer, |
| ) |
| from src.data.esm.utils.constants import esm3 as C |
| from src.data.esm.utils.function.encode_decode import ( |
| encode_function_annotations, |
| ) |
| from src.data.esm.utils.structure.protein_chain import ProteinChain |
| from src.data.esm.utils.types import FunctionAnnotation |
|
|
|
|
| |
| def get_default_sequence(sequence_length: int) -> str: |
| return C.MASK_STR_SHORT * sequence_length |
|
|
|
|
| def get_default_secondary_structure(sequence_length: int) -> str: |
| return C.MASK_STR_SHORT * sequence_length |
|
|
|
|
| def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]: |
| return [None] * sequence_length |
|
|
|
|
| |
| def tokenize_sequence( |
| sequence: str, |
| sequence_tokenizer: EsmSequenceTokenizer, |
| add_special_tokens: bool = True, |
| ) -> torch.Tensor: |
| sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token) |
| sequence_tokens = sequence_tokenizer.encode( |
| sequence, add_special_tokens=add_special_tokens |
| ) |
| sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64) |
| return sequence_tokens |
|
|
|
|
| def tokenize_structure( |
| coordinates: torch.Tensor, |
| |
| structure_encoder, |
| structure_tokenizer: StructureTokenizer, |
| reference_sequence: str = "", |
| add_special_tokens: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| device = next(structure_encoder.parameters()).device |
| chain = ProteinChain.from_atom37( |
| coordinates, sequence=reference_sequence if reference_sequence else None |
| ) |
|
|
| |
| if reference_sequence and len(reference_sequence) != coordinates.size(0): |
| raise ValueError( |
| f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})" |
| ) |
|
|
| left_pad = 0 |
| right_pad = 0 |
|
|
| if add_special_tokens: |
| left_pad += 1 |
| right_pad += 1 |
|
|
| coordinates, plddt, residue_index = chain.to_structure_encoder_inputs() |
| coordinates = coordinates.to(device) |
| plddt = plddt.to(device) |
| residue_index = residue_index.to(device) |
| _, structure_tokens = structure_encoder.encode( |
| coordinates, residue_index=residue_index |
| ) |
| coordinates = torch.squeeze(coordinates, dim=0) |
| plddt = torch.squeeze(plddt, dim=0) |
| structure_tokens = torch.squeeze(structure_tokens, dim=0) |
|
|
| |
| if add_special_tokens: |
| coordinates = F.pad( |
| coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf |
| ) |
| plddt = F.pad(plddt, (left_pad, right_pad), value=0) |
| structure_tokens = F.pad( |
| structure_tokens, |
| (left_pad, right_pad), |
| value=structure_tokenizer.mask_token_id, |
| ) |
| structure_tokens[0] = structure_tokenizer.bos_token_id |
| structure_tokens[-1] = structure_tokenizer.eos_token_id |
| return coordinates, plddt, structure_tokens |
|
|
|
|
| def tokenize_secondary_structure( |
| secondary_structure: str | Sequence[str], |
| secondary_structure_tokenizer: SecondaryStructureTokenizer, |
| add_special_tokens: bool = True, |
| ) -> torch.Tensor: |
| if isinstance(secondary_structure, str): |
| |
| secondary_structure = secondary_structure.replace( |
| secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT |
| ) |
|
|
| |
| secondary_structure = [char for char in secondary_structure] |
|
|
| |
| secondary_structure = [ |
| secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char |
| for char in secondary_structure |
| ] |
|
|
| secondary_structure_tokens = secondary_structure_tokenizer.encode( |
| secondary_structure, add_special_tokens=add_special_tokens |
| ) |
| return secondary_structure_tokens |
|
|
|
|
| def tokenize_sasa( |
| sasa: Sequence[float | str | None], |
| sasa_tokenizer: SASADiscretizingTokenizer, |
| add_special_tokens: bool = True, |
| ): |
| sasa_tokens = sasa_tokenizer.encode( |
| [sasa_tokenizer.mask_token if value is None else value for value in sasa], |
| add_special_tokens=add_special_tokens, |
| ) |
| return sasa_tokens |
|
|
|
|
| def tokenize_function_annotations( |
| function_annotations: Sequence[FunctionAnnotation], |
| reference_sequence: str, |
| function_tokenizer: EsmFunctionTokenizer, |
| residue_annotation_tokenizer: ResidueAnnotationsTokenizer, |
| add_special_tokens: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| function_tokens, residue_annotation_tokens = encode_function_annotations( |
| sequence=reference_sequence, |
| function_annotations=function_annotations, |
| function_tokens_tokenizer=function_tokenizer, |
| residue_annotations_tokenizer=residue_annotation_tokenizer, |
| add_special_tokens=add_special_tokens, |
| ) |
| return function_tokens, residue_annotation_tokens |
|
|
|
|
|
|
|
|
| |
| def get_default_sequence_tokens( |
| sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer |
| ) -> torch.Tensor: |
| assert sequence_tokenizer.mask_token_id is not None |
| assert sequence_tokenizer.bos_token_id is not None |
| assert sequence_tokenizer.eos_token_id is not None |
|
|
| sequence_tokens = torch.full( |
| (sequence_length + 2,), sequence_tokenizer.mask_token_id |
| ) |
| sequence_tokens[0] = sequence_tokenizer.bos_token_id |
| sequence_tokens[-1] = sequence_tokenizer.eos_token_id |
|
|
| return sequence_tokens |
|
|
|
|
| def get_default_structure_tokens( |
| sequence_length: int, structure_tokenizer: StructureTokenizer |
| ) -> torch.Tensor: |
| structure_tokens = ( |
| torch.ones((sequence_length + 2,), dtype=torch.int64) |
| * structure_tokenizer.mask_token_id |
| ) |
| |
| structure_tokens[0] = structure_tokenizer.bos_token_id |
| structure_tokens[-1] = structure_tokenizer.eos_token_id |
| return structure_tokens |
|
|
|
|
| def get_default_secondary_structure_tokens( |
| sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer |
| ) -> torch.Tensor: |
| ss8_tokens = torch.full( |
| (sequence_length + 2,), secondary_structure_tokenizer.mask_token_id |
| ) |
| ss8_tokens[0] = secondary_structure_tokenizer.bos_token_id |
| ss8_tokens[-1] = secondary_structure_tokenizer.eos_token_id |
|
|
| return ss8_tokens |
|
|
|
|
| def get_default_sasa_tokens( |
| sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer |
| ) -> torch.Tensor: |
| sasa_tokens = torch.full((sequence_length + 2,), sasa_tokenizer.mask_token_id) |
| sasa_tokens[0] = sasa_tokenizer.bos_token_id |
| sasa_tokens[-1] = sasa_tokenizer.eos_token_id |
| return sasa_tokens |
|
|
|
|
| def get_default_function_tokens( |
| sequence_length: int, function_tokenizer: EsmFunctionTokenizer |
| ) -> torch.Tensor: |
| function_tokens = ( |
| torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64) |
| * function_tokenizer.pad_token_id |
| ) |
| |
| function_tokens[0] = function_tokenizer.bos_token_id |
| function_tokens[-1] = function_tokenizer.eos_token_id |
| return function_tokens |
|
|
|
|
| def get_default_residue_annotation_tokens( |
| sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer |
| ) -> torch.Tensor: |
| residue_annotation_tokens = ( |
| torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64) |
| * residue_annotation_tokenizer.pad_token_id |
| ) |
| |
| residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id |
| residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id |
| return residue_annotation_tokens |
|
|
|
|
|
|