nas / PFMBench /src /data /esm /sdk /experimental.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
from abc import ABC, abstractmethod
import attr
import torch
from tqdm import tqdm
from src.data.esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
from src.data.esm.sdk.forge import ESM3ForgeInferenceClient
from src.data.esm.tokenization import get_esm3_model_tokenizers
class GuidedDecodingScoringFunction(ABC):
@abstractmethod
def __call__(self, protein: ESMProtein) -> float:
pass
class ESM3GuidedDecoding:
"""This class can be used to perform derivative-free guided decoding, based on
the method described in "Derivative-Free Guidance in Continuous and Discrete Diffusion Models with Soft Value-Based Decoding"
https://arxiv.org/abs/2408.08252
"""
def __init__(
self,
client: ESM3InferenceClient,
scoring_function: GuidedDecodingScoringFunction,
):
if isinstance(client, ESM3):
self.tokenizers = client.tokenizers
elif isinstance(client, ESM3ForgeInferenceClient):
self.tokenizers = get_esm3_model_tokenizers(client.model)
else:
raise ValueError(
"client must be an instance of ESM3 or ESM3ForgeInferenceClient"
)
self.client = client
self.scoring_function = scoring_function
def guided_generate(
self,
protein: ESMProtein,
num_decoding_steps: int,
num_samples_per_step: int,
denoised_prediction_temperature: float = 0.0,
track: str = "sequence",
verbose: bool = True,
) -> ESMProtein:
protein_tensor = self.client.encode(protein)
assert not isinstance(protein_tensor, ESMProteinError)
if track == "structure":
protein_tensor = self.maybe_add_default_structure_tokens(protein_tensor)
num_masked_positions = self.get_number_of_masked_positions(
protein_tensor, track=track
)
num_positions_to_unmask = num_masked_positions // num_decoding_steps
current_score = -1
if verbose:
pbar = tqdm(range(num_decoding_steps), desc="Current score: -1")
else:
pbar = range(num_decoding_steps)
for step in pbar:
if step == num_decoding_steps - 1:
# At the last step, unmask all remaining positions
num_positions_to_unmask = self.get_number_of_masked_positions(
protein_tensor, track=track
)
samples = []
scores = []
for _ in range(num_samples_per_step):
sample = self.randomly_unmask_positions(
protein_tensor, num_positions_to_unmask, track=track
)
scores.append(
self.reward_function(
sample,
denoised_prediction_temperature=denoised_prediction_temperature,
)
)
samples.append(sample)
# Select best scoring sample
best_sample = samples[scores.index(max(scores))]
current_score = max(scores)
protein_tensor = best_sample
if verbose:
pbar.set_description(f"Current score: {current_score:.2f}") # type: ignore
# Fully predict and decode final protein
protein_tensor_output = self.client.forward_and_sample(
protein_tensor,
SamplingConfig(
sequence=SamplingTrackConfig(temperature=0.0),
structure=SamplingTrackConfig(temperature=0.0),
),
)
assert not isinstance(protein_tensor_output, ESMProteinError)
protein_tensor = protein_tensor_output.protein_tensor
decoded_protein = self.client.decode(protein_tensor)
assert not isinstance(decoded_protein, ESMProteinError)
return decoded_protein
def reward_function(
self,
protein_tensor: ESMProteinTensor,
denoised_prediction_temperature: float = 0.0,
) -> float:
denoised_protein = self.predict_denoised(
protein_tensor, temperature=denoised_prediction_temperature
)
return self.scoring_function(denoised_protein)
def get_number_of_masked_positions(
self, protein_tensor: ESMProteinTensor, track: str = "sequence"
) -> int:
assert isinstance(protein_tensor, ESMProteinTensor)
track_tensor = getattr(protein_tensor, track)
track_tokenizer = getattr(self.tokenizers, track)
is_mask = track_tensor == track_tokenizer.mask_token_id
return is_mask.sum().item() # type: ignore
def randomly_unmask_positions(
self,
protein_tensor: ESMProteinTensor,
num_positions_to_unmask: int,
temperature: float = 1.0,
track: str = "sequence",
) -> ESMProteinTensor:
track_tensor = getattr(protein_tensor, track)
assert track_tensor is not None
protein_tensor = attr.evolve(protein_tensor)
setattr(protein_tensor, track, track_tensor.clone())
track_tensor = getattr(protein_tensor, track)
track_tokenizer = getattr(self.tokenizers, track)
is_mask = track_tensor == track_tokenizer.mask_token_id
num_masked_positions = is_mask.sum().item()
if num_positions_to_unmask > num_masked_positions:
num_positions_to_unmask = num_masked_positions # type: ignore
mask_indices = is_mask.nonzero(as_tuple=False)
mask_indices = mask_indices[torch.randperm(mask_indices.size(0))]
mask_indices = mask_indices[:num_positions_to_unmask]
sampling_config = SamplingConfig()
setattr(sampling_config, track, SamplingTrackConfig(temperature=temperature))
denoised_protein_tensor_output = self.client.forward_and_sample(
protein_tensor, sampling_configuration=sampling_config
)
assert not isinstance(denoised_protein_tensor_output, ESMProteinError)
denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor
output_track_tensor = getattr(denoised_protein_tensor, track)
assert output_track_tensor is not None
track_tensor[mask_indices] = output_track_tensor[mask_indices]
setattr(protein_tensor, track, track_tensor)
return protein_tensor
def predict_denoised(
self, protein_tensor: ESMProteinTensor, temperature: float = 0.0
) -> ESMProtein:
denoised_protein_tensor_output = self.client.forward_and_sample(
protein_tensor,
sampling_configuration=SamplingConfig(
sequence=SamplingTrackConfig(temperature=temperature),
structure=SamplingTrackConfig(temperature=temperature),
),
)
assert not isinstance(denoised_protein_tensor_output, ESMProteinError)
denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor
denoised_protein = self.client.decode(denoised_protein_tensor)
assert not isinstance(denoised_protein, ESMProteinError)
return denoised_protein
def maybe_add_default_structure_tokens(
self, protein_tensor: ESMProteinTensor
) -> ESMProteinTensor:
empty_protein_tensor = ESMProteinTensor.empty(
len(protein_tensor) - 2,
tokenizers=self.tokenizers,
device=protein_tensor.device,
)
if protein_tensor.structure is None:
setattr(protein_tensor, "structure", empty_protein_tensor.structure)
else:
print("Warning: structure already exists in protein_tensor")
return protein_tensor