nas / PFMBench /src /data /esm /utils /generation_test.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import pytest
import torch
from evolutionaryscale.models.esm3v2 import Esm3v2
from src.data.esm.sdk.api import (
ESMProtein,
ESMProteinTensor,
GenerationConfig,
)
from evolutionaryscale.utils.env import ModelName
from evolutionaryscale.utils.remote_inference.api_v1 import (
ESM3RemoteModelInferenceClient,
)
from projects.forge.fastapi.utils.model import _load_esm_model
@pytest.fixture()
def esm3_remote_inference_client():
model = _load_esm_model(
ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False
)
assert isinstance(model, Esm3v2)
client = ESM3RemoteModelInferenceClient(
model,
tokenizers=model.tokenizers,
device=torch.device("cuda"),
enable_batched_runner=False,
)
return client
@pytest.mark.gpu
def test_chain_break_tokens(esm3_remote_inference_client):
tokenizer = esm3_remote_inference_client.tokenizers.sequence
# 3 separate chains with 2 chainbreak tokens.
sequence_with_chain_breaks = torch.tensor(
[
tokenizer.bos_token_id,
20,
20,
20,
20,
tokenizer.chain_break_token_id,
21,
21,
21,
tokenizer.chain_break_token_id,
22,
22,
22,
tokenizer.eos_token_id,
]
)
protein = esm3_remote_inference_client.generate(
ESMProteinTensor(sequence=sequence_with_chain_breaks),
# There are 10 tokens that actually need to be sampled.
GenerationConfig(track="structure", num_steps=10),
)
assert isinstance(protein, ESMProteinTensor)
assert protein.structure is not None
@pytest.mark.gpu
def test_num_decoding_steps_more_than_mask_tokens(esm3_remote_inference_client):
protein = esm3_remote_inference_client.generate(
esm3_remote_inference_client.encode(
ESMProtein(sequence="CDEFG")
), # sequence of 5.
GenerationConfig(track="structure", num_steps=10), # use 10 decoding steps.
)
# Client should handle over-specification of decoding steps.
# TODO: This should be a warning.
assert isinstance(protein, ESMProteinTensor)
assert protein.structure is not None
@pytest.mark.gpu
def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_client):
protein_list = esm3_remote_inference_client.batch_generate(
inputs=[
esm3_remote_inference_client.encode(ESMProtein(sequence="CDEFG")),
esm3_remote_inference_client.encode(ESMProtein(sequence="ABCDEFG")),
esm3_remote_inference_client.encode(ESMProtein(sequence="AB__EFG")),
],
configs=[
GenerationConfig(track="structure", num_steps=10),
GenerationConfig(track="structure", num_steps=3),
GenerationConfig(track="sequence", num_steps=20),
],
)
# Client should handle over-specification of decoding steps.
# TODO: This should be a warning.
assert isinstance(protein_list[0], ESMProteinTensor)
assert protein_list[0].structure is not None
assert isinstance(protein_list[1], ESMProteinTensor)
assert protein_list[1].structure is not None
assert isinstance(protein_list[2], ESMProteinTensor)
assert protein_list[2].sequence is not None
@pytest.mark.gpu
def test_encode_chainbreak_token(esm3_remote_inference_client):
protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK"))
assert isinstance(protein, ESMProteinTensor)
assert protein.sequence is not None
assert (
protein.sequence[6]
== esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id
)
@pytest.mark.gpu
def test_generation_with_chainbreak_token(esm3_remote_inference_client):
chainbreak_sequence = torch.tensor(
[
esm3_remote_inference_client.tokenizers.sequence.bos_token_id,
20,
8,
11,
17,
14,
esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id,
15,
14,
16,
15,
15,
esm3_remote_inference_client.tokenizers.sequence.eos_token_id,
]
)
protein = esm3_remote_inference_client.generate(
ESMProteinTensor(sequence=chainbreak_sequence),
GenerationConfig(track="structure", num_steps=1),
)
# Can't specify more decoding steps than masks available.
assert isinstance(protein, ESMProteinTensor)
assert protein.structure is not None
assert (
protein.structure[6]
== esm3_remote_inference_client.tokenizers.structure.chain_break_token_id
)