| | from typing import Sequence |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | |
| | from src.data.esm.tokenization.function_tokenizer import ( |
| | InterProQuantizedTokenizer as EsmFunctionTokenizer, |
| | ) |
| |
|
| | from src.data.esm.tokenization.residue_tokenizer import ( |
| | ResidueAnnotationsTokenizer, |
| | ) |
| | from src.data.esm.tokenization.sasa_tokenizer import ( |
| | SASADiscretizingTokenizer, |
| | ) |
| | from src.data.esm.tokenization.sequence_tokenizer import ( |
| | EsmSequenceTokenizer, |
| | ) |
| | from src.data.esm.tokenization.ss_tokenizer import ( |
| | SecondaryStructureTokenizer, |
| | ) |
| | from src.data.esm.tokenization.structure_tokenizer import ( |
| | StructureTokenizer, |
| | ) |
| | from src.data.esm.utils.constants import esm3 as C |
| | from src.data.esm.utils.function.encode_decode import ( |
| | encode_function_annotations, |
| | ) |
| | from src.data.esm.utils.structure.protein_chain import ProteinChain |
| | from src.data.esm.utils.types import FunctionAnnotation |
| |
|
| |
|
| | |
| | def get_default_sequence(sequence_length: int) -> str: |
| | return C.MASK_STR_SHORT * sequence_length |
| |
|
| |
|
| | def get_default_secondary_structure(sequence_length: int) -> str: |
| | return C.MASK_STR_SHORT * sequence_length |
| |
|
| |
|
| | def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]: |
| | return [None] * sequence_length |
| |
|
| |
|
| | |
| | def tokenize_sequence( |
| | sequence: str, |
| | sequence_tokenizer: EsmSequenceTokenizer, |
| | add_special_tokens: bool = True, |
| | ) -> torch.Tensor: |
| | sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token) |
| | sequence_tokens = sequence_tokenizer.encode( |
| | sequence, add_special_tokens=add_special_tokens |
| | ) |
| | sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64) |
| | return sequence_tokens |
| |
|
| |
|
| | def tokenize_structure( |
| | coordinates: torch.Tensor, |
| | |
| | structure_encoder, |
| | structure_tokenizer: StructureTokenizer, |
| | reference_sequence: str = "", |
| | add_special_tokens: bool = True, |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | device = next(structure_encoder.parameters()).device |
| | chain = ProteinChain.from_atom37( |
| | coordinates, sequence=reference_sequence if reference_sequence else None |
| | ) |
| |
|
| | |
| | if reference_sequence and len(reference_sequence) != coordinates.size(0): |
| | raise ValueError( |
| | f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})" |
| | ) |
| |
|
| | left_pad = 0 |
| | right_pad = 0 |
| |
|
| | if add_special_tokens: |
| | left_pad += 1 |
| | right_pad += 1 |
| |
|
| | coordinates, plddt, residue_index = chain.to_structure_encoder_inputs() |
| | coordinates = coordinates.to(device) |
| | plddt = plddt.to(device) |
| | residue_index = residue_index.to(device) |
| | _, structure_tokens = structure_encoder.encode( |
| | coordinates, residue_index=residue_index |
| | ) |
| | coordinates = torch.squeeze(coordinates, dim=0) |
| | plddt = torch.squeeze(plddt, dim=0) |
| | structure_tokens = torch.squeeze(structure_tokens, dim=0) |
| |
|
| | |
| | if add_special_tokens: |
| | coordinates = F.pad( |
| | coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf |
| | ) |
| | plddt = F.pad(plddt, (left_pad, right_pad), value=0) |
| | structure_tokens = F.pad( |
| | structure_tokens, |
| | (left_pad, right_pad), |
| | value=structure_tokenizer.mask_token_id, |
| | ) |
| | structure_tokens[0] = structure_tokenizer.bos_token_id |
| | structure_tokens[-1] = structure_tokenizer.eos_token_id |
| | return coordinates, plddt, structure_tokens |
| |
|
| |
|
| | def tokenize_secondary_structure( |
| | secondary_structure: str | Sequence[str], |
| | secondary_structure_tokenizer: SecondaryStructureTokenizer, |
| | add_special_tokens: bool = True, |
| | ) -> torch.Tensor: |
| | if isinstance(secondary_structure, str): |
| | |
| | secondary_structure = secondary_structure.replace( |
| | secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT |
| | ) |
| |
|
| | |
| | secondary_structure = [char for char in secondary_structure] |
| |
|
| | |
| | secondary_structure = [ |
| | secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char |
| | for char in secondary_structure |
| | ] |
| |
|
| | secondary_structure_tokens = secondary_structure_tokenizer.encode( |
| | secondary_structure, add_special_tokens=add_special_tokens |
| | ) |
| | return secondary_structure_tokens |
| |
|
| |
|
| | def tokenize_sasa( |
| | sasa: Sequence[float | str | None], |
| | sasa_tokenizer: SASADiscretizingTokenizer, |
| | add_special_tokens: bool = True, |
| | ): |
| | sasa_tokens = sasa_tokenizer.encode( |
| | [sasa_tokenizer.mask_token if value is None else value for value in sasa], |
| | add_special_tokens=add_special_tokens, |
| | ) |
| | return sasa_tokens |
| |
|
| |
|
| | def tokenize_function_annotations( |
| | function_annotations: Sequence[FunctionAnnotation], |
| | reference_sequence: str, |
| | function_tokenizer: EsmFunctionTokenizer, |
| | residue_annotation_tokenizer: ResidueAnnotationsTokenizer, |
| | add_special_tokens: bool = True, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | function_tokens, residue_annotation_tokens = encode_function_annotations( |
| | sequence=reference_sequence, |
| | function_annotations=function_annotations, |
| | function_tokens_tokenizer=function_tokenizer, |
| | residue_annotations_tokenizer=residue_annotation_tokenizer, |
| | add_special_tokens=add_special_tokens, |
| | ) |
| | return function_tokens, residue_annotation_tokens |
| |
|
| |
|
| |
|
| |
|
| | |
| | def get_default_sequence_tokens( |
| | sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer |
| | ) -> torch.Tensor: |
| | assert sequence_tokenizer.mask_token_id is not None |
| | assert sequence_tokenizer.bos_token_id is not None |
| | assert sequence_tokenizer.eos_token_id is not None |
| |
|
| | sequence_tokens = torch.full( |
| | (sequence_length + 2,), sequence_tokenizer.mask_token_id |
| | ) |
| | sequence_tokens[0] = sequence_tokenizer.bos_token_id |
| | sequence_tokens[-1] = sequence_tokenizer.eos_token_id |
| |
|
| | return sequence_tokens |
| |
|
| |
|
| | def get_default_structure_tokens( |
| | sequence_length: int, structure_tokenizer: StructureTokenizer |
| | ) -> torch.Tensor: |
| | structure_tokens = ( |
| | torch.ones((sequence_length + 2,), dtype=torch.int64) |
| | * structure_tokenizer.mask_token_id |
| | ) |
| | |
| | structure_tokens[0] = structure_tokenizer.bos_token_id |
| | structure_tokens[-1] = structure_tokenizer.eos_token_id |
| | return structure_tokens |
| |
|
| |
|
| | def get_default_secondary_structure_tokens( |
| | sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer |
| | ) -> torch.Tensor: |
| | ss8_tokens = torch.full( |
| | (sequence_length + 2,), secondary_structure_tokenizer.mask_token_id |
| | ) |
| | ss8_tokens[0] = secondary_structure_tokenizer.bos_token_id |
| | ss8_tokens[-1] = secondary_structure_tokenizer.eos_token_id |
| |
|
| | return ss8_tokens |
| |
|
| |
|
| | def get_default_sasa_tokens( |
| | sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer |
| | ) -> torch.Tensor: |
| | sasa_tokens = torch.full((sequence_length + 2,), sasa_tokenizer.mask_token_id) |
| | sasa_tokens[0] = sasa_tokenizer.bos_token_id |
| | sasa_tokens[-1] = sasa_tokenizer.eos_token_id |
| | return sasa_tokens |
| |
|
| |
|
| | def get_default_function_tokens( |
| | sequence_length: int, function_tokenizer: EsmFunctionTokenizer |
| | ) -> torch.Tensor: |
| | function_tokens = ( |
| | torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64) |
| | * function_tokenizer.pad_token_id |
| | ) |
| | |
| | function_tokens[0] = function_tokenizer.bos_token_id |
| | function_tokens[-1] = function_tokenizer.eos_token_id |
| | return function_tokens |
| |
|
| |
|
| | def get_default_residue_annotation_tokens( |
| | sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer |
| | ) -> torch.Tensor: |
| | residue_annotation_tokens = ( |
| | torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64) |
| | * residue_annotation_tokenizer.pad_token_id |
| | ) |
| | |
| | residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id |
| | residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id |
| | return residue_annotation_tokens |
| |
|
| |
|
| |
|