| | import base64 |
| | from concurrent.futures import ThreadPoolExecutor |
| | from contextvars import ContextVar |
| | from functools import wraps |
| | from typing import Sequence |
| | from urllib.parse import urljoin |
| |
|
| | import requests |
| | import torch |
| | from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential |
| |
|
| | from src.data.esm.sdk.api import ( |
| | ESM3InferenceClient, |
| | ESMProtein, |
| | ESMProteinError, |
| | ESMProteinTensor, |
| | ForwardAndSampleOutput, |
| | ForwardTrackData, |
| | GenerationConfig, |
| | InverseFoldingConfig, |
| | LogitsConfig, |
| | LogitsOutput, |
| | ProteinType, |
| | SamplingConfig, |
| | SamplingTrackConfig, |
| | ) |
| | from src.data.esm.utils.misc import ( |
| | deserialize_tensors, |
| | maybe_list, |
| | maybe_tensor, |
| | ) |
| | from src.data.esm.utils.sampling import validate_sampling_config |
| | from src.data.esm.utils.types import FunctionAnnotation |
| |
|
| | skip_retries_var = ContextVar("skip_retries", default=False) |
| |
|
| |
|
| | def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None: |
| | if l is None or len(l) <= 0: |
| | return None |
| | return [FunctionAnnotation(*t) for t in l] |
| |
|
| |
|
| | def retry_if_specific_error(exception): |
| | """ |
| | We only retry on specific errors. |
| | Currently we retry for 502 (bad gateway) and 429 (rate limit) |
| | """ |
| | return isinstance(exception, ESMProteinError) and exception.error_code in { |
| | 429, |
| | 502, |
| | 504, |
| | } |
| |
|
| |
|
| | def log_retry_attempt(retry_state): |
| | print( |
| | f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}" |
| | ) |
| |
|
| |
|
| | def _validate_protein_tensor_input(input): |
| | if isinstance(input, ESMProteinError): |
| | raise ValueError( |
| | f"Input must be an ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}" |
| | ) |
| | if not isinstance(input, ESMProteinTensor): |
| | raise ValueError( |
| | f"Input must be an ESMProteinTensor instance, but received {type(input)} instead. " |
| | "Use encode() API to encode an ESMProtein into ESMProteinTensor." |
| | ) |
| |
|
| |
|
| | class SequenceStructureForgeInferenceClient: |
| | def __init__( |
| | self, |
| | url: str = "https://forge.evolutionaryscale.ai", |
| | model: str | None = None, |
| | token: str = "", |
| | request_timeout: int | None = None, |
| | ): |
| | """ |
| | Forge client for folding and inverse folding between sequence and structure spaces. |
| | |
| | Args: |
| | url: URL of the Forge server. |
| | model: Name of the model to be used for folding / inv folding. |
| | token: API token. |
| | request_timeout: Override the system default request timeout, in seconds. |
| | """ |
| | if token == "": |
| | raise RuntimeError( |
| | "Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE" |
| | ) |
| | self.url = url |
| | self.model = model |
| | self.token = token |
| | self.headers = {"Authorization": f"Bearer {self.token}"} |
| | self.request_timeout = request_timeout |
| |
|
| | def fold( |
| | self, |
| | sequence: str, |
| | potential_sequence_of_concern: bool, |
| | model_name: str | None = None, |
| | ) -> ESMProtein | ESMProteinError: |
| | """Predict coordinates for a protein sequence. |
| | |
| | Args: |
| | sequence: Protein sequence to be folded. |
| | potential_sequence_of_concern: Self disclosed potential_of_concern bit. |
| | This bit is largely ignored by the fold() endpoint. |
| | model_name: Override the client level model name if needed. |
| | """ |
| | request = {"sequence": sequence} |
| | if model_name is not None: |
| | request["model"] = model_name |
| | elif self.model is not None: |
| | request["model"] = self.model |
| | try: |
| | data = self._post("fold", request, potential_sequence_of_concern) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | return ESMProtein( |
| | sequence=sequence, |
| | coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True), |
| | ) |
| |
|
| | def inverse_fold( |
| | self, |
| | coordinates: torch.Tensor, |
| | config: InverseFoldingConfig, |
| | potential_sequence_of_concern: bool, |
| | model_name: str | None = None, |
| | ) -> ESMProtein | ESMProteinError: |
| | """Generate protein sequence from its structure. |
| | |
| | This endpoint is only supported by generative models like ESM3. |
| | |
| | Args: |
| | coordinates: Protein sequence coordinates to be inversely folded. |
| | config: Configurations related to inverse folding generation. |
| | potential_sequence_of_concern: Self disclosed potential_of_concern bit. |
| | Requires special permission to use. |
| | model_name: Override the client level model name if needed. |
| | """ |
| | inverse_folding_config = { |
| | "invalid_ids": config.invalid_ids, |
| | "temperature": config.temperature, |
| | } |
| | request = { |
| | "coordinates": maybe_list(coordinates, convert_nan_to_none=True), |
| | "inverse_folding_config": inverse_folding_config, |
| | } |
| | if model_name is not None: |
| | request["model"] = model_name |
| | elif self.model is not None: |
| | request["model"] = self.model |
| | try: |
| | data = self._post("inverse_fold", request, potential_sequence_of_concern) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | return ESMProtein(sequence=data["sequence"]) |
| |
|
| | def _post(self, endpoint, request, potential_sequence_of_concern): |
| | request["potential_sequence_of_concern"] = potential_sequence_of_concern |
| |
|
| | response = requests.post( |
| | urljoin(self.url, f"/api/v1/{endpoint}"), |
| | json=request, |
| | headers=self.headers, |
| | timeout=self.request_timeout, |
| | ) |
| |
|
| | if not response.ok: |
| | raise ESMProteinError( |
| | error_code=response.status_code, |
| | error_msg=f"Failure in {endpoint}: {response.text}", |
| | ) |
| |
|
| | data = response.json() |
| | |
| | |
| | if "outputs" not in data and "data" in data: |
| | data = data["data"] |
| |
|
| | |
| | if "warning_messages" in data and data["warning_messages"] is not None: |
| | for msg in data["warning_messages"]: |
| | print("\033[31m", msg, "\033[0m") |
| |
|
| | return data |
| |
|
| |
|
| | class ESM3ForgeInferenceClient(ESM3InferenceClient): |
| | def __init__( |
| | self, |
| | model: str, |
| | url: str = "https://forge.evolutionaryscale.ai", |
| | token: str = "", |
| | request_timeout: int | None = None, |
| | min_retry_wait: int = 1, |
| | max_retry_wait: int = 10, |
| | max_retry_attempts: int = 5, |
| | ): |
| | if token == "": |
| | raise RuntimeError( |
| | "Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE" |
| | ) |
| | self.model = model |
| | self.url = url |
| | self.token = token |
| | self.headers = {"Authorization": f"Bearer {self.token}"} |
| | self.request_timeout = request_timeout |
| | self.min_retry_wait = min_retry_wait |
| | self.max_retry_wait = max_retry_wait |
| | self.max_retry_attempts = max_retry_attempts |
| |
|
| | @staticmethod |
| | def retry_decorator(func): |
| | """ |
| | A static method that returns a retry decorator. This decorator uses the |
| | instance's retry settings. |
| | """ |
| |
|
| | @wraps(func) |
| | def wrapper(instance, *args, **kwargs): |
| | if skip_retries_var.get(): |
| | return func(instance, *args, **kwargs) |
| | retry_decorator = retry( |
| | retry=retry_if_result(retry_if_specific_error), |
| | wait=wait_exponential( |
| | multiplier=1, |
| | min=instance.min_retry_wait, |
| | max=instance.max_retry_wait, |
| | ), |
| | stop=stop_after_attempt(instance.max_retry_attempts), |
| | before_sleep=log_retry_attempt, |
| | ) |
| | |
| | return retry_decorator(func)(instance, *args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| | @retry_decorator |
| | def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: |
| | if isinstance(input, ESMProteinError): |
| | raise ValueError( |
| | f"Input must be an ESMProtein or ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}" |
| | ) |
| | assert isinstance(input, ESMProtein) or isinstance(input, ESMProteinTensor) |
| | if input.sequence is not None and config.num_steps > len(input.sequence): |
| | config.num_steps = len(input.sequence) |
| | print( |
| | "Warning: num_steps cannot exceed sequence length. Setting num_steps to sequence length." |
| | ) |
| | if isinstance(input, ESMProtein): |
| | output = self.__generate_protein(input, config) |
| | elif isinstance(input, ESMProteinTensor): |
| | output = self.__generate_protein_tensor(input, config) |
| | else: |
| | return ESMProteinError( |
| | error_code=500, error_msg=f"Unknown input type {type(input)}" |
| | ) |
| |
|
| | if ( |
| | isinstance(output, ESMProtein) |
| | and isinstance(input, ESMProtein) |
| | and config.track not in ["function", "residue_annotations"] |
| | ): |
| | |
| | |
| | output.function_annotations = input.function_annotations |
| |
|
| | return output |
| |
|
| | def batch_generate( |
| | self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig] |
| | ) -> Sequence[ProteinType]: |
| | """Forge supports auto-batching. So batch_generate() for the Forge client |
| | is as simple as running a collection of generate() in parallel using asyncio. |
| | """ |
| | with ThreadPoolExecutor() as executor: |
| | futures = [ |
| | executor.submit(self.generate, protein, config) |
| | for protein, config in zip(inputs, configs) |
| | ] |
| | results = [] |
| | for future in futures: |
| | try: |
| | results.append(future.result()) |
| | except Exception as e: |
| | results.append(ESMProteinError(500, str(e))) |
| | return results |
| |
|
| | def __generate_protein( |
| | self, input: ESMProtein, config: GenerationConfig |
| | ) -> ESMProtein | ESMProteinError: |
| | req = {} |
| | req["sequence"] = input.sequence |
| | req["secondary_structure"] = input.secondary_structure |
| | req["sasa"] = input.sasa |
| | if input.function_annotations is not None: |
| | req["function"] = [x.to_tuple() for x in input.function_annotations] |
| | req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) |
| |
|
| | request = { |
| | "model": self.model, |
| | "inputs": req, |
| | "track": config.track, |
| | "invalid_ids": config.invalid_ids, |
| | "schedule": config.schedule, |
| | "num_steps": config.num_steps, |
| | "temperature": config.temperature, |
| | "top_p": config.top_p, |
| | "condition_on_coordinates_only": config.condition_on_coordinates_only, |
| | } |
| | try: |
| | data = self._post("generate", request, input.potential_sequence_of_concern) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | return ESMProtein( |
| | sequence=data["outputs"]["sequence"], |
| | secondary_structure=data["outputs"]["secondary_structure"], |
| | sasa=data["outputs"]["sasa"], |
| | function_annotations=_list_to_function_annotations( |
| | data["outputs"]["function"] |
| | ), |
| | coordinates=maybe_tensor( |
| | data["outputs"]["coordinates"], convert_none_to_nan=True |
| | ), |
| | plddt=maybe_tensor(data["outputs"]["plddt"]), |
| | ptm=maybe_tensor(data["outputs"]["ptm"]), |
| | ) |
| |
|
| | def __generate_protein_tensor( |
| | self, input: ESMProteinTensor, config: GenerationConfig |
| | ) -> ESMProteinTensor | ESMProteinError: |
| | req = {} |
| | req["sequence"] = maybe_list(input.sequence) |
| | req["structure"] = maybe_list(input.structure) |
| | req["secondary_structure"] = maybe_list(input.secondary_structure) |
| | req["sasa"] = maybe_list(input.sasa) |
| | req["function"] = maybe_list(input.function) |
| | req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) |
| | req["residue_annotation"] = maybe_list(input.residue_annotations) |
| |
|
| | request = { |
| | "model": self.model, |
| | "inputs": req, |
| | "track": config.track, |
| | "invalid_ids": config.invalid_ids, |
| | "schedule": config.schedule, |
| | "num_steps": config.num_steps, |
| | "temperature": config.temperature, |
| | "top_p": config.top_p, |
| | "condition_on_coordinates_only": config.condition_on_coordinates_only, |
| | } |
| |
|
| | try: |
| | data = self._post( |
| | "generate_tensor", request, input.potential_sequence_of_concern |
| | ) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | def _field_to_tensor(field, convert_none_to_nan: bool = False): |
| | if field not in data["outputs"]: |
| | return None |
| | return maybe_tensor( |
| | data["outputs"][field], convert_none_to_nan=convert_none_to_nan |
| | ) |
| |
|
| | output = ESMProteinTensor( |
| | sequence=_field_to_tensor("sequence"), |
| | structure=_field_to_tensor("structure"), |
| | secondary_structure=_field_to_tensor("secondary_structure"), |
| | sasa=_field_to_tensor("sasa"), |
| | function=_field_to_tensor("function"), |
| | residue_annotations=_field_to_tensor("residue_annotation"), |
| | coordinates=_field_to_tensor("coordinates", convert_none_to_nan=True), |
| | ) |
| |
|
| | return output |
| |
|
| | @retry_decorator |
| | def forward_and_sample( |
| | self, input: ESMProteinTensor, sampling_configuration: SamplingConfig |
| | ) -> ForwardAndSampleOutput | ESMProteinError: |
| | _validate_protein_tensor_input(input) |
| | validate_sampling_config(sampling_configuration, on_invalid="raise") |
| |
|
| | req = {} |
| | sampling_config = {} |
| | embedding_config = { |
| | "sequence": sampling_configuration.return_mean_embedding, |
| | "per_residue": sampling_configuration.return_per_residue_embeddings, |
| | } |
| |
|
| | req["sequence"] = maybe_list(input.sequence) |
| | req["structure"] = maybe_list(input.structure) |
| | req["secondary_structure"] = maybe_list(input.secondary_structure) |
| | req["sasa"] = maybe_list(input.sasa) |
| | req["function"] = maybe_list(input.function) |
| | req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) |
| | req["residue_annotation"] = maybe_list(input.residue_annotations) |
| |
|
| | def do_track(t: str): |
| | track: SamplingTrackConfig | None |
| | if (track := getattr(sampling_configuration, t, None)) is None: |
| | sampling_config[t] = None |
| | else: |
| | sampling_config[t] = { |
| | "temperature": track.temperature, |
| | "top_p": track.top_p, |
| | "only_sample_masked_tokens": track.only_sample_masked_tokens, |
| | "invalid_ids": track.invalid_ids, |
| | "topk_logprobs": track.topk_logprobs, |
| | } |
| |
|
| | do_track("sequence") |
| | do_track("structure") |
| | do_track("secondary_structure") |
| | do_track("sasa") |
| | do_track("function") |
| |
|
| | request = { |
| | "model": self.model, |
| | "inputs": req, |
| | "sampling_config": sampling_config, |
| | "embedding_config": embedding_config, |
| | } |
| | try: |
| | data = self._post( |
| | "forward_and_sample", request, input.potential_sequence_of_concern |
| | ) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | def get(k, field): |
| | if data[k] is None: |
| | return None |
| | v = data[k][field] |
| | return torch.tensor(v) if v is not None else None |
| |
|
| | tokens = ESMProteinTensor( |
| | sequence=get("sequence", "tokens"), |
| | structure=get("structure", "tokens"), |
| | secondary_structure=get("secondary_structure", "tokens"), |
| | sasa=get("sasa", "tokens"), |
| | function=get("function", "tokens"), |
| | ) |
| |
|
| | def get_track(field): |
| | return ForwardTrackData( |
| | sequence=get("sequence", field), |
| | structure=get("structure", field), |
| | secondary_structure=get("secondary_structure", field), |
| | sasa=get("sasa", field), |
| | function=get("function", field), |
| | ) |
| |
|
| | def operate_on_track(track: ForwardTrackData, fn): |
| | apply = lambda x: fn(x) if x is not None else None |
| | return ForwardTrackData( |
| | sequence=apply(track.sequence), |
| | structure=apply(track.structure), |
| | secondary_structure=apply(track.secondary_structure), |
| | sasa=apply(track.sasa), |
| | function=apply(track.function), |
| | ) |
| |
|
| | logprob = get_track("logprobs") |
| | output = ForwardAndSampleOutput( |
| | protein_tensor=tokens, |
| | logprob=logprob, |
| | prob=operate_on_track(logprob, torch.exp), |
| | entropy=get_track("entropy"), |
| | topk_logprob=get_track("topk_logprobs"), |
| | topk_tokens=get_track("topk_tokens"), |
| | per_residue_embedding=data["embeddings"]["per_residue"], |
| | mean_embedding=data["embeddings"]["sequence"], |
| | ) |
| | return output |
| |
|
| | @retry_decorator |
| | def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError: |
| | tracks = {} |
| | tracks["sequence"] = input.sequence |
| | tracks["secondary_structure"] = input.secondary_structure |
| | tracks["sasa"] = input.sasa |
| | if input.function_annotations is not None: |
| | tracks["function"] = [x.to_tuple() for x in input.function_annotations] |
| | tracks["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) |
| |
|
| | request = {"inputs": tracks, "model": self.model} |
| |
|
| | try: |
| | data = self._post("encode", request, input.potential_sequence_of_concern) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | return ESMProteinTensor( |
| | sequence=maybe_tensor(data["outputs"]["sequence"]), |
| | structure=maybe_tensor(data["outputs"]["structure"]), |
| | coordinates=maybe_tensor( |
| | data["outputs"]["coordinates"], convert_none_to_nan=True |
| | ), |
| | secondary_structure=maybe_tensor(data["outputs"]["secondary_structure"]), |
| | sasa=maybe_tensor(data["outputs"]["sasa"]), |
| | function=maybe_tensor(data["outputs"]["function"]), |
| | residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]), |
| | ) |
| |
|
| | @retry_decorator |
| | def decode(self, input: ESMProteinTensor) -> ESMProtein | ESMProteinError: |
| | _validate_protein_tensor_input(input) |
| |
|
| | tokens = {} |
| | tokens["sequence"] = maybe_list(input.sequence) |
| | tokens["structure"] = maybe_list(input.structure) |
| | tokens["secondary_structure"] = maybe_list(input.secondary_structure) |
| | tokens["sasa"] = maybe_list(input.sasa) |
| | tokens["function"] = maybe_list(input.function) |
| | tokens["residue_annotation"] = maybe_list(input.residue_annotations) |
| | tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) |
| |
|
| | request = {"model": self.model, "inputs": tokens} |
| |
|
| | try: |
| | data = self._post("decode", request, input.potential_sequence_of_concern) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | return ESMProtein( |
| | sequence=data["outputs"]["sequence"], |
| | secondary_structure=data["outputs"]["secondary_structure"], |
| | sasa=data["outputs"]["sasa"], |
| | function_annotations=_list_to_function_annotations( |
| | data["outputs"]["function"] |
| | ), |
| | coordinates=maybe_tensor( |
| | data["outputs"]["coordinates"], convert_none_to_nan=True |
| | ), |
| | plddt=maybe_tensor(data["outputs"]["plddt"]), |
| | ptm=maybe_tensor(data["outputs"]["ptm"]), |
| | ) |
| |
|
| | @retry_decorator |
| | def logits( |
| | self, |
| | input: ESMProteinTensor, |
| | config: LogitsConfig = LogitsConfig(), |
| | return_bytes: bool = True, |
| | ) -> LogitsOutput | ESMProteinError: |
| | _validate_protein_tensor_input(input) |
| |
|
| | |
| | |
| | |
| | req = {} |
| | req["sequence"] = maybe_list(input.sequence) |
| | req["structure"] = maybe_list(input.structure) |
| | req["secondary_structure"] = maybe_list(input.secondary_structure) |
| | req["sasa"] = maybe_list(input.sasa) |
| | req["function"] = maybe_list(input.function) |
| | req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) |
| | req["residue_annotation"] = maybe_list(input.residue_annotations) |
| |
|
| | logits_config = { |
| | "sequence": config.sequence, |
| | "structure": config.structure, |
| | "secondary_structure": config.secondary_structure, |
| | "sasa": config.sasa, |
| | "function": config.function, |
| | "residue_annotations": config.residue_annotations, |
| | "return_embeddings": config.return_embeddings, |
| | "return_hidden_states": config.return_hidden_states, |
| | "ith_hidden_layer": config.ith_hidden_layer, |
| | } |
| |
|
| | request = {"model": self.model, "inputs": req, "logits_config": logits_config} |
| | try: |
| | data = self._post( |
| | "logits", |
| | request, |
| | input.potential_sequence_of_concern, |
| | return_bytes=return_bytes, |
| | ) |
| | except ESMProteinError as e: |
| | return e |
| |
|
| | def _maybe_logits(track: str): |
| | if "logits" in data and track in data["logits"]: |
| | return maybe_tensor(data["logits"][track]) |
| | return None |
| |
|
| | def _maybe_b64_decode(obj): |
| | return ( |
| | deserialize_tensors(base64.b64decode(obj)) |
| | if return_bytes and obj is not None |
| | else obj |
| | ) |
| |
|
| | logits = _maybe_b64_decode(data["logits"]) |
| | data["logits"] = dict(logits) if logits is not None else logits |
| | data["embeddings"] = _maybe_b64_decode(data["embeddings"]) |
| | data["hidden_states"] = _maybe_b64_decode(data["hidden_states"]) |
| |
|
| | output = LogitsOutput( |
| | logits=ForwardTrackData( |
| | sequence=_maybe_logits("sequence"), |
| | structure=_maybe_logits("structure"), |
| | secondary_structure=_maybe_logits("secondary_structure"), |
| | sasa=_maybe_logits("sasa"), |
| | function=_maybe_logits("function"), |
| | ), |
| | embeddings=maybe_tensor(data["embeddings"]), |
| | residue_annotation_logits=_maybe_logits("residue_annotation"), |
| | hidden_states=maybe_tensor(data["hidden_states"]), |
| | ) |
| |
|
| | return output |
| |
|
| | def _post( |
| | self, |
| | endpoint, |
| | request, |
| | potential_sequence_of_concern, |
| | return_bytes: bool = False, |
| | ): |
| | request["potential_sequence_of_concern"] = potential_sequence_of_concern |
| | headers = dict(self.headers) |
| | if return_bytes: |
| | headers["return-bytes"] = "true" |
| | response = requests.post( |
| | urljoin(self.url, f"/api/v1/{endpoint}"), |
| | json=request, |
| | headers=headers, |
| | timeout=self.request_timeout, |
| | ) |
| |
|
| | if not response.ok: |
| | raise ESMProteinError( |
| | error_code=response.status_code, |
| | error_msg=f"Failure in {endpoint}: {response.text}", |
| | ) |
| |
|
| | data = response.json() |
| | |
| | |
| | if "outputs" not in data and "data" in data: |
| | data = data["data"] |
| |
|
| | return data |
| |
|
| | @property |
| | def raw_model(self): |
| | raise NotImplementedError( |
| | f"Can not get underlying remote model {self.model} from a Forge client." |
| | ) |
| |
|