| | from __future__ import annotations |
| |
|
| | from abc import ABC |
| | from copy import deepcopy |
| | from typing import List, Sequence |
| |
|
| | import attr |
| | import torch |
| | from attr import asdict, define |
| |
|
| | import src.data.esm.utils.constants.api as C |
| | from src.data.esm.tokenization import ( |
| | TokenizerCollectionProtocol, |
| | get_esm3_model_tokenizers, |
| | ) |
| | from src.data.esm.utils import encoding |
| | from src.data.esm.utils.constants.models import ESM3_OPEN_SMALL |
| | from src.data.esm.utils.misc import ( |
| | get_chainbreak_boundaries_from_sequence, |
| | ) |
| | from src.data.esm.utils.structure.protein_chain import ProteinChain |
| | from src.data.esm.utils.structure.protein_complex import ProteinComplex |
| | from src.data.esm.utils.types import FunctionAnnotation, PathOrBuffer |
| |
|
| |
|
| | class ProteinType(ABC): ... |
| |
|
| |
|
| | |
| | @define |
| | class ESMProtein(ProteinType): |
| | |
| | sequence: str | None = None |
| | secondary_structure: str | None = None |
| | sasa: list[float | None] | None = None |
| | function_annotations: list[FunctionAnnotation] | None = None |
| | coordinates: torch.Tensor | None = None |
| |
|
| | |
| | plddt: torch.Tensor | None = None |
| | ptm: torch.Tensor | None = None |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | potential_sequence_of_concern: bool = False |
| |
|
| | def __len__(self): |
| | if self.sequence is not None: |
| | return len(self.sequence) |
| | elif self.secondary_structure is not None: |
| | return len(self.secondary_structure) |
| | elif self.sasa is not None: |
| | return len(self.sasa) |
| | elif self.coordinates is not None: |
| | return self.coordinates.size(0) |
| | else: |
| | raise ValueError("No track to determine length from.") |
| |
|
| | @classmethod |
| | def from_pdb( |
| | cls, |
| | path: PathOrBuffer, |
| | chain_id: str = "detect", |
| | id: str | None = None, |
| | is_predicted: bool = False, |
| | ) -> ESMProtein: |
| | protein_chain = ProteinChain.from_pdb( |
| | path=path, chain_id=chain_id, id=id, is_predicted=is_predicted |
| | ) |
| | return cls.from_protein_chain(protein_chain) |
| |
|
| | @classmethod |
| | def from_protein_chain( |
| | cls, protein_chain: ProteinChain, with_annotations: bool = False |
| | ) -> ESMProtein: |
| | |
| | |
| | if with_annotations: |
| | return ESMProtein( |
| | sequence=protein_chain.sequence, |
| | secondary_structure=protein_chain.dssp().tolist(), |
| | sasa=protein_chain.sasa().tolist(), |
| | function_annotations=None, |
| | coordinates=torch.tensor(protein_chain.atom37_positions), |
| | ) |
| | else: |
| | return ESMProtein( |
| | sequence=protein_chain.sequence, |
| | secondary_structure=None, |
| | sasa=None, |
| | function_annotations=None, |
| | coordinates=torch.tensor(protein_chain.atom37_positions), |
| | ) |
| |
|
| | @classmethod |
| | def from_protein_complex( |
| | cls, protein_complex: ProteinComplex, with_annotations: bool = False |
| | ) -> ESMProtein: |
| | if with_annotations: |
| | raise NotImplementedError( |
| | "Annotations are not supported for ProteinComplex yet." |
| | ) |
| |
|
| | return ESMProtein( |
| | sequence=protein_complex.sequence, |
| | secondary_structure=None, |
| | sasa=None, |
| | function_annotations=None, |
| | coordinates=torch.tensor( |
| | protein_complex.atom37_positions, dtype=torch.float32 |
| | ), |
| | ) |
| |
|
| | def to_pdb(self, pdb_path: PathOrBuffer) -> None: |
| | |
| | protein_complex = self.to_protein_complex().infer_oxygen() |
| | protein_complex.to_pdb(pdb_path) |
| |
|
| | def to_pdb_string(self) -> str: |
| | protein_chain = self.to_protein_chain() |
| | return protein_chain.to_pdb_string() |
| |
|
| | def to_protein_chain(self) -> ProteinChain: |
| | if self.coordinates is None: |
| | raise ValueError("Coordinates are required to convert to a ProteinChain.") |
| | protein_chain = ProteinChain.from_atom37( |
| | atom37_positions=self.coordinates.to("cpu").numpy(), |
| | id=None, |
| | sequence=None if self.sequence is None else self.sequence.replace("_", "X"), |
| | chain_id=None, |
| | entity_id=None, |
| | residue_index=None, |
| | insertion_code=None, |
| | confidence=None |
| | if self.plddt is None |
| | else self.plddt.detach().cpu().numpy(), |
| | ) |
| | return protein_chain |
| |
|
| | def to_protein_complex( |
| | self, copy_annotations_from_ground_truth: ProteinComplex | None = None |
| | ) -> ProteinComplex: |
| | assert ( |
| | self.sequence is not None |
| | ), "ESMProtein must have a sequence to convert to ProteinComplex" |
| | assert ( |
| | self.coordinates is not None |
| | ), "ESMProtein must have coordinates to convert to ProteinComplex" |
| | coords = self.coordinates.to("cpu").numpy() |
| |
|
| | chain_boundaries = get_chainbreak_boundaries_from_sequence(self.sequence) |
| | if copy_annotations_from_ground_truth is not None: |
| | gt_chains = list(copy_annotations_from_ground_truth.chain_iter()) |
| | else: |
| | gt_chains = None |
| | pred_chains = [] |
| | for i, (start, end) in enumerate(chain_boundaries): |
| | pred_chain = ProteinChain.from_atom37( |
| | atom37_positions=coords[start:end], |
| | sequence=self.sequence[start:end], |
| | chain_id=gt_chains[i].chain_id if gt_chains is not None else None, |
| | entity_id=gt_chains[i].entity_id if gt_chains is not None else None, |
| | ) |
| | pred_chains.append(pred_chain) |
| | return ProteinComplex.from_chains(pred_chains) |
| |
|
| | def copy(self) -> "ESMProtein": |
| | """Create a deep copy of the ESMProtein instance.""" |
| | return deepcopy(self) |
| |
|
| |
|
| | @define |
| | class ESMProteinTensor(ProteinType): |
| | sequence: torch.Tensor | None = None |
| | structure: torch.Tensor | None = None |
| | secondary_structure: torch.Tensor | None = None |
| | sasa: torch.Tensor | None = None |
| | function: torch.Tensor | None = None |
| | residue_annotations: torch.Tensor | None = None |
| | coordinates: torch.Tensor | None = None |
| |
|
| | |
| | |
| | |
| | |
| | potential_sequence_of_concern: bool = False |
| |
|
| | def _detect_attribute(self, func, msg): |
| | mapped = { |
| | k: func(k, v) |
| | for k, v in asdict(self).items() |
| | if isinstance(v, torch.Tensor) |
| | } |
| | s = set(mapped.values()) |
| | if len(s) <= 0: |
| | return None |
| | if len(s) != 1: |
| | raise ValueError(f"Either no tracks or inconsistent {msg}: {mapped}") |
| | return next(iter(s)) |
| |
|
| | def __len__(self) -> int: |
| | l = self._detect_attribute(lambda _, x: x.size(0), "length") |
| | return l if l is not None else 0 |
| |
|
| | @property |
| | def device(self) -> str | torch.device: |
| | d = self._detect_attribute(lambda _, x: x.device, "device") |
| | assert d is not None |
| | return d |
| |
|
| | def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor: |
| | def _to(name): |
| | v = getattr(self, name) |
| | if v is not None and isinstance(v, torch.Tensor): |
| | setattr(self, name, v.to(device_or_dtype)) |
| |
|
| | for n in attr.fields(ESMProteinTensor): |
| | _to(n.name) |
| |
|
| | return self |
| |
|
| | @classmethod |
| | def empty( |
| | cls, |
| | length: int, |
| | tokenizers: TokenizerCollectionProtocol | None = None, |
| | device: torch.device | str = "cpu", |
| | ) -> ESMProteinTensor: |
| | if tokenizers is None: |
| | tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL) |
| |
|
| | return ESMProteinTensor( |
| | sequence=encoding.get_default_sequence_tokens( |
| | length, tokenizers.sequence |
| | ).to(device), |
| | structure=encoding.get_default_structure_tokens( |
| | length, tokenizers.structure |
| | ).to(device), |
| | secondary_structure=encoding.get_default_secondary_structure_tokens( |
| | length, tokenizers.secondary_structure |
| | ).to(device), |
| | sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device), |
| | function=encoding.get_default_function_tokens( |
| | length, tokenizers.function |
| | ).to(device), |
| | residue_annotations=encoding.get_default_residue_annotation_tokens( |
| | length, tokenizers.residue_annotations |
| | ).to(device), |
| | ) |
| |
|
| | def copy(self) -> ESMProteinTensor: |
| | """Create a deep copy of the ESMProteinTensor instance.""" |
| | return deepcopy(self) |
| |
|
| |
|
| | @define |
| | class ESMProteinError(Exception, ProteinType): |
| | error_code: int |
| | error_msg: str |
| |
|
| |
|
| | |
| | @define |
| | class GenerationConfig: |
| | track: str = "" |
| | |
| | invalid_ids: Sequence[int] = [24] |
| | |
| | schedule: str = attr.field( |
| | validator=attr.validators.in_(["cosine", "linear"]), default="cosine" |
| | ) |
| | |
| | |
| | |
| | strategy: str = attr.field( |
| | validator=attr.validators.in_(["random", "entropy"]), default="random" |
| | ) |
| | |
| | |
| | num_steps: int = 20 |
| | temperature: float = 1.0 |
| | temperature_annealing: bool = False |
| | top_p: float = 1.0 |
| | condition_on_coordinates_only: bool = True |
| |
|
| | def use_entropy_based_unmasking_strategy(self): |
| | """Use entropy based unmasking strategy during generation.""" |
| | self.schedule = "cosine" |
| | self.strategy = "entropy" |
| | self.temperature_annealing = False |
| |
|
| | def use_generative_unmasking_strategy(self): |
| | """Use an unmasking strategy that produces more variety of generations.""" |
| | self.schedule = "cosine" |
| | self.strategy = "random" |
| | self.temperature_annealing = True |
| |
|
| |
|
| | @define |
| | class InverseFoldingConfig: |
| | invalid_ids: Sequence[int] = [] |
| | temperature: float = 1.0 |
| |
|
| |
|
| | |
| | @define |
| | class SamplingTrackConfig: |
| | temperature: float = 1.0 |
| | top_p: float = 1.0 |
| | only_sample_masked_tokens: bool = True |
| | invalid_ids: Sequence[int] = [] |
| | topk_logprobs: int = 0 |
| |
|
| |
|
| | @define |
| | class SamplingConfig: |
| | sequence: SamplingTrackConfig | None = attr.field( |
| | default=None, metadata={"max_topk": C.MAX_TOPK_SEQUENCE} |
| | ) |
| | structure: SamplingTrackConfig | None = attr.field( |
| | default=None, metadata={"max_topk": C.MAX_TOPK_STRUCTURE} |
| | ) |
| | secondary_structure: SamplingTrackConfig | None = attr.field( |
| | default=None, metadata={"max_topk": C.MAX_TOPK_SECONDARY_STRUCTURE} |
| | ) |
| | sasa: SamplingTrackConfig | None = attr.field( |
| | default=None, metadata={"max_topk": C.MAX_TOPK_SASA} |
| | ) |
| | function: SamplingTrackConfig | None = attr.field( |
| | default=None, metadata={"max_topk": C.MAX_TOPK_FUNCTION} |
| | ) |
| |
|
| | return_per_residue_embeddings: bool = False |
| | return_mean_embedding: bool = False |
| |
|
| |
|
| | @define |
| | class ForwardTrackData: |
| | sequence: torch.Tensor | None = None |
| | structure: torch.Tensor | None = None |
| | secondary_structure: torch.Tensor | None = None |
| | sasa: torch.Tensor | None = None |
| | function: torch.Tensor | None = None |
| |
|
| |
|
| | @define |
| | class LogitsConfig: |
| | |
| | sequence: bool = False |
| |
|
| | |
| | |
| | |
| | |
| | structure: bool = False |
| | secondary_structure: bool = False |
| | sasa: bool = False |
| | function: bool = False |
| | residue_annotations: bool = False |
| |
|
| | |
| | return_embeddings: bool = False |
| | return_hidden_states: bool = False |
| | ith_hidden_layer: int = -1 |
| |
|
| |
|
| | @define |
| | class LogitsOutput: |
| | logits: ForwardTrackData | None = None |
| | embeddings: torch.Tensor | None = None |
| |
|
| | |
| | |
| | |
| | residue_annotation_logits: torch.Tensor | None = None |
| | hidden_states: torch.Tensor | None = None |
| |
|
| |
|
| | @define |
| | class ForwardAndSampleOutput(LogitsOutput): |
| | protein_tensor: ESMProteinTensor = ESMProteinTensor() |
| |
|
| | entropy: ForwardTrackData | None = None |
| | |
| | prob: ForwardTrackData | None = None |
| | logprob: ForwardTrackData | None = None |
| | |
| | top_prob: ForwardTrackData | None = None |
| | topk_logprob: ForwardTrackData | None = None |
| | |
| | topk_tokens: ForwardTrackData | None = None |
| | per_residue_embedding: torch.Tensor | None = None |
| | mean_embedding: torch.Tensor | None = None |
| |
|
| |
|
| | class ESM3InferenceClient(ABC): |
| | def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: |
| | |
| | |
| | |
| | |
| | |
| | raise NotImplementedError |
| |
|
| | def batch_generate( |
| | self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig] |
| | ) -> Sequence[ProteinType]: |
| | |
| | raise NotImplementedError |
| |
|
| | def encode(self, input: ESMProtein) -> ESMProteinTensor: |
| | |
| | |
| | raise NotImplementedError |
| |
|
| | def decode(self, input: ESMProteinTensor) -> ESMProtein: |
| | |
| | raise NotImplementedError |
| |
|
| | def logits( |
| | self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig() |
| | ) -> LogitsOutput: |
| | |
| | |
| | |
| | raise NotImplementedError |
| |
|
| | def forward_and_sample( |
| | self, input: ESMProteinTensor, sampling_configuration: SamplingConfig |
| | ) -> ForwardAndSampleOutput: |
| | |
| | |
| | |
| | raise NotImplementedError |
| |
|
| | @property |
| | def raw_model(self): |
| | |
| | raise NotImplementedError |
| |
|
| |
|
| | class ESMCInferenceClient(ABC): |
| | def encode(self, input: ESMProtein) -> ESMProteinTensor: |
| | |
| | raise NotImplementedError |
| |
|
| | def decode(self, input: ESMProteinTensor) -> ESMProtein: |
| | |
| | raise NotImplementedError |
| |
|
| | def logits( |
| | self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig() |
| | ) -> LogitsOutput: |
| | raise NotImplementedError |
| |
|
| | @property |
| | def raw_model(self): |
| | |
| | raise NotImplementedError |
| |
|