from __future__ import annotations from abc import ABC from copy import deepcopy from typing import List, Sequence import attr import torch from attr import asdict, define import src.data.esm.utils.constants.api as C from src.data.esm.tokenization import ( TokenizerCollectionProtocol, get_esm3_model_tokenizers, ) from src.data.esm.utils import encoding from src.data.esm.utils.constants.models import ESM3_OPEN_SMALL from src.data.esm.utils.misc import ( get_chainbreak_boundaries_from_sequence, ) from src.data.esm.utils.structure.protein_chain import ProteinChain from src.data.esm.utils.structure.protein_complex import ProteinComplex from src.data.esm.utils.types import FunctionAnnotation, PathOrBuffer class ProteinType(ABC): ... ## Basic Types @define class ESMProtein(ProteinType): # Tracks sequence: str | None = None secondary_structure: str | None = None sasa: list[float | None] | None = None function_annotations: list[FunctionAnnotation] | None = None coordinates: torch.Tensor | None = None # Metrics plddt: torch.Tensor | None = None ptm: torch.Tensor | None = None # When calling EvolutionaryScale API, use this flag to disclose any # sequences that may potentially have concerns. # Such sequences may not go through standard safety filter for approved users. # Reach out if interested in using this. potential_sequence_of_concern: bool = False def __len__(self): if self.sequence is not None: return len(self.sequence) elif self.secondary_structure is not None: return len(self.secondary_structure) elif self.sasa is not None: return len(self.sasa) elif self.coordinates is not None: return self.coordinates.size(0) else: raise ValueError("No track to determine length from.") @classmethod def from_pdb( cls, path: PathOrBuffer, chain_id: str = "detect", id: str | None = None, is_predicted: bool = False, ) -> ESMProtein: protein_chain = ProteinChain.from_pdb( path=path, chain_id=chain_id, id=id, is_predicted=is_predicted ) return cls.from_protein_chain(protein_chain) @classmethod def from_protein_chain( cls, protein_chain: ProteinChain, with_annotations: bool = False ) -> ESMProtein: # By default, we don't annotate with DSSP / SASA, which are expensive. # If mkdssp is installed, we can annotate with a flag. if with_annotations: return ESMProtein( sequence=protein_chain.sequence, secondary_structure=protein_chain.dssp().tolist(), sasa=protein_chain.sasa().tolist(), function_annotations=None, coordinates=torch.tensor(protein_chain.atom37_positions), ) else: return ESMProtein( sequence=protein_chain.sequence, secondary_structure=None, sasa=None, function_annotations=None, coordinates=torch.tensor(protein_chain.atom37_positions), ) @classmethod def from_protein_complex( cls, protein_complex: ProteinComplex, with_annotations: bool = False ) -> ESMProtein: if with_annotations: raise NotImplementedError( "Annotations are not supported for ProteinComplex yet." ) return ESMProtein( sequence=protein_complex.sequence, secondary_structure=None, sasa=None, function_annotations=None, coordinates=torch.tensor( protein_complex.atom37_positions, dtype=torch.float32 ), ) def to_pdb(self, pdb_path: PathOrBuffer) -> None: # Note: Will work for single chains as well and produce same pdb file protein_complex = self.to_protein_complex().infer_oxygen() protein_complex.to_pdb(pdb_path) def to_pdb_string(self) -> str: protein_chain = self.to_protein_chain() return protein_chain.to_pdb_string() def to_protein_chain(self) -> ProteinChain: if self.coordinates is None: raise ValueError("Coordinates are required to convert to a ProteinChain.") protein_chain = ProteinChain.from_atom37( atom37_positions=self.coordinates.to("cpu").numpy(), id=None, sequence=None if self.sequence is None else self.sequence.replace("_", "X"), chain_id=None, entity_id=None, residue_index=None, insertion_code=None, confidence=None if self.plddt is None else self.plddt.detach().cpu().numpy(), ) return protein_chain def to_protein_complex( self, copy_annotations_from_ground_truth: ProteinComplex | None = None ) -> ProteinComplex: assert ( self.sequence is not None ), "ESMProtein must have a sequence to convert to ProteinComplex" assert ( self.coordinates is not None ), "ESMProtein must have coordinates to convert to ProteinComplex" coords = self.coordinates.to("cpu").numpy() chain_boundaries = get_chainbreak_boundaries_from_sequence(self.sequence) if copy_annotations_from_ground_truth is not None: gt_chains = list(copy_annotations_from_ground_truth.chain_iter()) else: gt_chains = None pred_chains = [] for i, (start, end) in enumerate(chain_boundaries): pred_chain = ProteinChain.from_atom37( atom37_positions=coords[start:end], sequence=self.sequence[start:end], chain_id=gt_chains[i].chain_id if gt_chains is not None else None, entity_id=gt_chains[i].entity_id if gt_chains is not None else None, ) pred_chains.append(pred_chain) return ProteinComplex.from_chains(pred_chains) def copy(self) -> "ESMProtein": """Create a deep copy of the ESMProtein instance.""" return deepcopy(self) @define class ESMProteinTensor(ProteinType): sequence: torch.Tensor | None = None structure: torch.Tensor | None = None secondary_structure: torch.Tensor | None = None sasa: torch.Tensor | None = None function: torch.Tensor | None = None residue_annotations: torch.Tensor | None = None coordinates: torch.Tensor | None = None # When calling EvolutionaryScale API, use this flag to disclose any # sequences that may potentially have concerns. # Such sequences may not go through standard safety filter for approved users. # Reach out if interested in using this. potential_sequence_of_concern: bool = False def _detect_attribute(self, func, msg): mapped = { k: func(k, v) for k, v in asdict(self).items() if isinstance(v, torch.Tensor) } s = set(mapped.values()) if len(s) <= 0: return None if len(s) != 1: raise ValueError(f"Either no tracks or inconsistent {msg}: {mapped}") return next(iter(s)) def __len__(self) -> int: l = self._detect_attribute(lambda _, x: x.size(0), "length") return l if l is not None else 0 @property def device(self) -> str | torch.device: d = self._detect_attribute(lambda _, x: x.device, "device") assert d is not None return d def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor: def _to(name): v = getattr(self, name) if v is not None and isinstance(v, torch.Tensor): setattr(self, name, v.to(device_or_dtype)) for n in attr.fields(ESMProteinTensor): _to(n.name) return self @classmethod def empty( cls, length: int, tokenizers: TokenizerCollectionProtocol | None = None, device: torch.device | str = "cpu", ) -> ESMProteinTensor: if tokenizers is None: tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL) return ESMProteinTensor( sequence=encoding.get_default_sequence_tokens( length, tokenizers.sequence ).to(device), structure=encoding.get_default_structure_tokens( length, tokenizers.structure ).to(device), secondary_structure=encoding.get_default_secondary_structure_tokens( length, tokenizers.secondary_structure ).to(device), sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device), function=encoding.get_default_function_tokens( length, tokenizers.function ).to(device), residue_annotations=encoding.get_default_residue_annotation_tokens( length, tokenizers.residue_annotations ).to(device), ) def copy(self) -> ESMProteinTensor: """Create a deep copy of the ESMProteinTensor instance.""" return deepcopy(self) @define class ESMProteinError(Exception, ProteinType): error_code: int # Error code follows HTTP convention, i.e., 404 NotFoundError, 500 InternalError. error_msg: str ## High Level Endpoint Types @define class GenerationConfig: track: str = "" # By default avoid sampling the amino acid "X" invalid_ids: Sequence[int] = [24] # Controls the number of tokens to unmask during each round of iterative generation. schedule: str = attr.field( validator=attr.validators.in_(["cosine", "linear"]), default="cosine" ) # Controls which tokens to unmask during each round of iterative generation. # "random" will unmask a correct number of tokens randomly. # "entropy" will unmask the tokens with the lowest logit entropy first. strategy: str = attr.field( validator=attr.validators.in_(["random", "entropy"]), default="random" ) # Setting default to 20, as there is diminishing return for decoding steps more than 20. # Note that this needs to be less than or equal to the sequence length. num_steps: int = 20 temperature: float = 1.0 temperature_annealing: bool = False top_p: float = 1.0 condition_on_coordinates_only: bool = True def use_entropy_based_unmasking_strategy(self): """Use entropy based unmasking strategy during generation.""" self.schedule = "cosine" self.strategy = "entropy" self.temperature_annealing = False def use_generative_unmasking_strategy(self): """Use an unmasking strategy that produces more variety of generations.""" self.schedule = "cosine" self.strategy = "random" self.temperature_annealing = True @define class InverseFoldingConfig: invalid_ids: Sequence[int] = [] temperature: float = 1.0 ## Low Level Endpoint Types @define class SamplingTrackConfig: temperature: float = 1.0 top_p: float = 1.0 only_sample_masked_tokens: bool = True invalid_ids: Sequence[int] = [] topk_logprobs: int = 0 @define class SamplingConfig: sequence: SamplingTrackConfig | None = attr.field( default=None, metadata={"max_topk": C.MAX_TOPK_SEQUENCE} ) structure: SamplingTrackConfig | None = attr.field( default=None, metadata={"max_topk": C.MAX_TOPK_STRUCTURE} ) secondary_structure: SamplingTrackConfig | None = attr.field( default=None, metadata={"max_topk": C.MAX_TOPK_SECONDARY_STRUCTURE} ) sasa: SamplingTrackConfig | None = attr.field( default=None, metadata={"max_topk": C.MAX_TOPK_SASA} ) function: SamplingTrackConfig | None = attr.field( default=None, metadata={"max_topk": C.MAX_TOPK_FUNCTION} ) return_per_residue_embeddings: bool = False return_mean_embedding: bool = False @define class ForwardTrackData: sequence: torch.Tensor | None = None structure: torch.Tensor | None = None secondary_structure: torch.Tensor | None = None sasa: torch.Tensor | None = None function: torch.Tensor | None = None @define class LogitsConfig: # Logits. sequence: bool = False # Note that getting logits for tracks other than sequence # are not supported by Forge today, due to their impractical # data sizes. # These are of course supported when running local OSS models. structure: bool = False secondary_structure: bool = False sasa: bool = False function: bool = False residue_annotations: bool = False # Embeddings. return_embeddings: bool = False return_hidden_states: bool = False ith_hidden_layer: int = -1 @define class LogitsOutput: logits: ForwardTrackData | None = None embeddings: torch.Tensor | None = None # Residue annotations is multi-hot, so deserves special treatment # It's not a categorical distribution, but instead a bernoulli, so # softmax across the last dimension is _wrong_ residue_annotation_logits: torch.Tensor | None = None hidden_states: torch.Tensor | None = None @define class ForwardAndSampleOutput(LogitsOutput): protein_tensor: ESMProteinTensor = ESMProteinTensor() entropy: ForwardTrackData | None = None # Probability of sampled token prob: ForwardTrackData | None = None logprob: ForwardTrackData | None = None # Top probability at this position top_prob: ForwardTrackData | None = None topk_logprob: ForwardTrackData | None = None # Which tokens correspond to top probability topk_tokens: ForwardTrackData | None = None per_residue_embedding: torch.Tensor | None = None mean_embedding: torch.Tensor | None = None class ESM3InferenceClient(ABC): def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: # This is the easiest and most flexible way to run ESM3. Generate will # iteratively sample tokens an provide an output with the track specified # completely filled out, according to the GenerationConfig provided. # It is a local function wrapping calls for encode -> iterative_sampling -> decode. # if a ESMProteinTensor is provided, encode and decode are skipped raise NotImplementedError def batch_generate( self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig] ) -> Sequence[ProteinType]: # Same as generate(...), but generates a batch of proteins at once. raise NotImplementedError def encode(self, input: ESMProtein) -> ESMProteinTensor: # Encode allows for encoding RawRepresentation into TokenizedRepresentation. # This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion raise NotImplementedError def decode(self, input: ESMProteinTensor) -> ESMProtein: # Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates raise NotImplementedError def logits( self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig() ) -> LogitsOutput: # Our API generally discourages using raw forwards. # This is because sending logits can be prohibitively expensive. # Please use forward_and_sample instead. raise NotImplementedError def forward_and_sample( self, input: ESMProteinTensor, sampling_configuration: SamplingConfig ) -> ForwardAndSampleOutput: # forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`. # This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput # inference, as well as arbitrary chain-of-though invocations of ESM3. raise NotImplementedError @property def raw_model(self): # Get underlying esm3 model of an inference client. raise NotImplementedError class ESMCInferenceClient(ABC): def encode(self, input: ESMProtein) -> ESMProteinTensor: # Encode allows for encoding RawRepresentation into TokenizedRepresentation. raise NotImplementedError def decode(self, input: ESMProteinTensor) -> ESMProtein: # Decode is the inverse of encode raise NotImplementedError def logits( self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig() ) -> LogitsOutput: raise NotImplementedError @property def raw_model(self): # Get underlying esmc model of an inference client. raise NotImplementedError