File size: 4,747 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import pytest
import torch

from evolutionaryscale.models.esm3v2 import Esm3v2
from src.data.esm.sdk.api import (
    ESMProtein,
    ESMProteinTensor,
    GenerationConfig,
)
from evolutionaryscale.utils.env import ModelName
from evolutionaryscale.utils.remote_inference.api_v1 import (
    ESM3RemoteModelInferenceClient,
)
from projects.forge.fastapi.utils.model import _load_esm_model


@pytest.fixture()
def esm3_remote_inference_client():
    model = _load_esm_model(
        ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False
    )
    assert isinstance(model, Esm3v2)
    client = ESM3RemoteModelInferenceClient(
        model,
        tokenizers=model.tokenizers,
        device=torch.device("cuda"),
        enable_batched_runner=False,
    )
    return client


@pytest.mark.gpu
def test_chain_break_tokens(esm3_remote_inference_client):
    tokenizer = esm3_remote_inference_client.tokenizers.sequence
    # 3 separate chains with 2 chainbreak tokens.
    sequence_with_chain_breaks = torch.tensor(
        [
            tokenizer.bos_token_id,
            20,
            20,
            20,
            20,
            tokenizer.chain_break_token_id,
            21,
            21,
            21,
            tokenizer.chain_break_token_id,
            22,
            22,
            22,
            tokenizer.eos_token_id,
        ]
    )
    protein = esm3_remote_inference_client.generate(
        ESMProteinTensor(sequence=sequence_with_chain_breaks),
        # There are 10 tokens that actually need to be sampled.
        GenerationConfig(track="structure", num_steps=10),
    )

    assert isinstance(protein, ESMProteinTensor)
    assert protein.structure is not None


@pytest.mark.gpu
def test_num_decoding_steps_more_than_mask_tokens(esm3_remote_inference_client):
    protein = esm3_remote_inference_client.generate(
        esm3_remote_inference_client.encode(
            ESMProtein(sequence="CDEFG")
        ),  # sequence of 5.
        GenerationConfig(track="structure", num_steps=10),  # use 10 decoding steps.
    )
    # Client should handle over-specification of decoding steps.
    # TODO: This should be a warning.
    assert isinstance(protein, ESMProteinTensor)
    assert protein.structure is not None


@pytest.mark.gpu
def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_client):
    protein_list = esm3_remote_inference_client.batch_generate(
        inputs=[
            esm3_remote_inference_client.encode(ESMProtein(sequence="CDEFG")),
            esm3_remote_inference_client.encode(ESMProtein(sequence="ABCDEFG")),
            esm3_remote_inference_client.encode(ESMProtein(sequence="AB__EFG")),
        ],
        configs=[
            GenerationConfig(track="structure", num_steps=10),
            GenerationConfig(track="structure", num_steps=3),
            GenerationConfig(track="sequence", num_steps=20),
        ],
    )
    # Client should handle over-specification of decoding steps.
    # TODO: This should be a warning.
    assert isinstance(protein_list[0], ESMProteinTensor)
    assert protein_list[0].structure is not None
    assert isinstance(protein_list[1], ESMProteinTensor)
    assert protein_list[1].structure is not None
    assert isinstance(protein_list[2], ESMProteinTensor)
    assert protein_list[2].sequence is not None


@pytest.mark.gpu
def test_encode_chainbreak_token(esm3_remote_inference_client):
    protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK"))
    assert isinstance(protein, ESMProteinTensor)
    assert protein.sequence is not None
    assert (
        protein.sequence[6]
        == esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id
    )


@pytest.mark.gpu
def test_generation_with_chainbreak_token(esm3_remote_inference_client):
    chainbreak_sequence = torch.tensor(
        [
            esm3_remote_inference_client.tokenizers.sequence.bos_token_id,
            20,
            8,
            11,
            17,
            14,
            esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id,
            15,
            14,
            16,
            15,
            15,
            esm3_remote_inference_client.tokenizers.sequence.eos_token_id,
        ]
    )

    protein = esm3_remote_inference_client.generate(
        ESMProteinTensor(sequence=chainbreak_sequence),
        GenerationConfig(track="structure", num_steps=1),
    )
    # Can't specify more decoding steps than masks available.
    assert isinstance(protein, ESMProteinTensor)
    assert protein.structure is not None
    assert (
        protein.structure[6]
        == esm3_remote_inference_client.tokenizers.structure.chain_break_token_id
    )