File size: 12,383 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import warnings
from typing import Literal

import attr
import torch
import torch.nn.functional as F

from src.data.esm.sdk.api import (
    ESMProteinTensor,
    SamplingConfig,
    SamplingTrackConfig,
)
from src.data.esm.tokenization import (
    TokenizerCollectionProtocol,
    get_invalid_tokenizer_ids,
)
from src.data.esm.tokenization.function_tokenizer import (
    InterProQuantizedTokenizer,
)
from src.data.esm.utils.constants.esm3 import (
    MAX_RESIDUE_ANNOTATIONS,
    SASA_DISCRETIZATION_BOUNDARIES,
)


def _non_batched_dims(k: str, v: torch.Tensor):
    match k:
        case "sequence":
            return 1
        case "structure":
            if v.is_floating_point():
                # This is the one hot soft structure token.
                return 2
            else:
                # This is the normal int structure token.
                return 1
        case "secondary_structure":
            return 1
        case "sasa":
            return 1
        case "function":
            return 2
        case "residue_annotations":
            return 2
        case "coordinates":
            return 3
        case _:
            raise ValueError(f"Unknown dim for track {k}")


class _BatchedESMProteinTensor(ESMProteinTensor):
    @staticmethod
    def from_protein_tensor(protein: ESMProteinTensor):
        def _maybe_unsqueeze(x: torch.Tensor | None):
            return x.unsqueeze(0) if x is not None else None

        return _BatchedESMProteinTensor(
            sequence=_maybe_unsqueeze(protein.sequence),
            structure=_maybe_unsqueeze(protein.structure),
            secondary_structure=_maybe_unsqueeze(protein.secondary_structure),
            sasa=_maybe_unsqueeze(protein.sasa),
            function=_maybe_unsqueeze(protein.function),
            residue_annotations=_maybe_unsqueeze(protein.residue_annotations),
            coordinates=_maybe_unsqueeze(protein.coordinates),
        )

    def __len__(self) -> int:
        def get_len(k, v) -> int:
            assert len(v.shape) == _non_batched_dims(k, v) + 1
            return v.size(1)

        l = self._detect_attribute(get_len, "length")
        return l if l is not None else 0

    @property
    def batch_size(self) -> int:
        def get_batch_size(k, v) -> int:
            assert len(v.shape) == _non_batched_dims(k, v) + 1
            return v.size(0)

        d = self._detect_attribute(get_batch_size, "batch size")
        assert d is not None
        return d

    def slice(self, i: int, sequence_len: int | None = None) -> ESMProteinTensor:
        def _maybe_slice(x: torch.Tensor | None):
            if x is None:
                return None
            row = x[i]
            if sequence_len is not None:
                row = row[:sequence_len]
            return row

        return ESMProteinTensor(
            sequence=_maybe_slice(self.sequence),
            structure=_maybe_slice(self.structure),
            secondary_structure=_maybe_slice(self.secondary_structure),
            sasa=_maybe_slice(self.sasa),
            function=_maybe_slice(self.function),
            residue_annotations=_maybe_slice(self.residue_annotations),
            coordinates=_maybe_slice(self.coordinates),
        )

    def set_slice(self, i: int, slice: ESMProteinTensor):
        """Update the i-th slice of this tensor data class."""
        for f in attr.fields(ESMProteinTensor):
            s = getattr(self, f.name)
            v = getattr(slice, f.name)

            assert v is None or (
                v is not None and s is not None
            ), f"Trying to set a slice on None tensor ({f.name})."

            if v is not None:
                s[i, ...] = v


def get_default_sampling_config(
    tokenizers: TokenizerCollectionProtocol,
) -> SamplingConfig:
    tracks = [f.name for f in attr.fields(SamplingConfig)]
    sampling_config = SamplingConfig()
    for current_track in tracks:
        setattr(
            sampling_config,
            current_track,
            SamplingTrackConfig(
                invalid_ids=get_invalid_tokenizer_ids(
                    getattr(tokenizers, current_track)
                ),
                temperature=1.0,
                top_p=1.0,
                # TODO: Add different mask and padding tokens for all tracks
                # Some tracks have the same pad and mask, which causes ambiguity when sampling
                only_sample_masked_tokens=current_track
                not in ["secondary_structure", "sasa", "function"],
            ),
        )
    return sampling_config


def validate_sampling_config(
    sampling_config: SamplingConfig, on_invalid: Literal["raise", "warn"] = "warn"
):
    # Check that all tracks have topk_logprobs less or equal to MAX_TOP_K
    for track in attr.fields(SamplingConfig):
        track: attr.Attribute
        track_config = getattr(sampling_config, track.name, None)
        if isinstance(track_config, SamplingTrackConfig):
            max_topk = track.metadata["max_topk"]
            if track_config.topk_logprobs > max_topk:
                msg = (
                    f"Sampling track {track.name} has topk_logprobs={track_config.topk_logprobs} "
                    f"greater than MAX_TOPK={max_topk}."
                )
                if on_invalid == "raise":
                    raise AssertionError(msg)
                else:
                    warnings.warn(msg)


def sample_logits(
    logits: torch.Tensor,
    temperature: float | torch.Tensor,
    valid_ids: list[int] = [],
    top_p: float | torch.Tensor = 1.0,
    mask_logits_of_invalid_ids: bool = True,
):
    """Default sampling from logits.

    Args:
        logits is shape (..., vocab_size)
        temperature is broadcastable to (...)
    """
    if len(valid_ids) == 0:
        raise ValueError(
            "Can not sample logits if there are no valid ids to sample from."
        )

    if top_p < 1.0:
        logits = top_p_logits(logits, top_p=top_p)

    temperature = _tensorize_like(temperature, logits)
    batch_dims = logits.size()[:-1]
    logits = logits.reshape(-1, logits.shape[-1])

    # Only sample from valid ids
    # the /logits endpoint should receive unmodified logits
    if mask_logits_of_invalid_ids:
        mask = torch.ones_like(logits, dtype=torch.bool)
        mask[..., valid_ids] = False
        logits[mask] = -torch.inf

    if torch.all(temperature == 0):
        ids = logits.argmax(-1)
        return ids.reshape(*batch_dims)

    assert not torch.any(temperature == 0), "Partial temperature 0 not supported."

    # Sample from all logits
    probs = F.softmax(logits / temperature[..., None], dim=-1)
    ids = torch.multinomial(probs, 1).squeeze(1)

    ids = ids.reshape(*batch_dims)
    return ids


def sample_function_logits(
    logits: torch.Tensor,
    tokenizer: InterProQuantizedTokenizer,
    top_p: float | torch.Tensor = 1.0,
    temperature: float | torch.Tensor = 1.0,
    p_none_threshold: float = 0.05,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Works with inputs that have batch dimension."""
    [B, L, D, V] = logits.shape
    assert D == tokenizer.depth

    if top_p < 1.0:
        logits = top_p_logits(logits, top_p=top_p)

    temperature = torch.ones_like(logits[..., 0]) * temperature

    log_p = F.log_softmax(logits / temperature[..., None], dim=-1)  # (B, L, D, V)

    # Choose which positions have no predicted function.
    none_index = tokenizer.vocab_to_index["<none>"]
    log_p_nones = log_p[..., none_index]  # (B, L, D)
    p_none = torch.exp(log_p_nones).mean(dim=-1)  # "Ensemble of <none> predictions"
    where_none = p_none > p_none_threshold  # (B, L)

    # Set probability of <none> to 0 for all not-none positions
    batch_size, seq_len, depth = log_p.shape[:-1]
    expanded_where_not_none = ~where_none.unsqueeze(-1).unsqueeze(-1)  # (B, L, 1, 1)
    expanded_where_not_none = expanded_where_not_none.expand(
        batch_size, seq_len, depth, 1
    )  # (B, L, D, 1)
    indices = torch.arange(log_p.shape[-1], device=log_p.device)  # (V,)
    mask = indices == none_index  # (V,)
    mask = expanded_where_not_none & mask  # (B, L, D, 1) x (V,) -> (B, L, D, V)
    log_p[mask] = -torch.inf

    ids = torch.argmax(log_p, dim=-1)  # (B, L, D)
    ids[where_none, :] = tokenizer.vocab_to_index["<none>"]

    return ids, log_p


def sample_residue_annotation_logits(
    logits: torch.Tensor, annotation_threshold: float = 0.5
) -> tuple[torch.Tensor, torch.Tensor]:
    # Take top residue annotations
    top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
        ..., :MAX_RESIDUE_ANNOTATIONS
    ]  # (B, L, MAX_R)
    top_residue_annotations_logprobs = torch.gather(
        F.logsigmoid(logits), -1, top_residue_annotations_idx
    )  # (B, L, MAX_R)
    top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
    # Keep only positive predictions
    is_negative = top_residue_annotations_probs < annotation_threshold
    top_residue_annotations_idx[is_negative] = 0

    top_residue_annotations_logprobs = top_residue_annotations_logprobs

    return top_residue_annotations_idx, top_residue_annotations_logprobs


def sample_sasa_logits(
    logits: torch.Tensor,
    tokens: torch.Tensor,
    sampling_track_config: SamplingTrackConfig,
    mask_idx: int,
    valid_ids: list[int],
    mask_logits_of_invalid_ids: bool = True,
) -> torch.Tensor:
    # Only sample from valid ids
    # the /logits endpoint should receive unmodified logits
    if mask_logits_of_invalid_ids:
        mask = torch.ones_like(logits, dtype=torch.bool)
        mask[..., valid_ids] = False
        logits[mask] = -torch.inf

    sasa_probs = torch.nn.functional.softmax(logits, dim=-1)
    max_prob_idx = torch.argmax(sasa_probs, dim=-1)
    sasa_bins = torch.tensor([0] + SASA_DISCRETIZATION_BOUNDARIES, dtype=torch.float)
    sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2
    sasa_bins = sasa_bins.to(sasa_probs.device)

    sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
    # Adjust sasa_values based on max_prob_idx conditions
    sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1)
    sasa_value[max_prob_idx == 18] = float("inf")
    sasa_value[~sampling_mask] = float("inf")

    return sasa_value


def top_p_logits(logits: torch.Tensor, top_p: float | torch.Tensor) -> torch.Tensor:
    top_p = _tensorize_like(top_p, logits)

    batch_dims = logits.size()[:-1]
    logits = logits.reshape(-1, logits.shape[-1])

    # Sort logits in descending order and extract the mask for the top_p
    sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
    cumsum_logits = sorted_logits.softmax(-1).cumsum(-1)
    top_p_mask = cumsum_logits <= top_p[:, None]

    # Make sure at least one token is sampled
    top_p_mask[:, 0] = True

    # Mask out the logits that are not in the top_p
    batch_indices_to_mask, _ = torch.where(~top_p_mask)
    vocab_indices_to_mask = sorted_indices[~top_p_mask]
    logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min

    return logits.reshape(*batch_dims, -1)


def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
    if isinstance(value, (float, int)):
        value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
    return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)


def get_sampling_mask(
    tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int
):
    # Do not sample at BOS and EOS tokens
    sampling_mask = torch.ones_like(tokens, dtype=torch.bool)  # (B, L, )
    sampling_mask[:, 0] = False
    sampling_mask[:, -1] = False

    # Do not sample at special token positions but allow sampling at mask token
    special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
    if len(special_minus_mask) > 0:
        special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
        assert special_tokens.numel() > 0
        sampling_mask = sampling_mask & (
            tokens[..., None] != special_tokens[None, :]
        ).all(-1)

    # Keep only samples from masked positions (if specified)
    if sampling_track_config.only_sample_masked_tokens:
        masked_tokens = tokens == mask_idx
        sampling_mask = sampling_mask & masked_tokens
    return sampling_mask