| | from abc import ABC, abstractmethod |
| |
|
| | import attr |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from src.data.esm.sdk.api import ( |
| | ESM3InferenceClient, |
| | ESMProtein, |
| | ESMProteinError, |
| | ESMProteinTensor, |
| | SamplingConfig, |
| | SamplingTrackConfig, |
| | ) |
| | from src.data.esm.sdk.forge import ESM3ForgeInferenceClient |
| | from src.data.esm.tokenization import get_esm3_model_tokenizers |
| |
|
| |
|
| | class GuidedDecodingScoringFunction(ABC): |
| | @abstractmethod |
| | def __call__(self, protein: ESMProtein) -> float: |
| | pass |
| |
|
| |
|
| | class ESM3GuidedDecoding: |
| | """This class can be used to perform derivative-free guided decoding, based on |
| | the method described in "Derivative-Free Guidance in Continuous and Discrete Diffusion Models with Soft Value-Based Decoding" |
| | https://arxiv.org/abs/2408.08252 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | client: ESM3InferenceClient, |
| | scoring_function: GuidedDecodingScoringFunction, |
| | ): |
| | if isinstance(client, ESM3): |
| | self.tokenizers = client.tokenizers |
| | elif isinstance(client, ESM3ForgeInferenceClient): |
| | self.tokenizers = get_esm3_model_tokenizers(client.model) |
| | else: |
| | raise ValueError( |
| | "client must be an instance of ESM3 or ESM3ForgeInferenceClient" |
| | ) |
| |
|
| | self.client = client |
| | self.scoring_function = scoring_function |
| |
|
| | def guided_generate( |
| | self, |
| | protein: ESMProtein, |
| | num_decoding_steps: int, |
| | num_samples_per_step: int, |
| | denoised_prediction_temperature: float = 0.0, |
| | track: str = "sequence", |
| | verbose: bool = True, |
| | ) -> ESMProtein: |
| | protein_tensor = self.client.encode(protein) |
| |
|
| | assert not isinstance(protein_tensor, ESMProteinError) |
| |
|
| | if track == "structure": |
| | protein_tensor = self.maybe_add_default_structure_tokens(protein_tensor) |
| |
|
| | num_masked_positions = self.get_number_of_masked_positions( |
| | protein_tensor, track=track |
| | ) |
| | num_positions_to_unmask = num_masked_positions // num_decoding_steps |
| |
|
| | current_score = -1 |
| |
|
| | if verbose: |
| | pbar = tqdm(range(num_decoding_steps), desc="Current score: -1") |
| | else: |
| | pbar = range(num_decoding_steps) |
| |
|
| | for step in pbar: |
| | if step == num_decoding_steps - 1: |
| | |
| | num_positions_to_unmask = self.get_number_of_masked_positions( |
| | protein_tensor, track=track |
| | ) |
| |
|
| | samples = [] |
| | scores = [] |
| | for _ in range(num_samples_per_step): |
| | sample = self.randomly_unmask_positions( |
| | protein_tensor, num_positions_to_unmask, track=track |
| | ) |
| | scores.append( |
| | self.reward_function( |
| | sample, |
| | denoised_prediction_temperature=denoised_prediction_temperature, |
| | ) |
| | ) |
| | samples.append(sample) |
| |
|
| | |
| | best_sample = samples[scores.index(max(scores))] |
| | current_score = max(scores) |
| | protein_tensor = best_sample |
| |
|
| | if verbose: |
| | pbar.set_description(f"Current score: {current_score:.2f}") |
| |
|
| | |
| | protein_tensor_output = self.client.forward_and_sample( |
| | protein_tensor, |
| | SamplingConfig( |
| | sequence=SamplingTrackConfig(temperature=0.0), |
| | structure=SamplingTrackConfig(temperature=0.0), |
| | ), |
| | ) |
| |
|
| | assert not isinstance(protein_tensor_output, ESMProteinError) |
| | protein_tensor = protein_tensor_output.protein_tensor |
| |
|
| | decoded_protein = self.client.decode(protein_tensor) |
| | assert not isinstance(decoded_protein, ESMProteinError) |
| | return decoded_protein |
| |
|
| | def reward_function( |
| | self, |
| | protein_tensor: ESMProteinTensor, |
| | denoised_prediction_temperature: float = 0.0, |
| | ) -> float: |
| | denoised_protein = self.predict_denoised( |
| | protein_tensor, temperature=denoised_prediction_temperature |
| | ) |
| | return self.scoring_function(denoised_protein) |
| |
|
| | def get_number_of_masked_positions( |
| | self, protein_tensor: ESMProteinTensor, track: str = "sequence" |
| | ) -> int: |
| | assert isinstance(protein_tensor, ESMProteinTensor) |
| | track_tensor = getattr(protein_tensor, track) |
| | track_tokenizer = getattr(self.tokenizers, track) |
| | is_mask = track_tensor == track_tokenizer.mask_token_id |
| | return is_mask.sum().item() |
| |
|
| | def randomly_unmask_positions( |
| | self, |
| | protein_tensor: ESMProteinTensor, |
| | num_positions_to_unmask: int, |
| | temperature: float = 1.0, |
| | track: str = "sequence", |
| | ) -> ESMProteinTensor: |
| | track_tensor = getattr(protein_tensor, track) |
| | assert track_tensor is not None |
| | protein_tensor = attr.evolve(protein_tensor) |
| | setattr(protein_tensor, track, track_tensor.clone()) |
| |
|
| | track_tensor = getattr(protein_tensor, track) |
| | track_tokenizer = getattr(self.tokenizers, track) |
| |
|
| | is_mask = track_tensor == track_tokenizer.mask_token_id |
| | num_masked_positions = is_mask.sum().item() |
| |
|
| | if num_positions_to_unmask > num_masked_positions: |
| | num_positions_to_unmask = num_masked_positions |
| |
|
| | mask_indices = is_mask.nonzero(as_tuple=False) |
| | mask_indices = mask_indices[torch.randperm(mask_indices.size(0))] |
| | mask_indices = mask_indices[:num_positions_to_unmask] |
| |
|
| | sampling_config = SamplingConfig() |
| | setattr(sampling_config, track, SamplingTrackConfig(temperature=temperature)) |
| |
|
| | denoised_protein_tensor_output = self.client.forward_and_sample( |
| | protein_tensor, sampling_configuration=sampling_config |
| | ) |
| | assert not isinstance(denoised_protein_tensor_output, ESMProteinError) |
| | denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor |
| | output_track_tensor = getattr(denoised_protein_tensor, track) |
| | assert output_track_tensor is not None |
| | track_tensor[mask_indices] = output_track_tensor[mask_indices] |
| | setattr(protein_tensor, track, track_tensor) |
| |
|
| | return protein_tensor |
| |
|
| | def predict_denoised( |
| | self, protein_tensor: ESMProteinTensor, temperature: float = 0.0 |
| | ) -> ESMProtein: |
| | denoised_protein_tensor_output = self.client.forward_and_sample( |
| | protein_tensor, |
| | sampling_configuration=SamplingConfig( |
| | sequence=SamplingTrackConfig(temperature=temperature), |
| | structure=SamplingTrackConfig(temperature=temperature), |
| | ), |
| | ) |
| | assert not isinstance(denoised_protein_tensor_output, ESMProteinError) |
| | denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor |
| | denoised_protein = self.client.decode(denoised_protein_tensor) |
| | assert not isinstance(denoised_protein, ESMProteinError) |
| | return denoised_protein |
| |
|
| | def maybe_add_default_structure_tokens( |
| | self, protein_tensor: ESMProteinTensor |
| | ) -> ESMProteinTensor: |
| | empty_protein_tensor = ESMProteinTensor.empty( |
| | len(protein_tensor) - 2, |
| | tokenizers=self.tokenizers, |
| | device=protein_tensor.device, |
| | ) |
| | if protein_tensor.structure is None: |
| | setattr(protein_tensor, "structure", empty_protein_tensor.structure) |
| | else: |
| | print("Warning: structure already exists in protein_tensor") |
| | return protein_tensor |
| |
|