| import warnings |
| from typing import Literal |
|
|
| import attr |
| import torch |
| import torch.nn.functional as F |
|
|
| from src.data.esm.sdk.api import ( |
| ESMProteinTensor, |
| SamplingConfig, |
| SamplingTrackConfig, |
| ) |
| from src.data.esm.tokenization import ( |
| TokenizerCollectionProtocol, |
| get_invalid_tokenizer_ids, |
| ) |
| from src.data.esm.tokenization.function_tokenizer import ( |
| InterProQuantizedTokenizer, |
| ) |
| from src.data.esm.utils.constants.esm3 import ( |
| MAX_RESIDUE_ANNOTATIONS, |
| SASA_DISCRETIZATION_BOUNDARIES, |
| ) |
|
|
|
|
| def _non_batched_dims(k: str, v: torch.Tensor): |
| match k: |
| case "sequence": |
| return 1 |
| case "structure": |
| if v.is_floating_point(): |
| |
| return 2 |
| else: |
| |
| return 1 |
| case "secondary_structure": |
| return 1 |
| case "sasa": |
| return 1 |
| case "function": |
| return 2 |
| case "residue_annotations": |
| return 2 |
| case "coordinates": |
| return 3 |
| case _: |
| raise ValueError(f"Unknown dim for track {k}") |
|
|
|
|
| class _BatchedESMProteinTensor(ESMProteinTensor): |
| @staticmethod |
| def from_protein_tensor(protein: ESMProteinTensor): |
| def _maybe_unsqueeze(x: torch.Tensor | None): |
| return x.unsqueeze(0) if x is not None else None |
|
|
| return _BatchedESMProteinTensor( |
| sequence=_maybe_unsqueeze(protein.sequence), |
| structure=_maybe_unsqueeze(protein.structure), |
| secondary_structure=_maybe_unsqueeze(protein.secondary_structure), |
| sasa=_maybe_unsqueeze(protein.sasa), |
| function=_maybe_unsqueeze(protein.function), |
| residue_annotations=_maybe_unsqueeze(protein.residue_annotations), |
| coordinates=_maybe_unsqueeze(protein.coordinates), |
| ) |
|
|
| def __len__(self) -> int: |
| def get_len(k, v) -> int: |
| assert len(v.shape) == _non_batched_dims(k, v) + 1 |
| return v.size(1) |
|
|
| l = self._detect_attribute(get_len, "length") |
| return l if l is not None else 0 |
|
|
| @property |
| def batch_size(self) -> int: |
| def get_batch_size(k, v) -> int: |
| assert len(v.shape) == _non_batched_dims(k, v) + 1 |
| return v.size(0) |
|
|
| d = self._detect_attribute(get_batch_size, "batch size") |
| assert d is not None |
| return d |
|
|
| def slice(self, i: int, sequence_len: int | None = None) -> ESMProteinTensor: |
| def _maybe_slice(x: torch.Tensor | None): |
| if x is None: |
| return None |
| row = x[i] |
| if sequence_len is not None: |
| row = row[:sequence_len] |
| return row |
|
|
| return ESMProteinTensor( |
| sequence=_maybe_slice(self.sequence), |
| structure=_maybe_slice(self.structure), |
| secondary_structure=_maybe_slice(self.secondary_structure), |
| sasa=_maybe_slice(self.sasa), |
| function=_maybe_slice(self.function), |
| residue_annotations=_maybe_slice(self.residue_annotations), |
| coordinates=_maybe_slice(self.coordinates), |
| ) |
|
|
| def set_slice(self, i: int, slice: ESMProteinTensor): |
| """Update the i-th slice of this tensor data class.""" |
| for f in attr.fields(ESMProteinTensor): |
| s = getattr(self, f.name) |
| v = getattr(slice, f.name) |
|
|
| assert v is None or ( |
| v is not None and s is not None |
| ), f"Trying to set a slice on None tensor ({f.name})." |
|
|
| if v is not None: |
| s[i, ...] = v |
|
|
|
|
| def get_default_sampling_config( |
| tokenizers: TokenizerCollectionProtocol, |
| ) -> SamplingConfig: |
| tracks = [f.name for f in attr.fields(SamplingConfig)] |
| sampling_config = SamplingConfig() |
| for current_track in tracks: |
| setattr( |
| sampling_config, |
| current_track, |
| SamplingTrackConfig( |
| invalid_ids=get_invalid_tokenizer_ids( |
| getattr(tokenizers, current_track) |
| ), |
| temperature=1.0, |
| top_p=1.0, |
| |
| |
| only_sample_masked_tokens=current_track |
| not in ["secondary_structure", "sasa", "function"], |
| ), |
| ) |
| return sampling_config |
|
|
|
|
| def validate_sampling_config( |
| sampling_config: SamplingConfig, on_invalid: Literal["raise", "warn"] = "warn" |
| ): |
| |
| for track in attr.fields(SamplingConfig): |
| track: attr.Attribute |
| track_config = getattr(sampling_config, track.name, None) |
| if isinstance(track_config, SamplingTrackConfig): |
| max_topk = track.metadata["max_topk"] |
| if track_config.topk_logprobs > max_topk: |
| msg = ( |
| f"Sampling track {track.name} has topk_logprobs={track_config.topk_logprobs} " |
| f"greater than MAX_TOPK={max_topk}." |
| ) |
| if on_invalid == "raise": |
| raise AssertionError(msg) |
| else: |
| warnings.warn(msg) |
|
|
|
|
| def sample_logits( |
| logits: torch.Tensor, |
| temperature: float | torch.Tensor, |
| valid_ids: list[int] = [], |
| top_p: float | torch.Tensor = 1.0, |
| mask_logits_of_invalid_ids: bool = True, |
| ): |
| """Default sampling from logits. |
| |
| Args: |
| logits is shape (..., vocab_size) |
| temperature is broadcastable to (...) |
| """ |
| if len(valid_ids) == 0: |
| raise ValueError( |
| "Can not sample logits if there are no valid ids to sample from." |
| ) |
|
|
| if top_p < 1.0: |
| logits = top_p_logits(logits, top_p=top_p) |
|
|
| temperature = _tensorize_like(temperature, logits) |
| batch_dims = logits.size()[:-1] |
| logits = logits.reshape(-1, logits.shape[-1]) |
|
|
| |
| |
| if mask_logits_of_invalid_ids: |
| mask = torch.ones_like(logits, dtype=torch.bool) |
| mask[..., valid_ids] = False |
| logits[mask] = -torch.inf |
|
|
| if torch.all(temperature == 0): |
| ids = logits.argmax(-1) |
| return ids.reshape(*batch_dims) |
|
|
| assert not torch.any(temperature == 0), "Partial temperature 0 not supported." |
|
|
| |
| probs = F.softmax(logits / temperature[..., None], dim=-1) |
| ids = torch.multinomial(probs, 1).squeeze(1) |
|
|
| ids = ids.reshape(*batch_dims) |
| return ids |
|
|
|
|
| def sample_function_logits( |
| logits: torch.Tensor, |
| tokenizer: InterProQuantizedTokenizer, |
| top_p: float | torch.Tensor = 1.0, |
| temperature: float | torch.Tensor = 1.0, |
| p_none_threshold: float = 0.05, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Works with inputs that have batch dimension.""" |
| [B, L, D, V] = logits.shape |
| assert D == tokenizer.depth |
|
|
| if top_p < 1.0: |
| logits = top_p_logits(logits, top_p=top_p) |
|
|
| temperature = torch.ones_like(logits[..., 0]) * temperature |
|
|
| log_p = F.log_softmax(logits / temperature[..., None], dim=-1) |
|
|
| |
| none_index = tokenizer.vocab_to_index["<none>"] |
| log_p_nones = log_p[..., none_index] |
| p_none = torch.exp(log_p_nones).mean(dim=-1) |
| where_none = p_none > p_none_threshold |
|
|
| |
| batch_size, seq_len, depth = log_p.shape[:-1] |
| expanded_where_not_none = ~where_none.unsqueeze(-1).unsqueeze(-1) |
| expanded_where_not_none = expanded_where_not_none.expand( |
| batch_size, seq_len, depth, 1 |
| ) |
| indices = torch.arange(log_p.shape[-1], device=log_p.device) |
| mask = indices == none_index |
| mask = expanded_where_not_none & mask |
| log_p[mask] = -torch.inf |
|
|
| ids = torch.argmax(log_p, dim=-1) |
| ids[where_none, :] = tokenizer.vocab_to_index["<none>"] |
|
|
| return ids, log_p |
|
|
|
|
| def sample_residue_annotation_logits( |
| logits: torch.Tensor, annotation_threshold: float = 0.5 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[ |
| ..., :MAX_RESIDUE_ANNOTATIONS |
| ] |
| top_residue_annotations_logprobs = torch.gather( |
| F.logsigmoid(logits), -1, top_residue_annotations_idx |
| ) |
| top_residue_annotations_probs = top_residue_annotations_logprobs.exp() |
| |
| is_negative = top_residue_annotations_probs < annotation_threshold |
| top_residue_annotations_idx[is_negative] = 0 |
|
|
| top_residue_annotations_logprobs = top_residue_annotations_logprobs |
|
|
| return top_residue_annotations_idx, top_residue_annotations_logprobs |
|
|
|
|
| def sample_sasa_logits( |
| logits: torch.Tensor, |
| tokens: torch.Tensor, |
| sampling_track_config: SamplingTrackConfig, |
| mask_idx: int, |
| valid_ids: list[int], |
| mask_logits_of_invalid_ids: bool = True, |
| ) -> torch.Tensor: |
| |
| |
| if mask_logits_of_invalid_ids: |
| mask = torch.ones_like(logits, dtype=torch.bool) |
| mask[..., valid_ids] = False |
| logits[mask] = -torch.inf |
|
|
| sasa_probs = torch.nn.functional.softmax(logits, dim=-1) |
| max_prob_idx = torch.argmax(sasa_probs, dim=-1) |
| sasa_bins = torch.tensor([0] + SASA_DISCRETIZATION_BOUNDARIES, dtype=torch.float) |
| sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2 |
| sasa_bins = sasa_bins.to(sasa_probs.device) |
|
|
| sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx) |
| |
| sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1) |
| sasa_value[max_prob_idx == 18] = float("inf") |
| sasa_value[~sampling_mask] = float("inf") |
|
|
| return sasa_value |
|
|
|
|
| def top_p_logits(logits: torch.Tensor, top_p: float | torch.Tensor) -> torch.Tensor: |
| top_p = _tensorize_like(top_p, logits) |
|
|
| batch_dims = logits.size()[:-1] |
| logits = logits.reshape(-1, logits.shape[-1]) |
|
|
| |
| sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
| cumsum_logits = sorted_logits.softmax(-1).cumsum(-1) |
| top_p_mask = cumsum_logits <= top_p[:, None] |
|
|
| |
| top_p_mask[:, 0] = True |
|
|
| |
| batch_indices_to_mask, _ = torch.where(~top_p_mask) |
| vocab_indices_to_mask = sorted_indices[~top_p_mask] |
| logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min |
|
|
| return logits.reshape(*batch_dims, -1) |
|
|
|
|
| def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor): |
| if isinstance(value, (float, int)): |
| value = torch.full_like(logits[..., 0], value, dtype=logits.dtype) |
| return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1) |
|
|
|
|
| def get_sampling_mask( |
| tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int |
| ): |
| |
| sampling_mask = torch.ones_like(tokens, dtype=torch.bool) |
| sampling_mask[:, 0] = False |
| sampling_mask[:, -1] = False |
|
|
| |
| special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx}) |
| if len(special_minus_mask) > 0: |
| special_tokens = torch.tensor(special_minus_mask, device=tokens.device) |
| assert special_tokens.numel() > 0 |
| sampling_mask = sampling_mask & ( |
| tokens[..., None] != special_tokens[None, :] |
| ).all(-1) |
|
|
| |
| if sampling_track_config.only_sample_masked_tokens: |
| masked_tokens = tokens == mask_idx |
| sampling_mask = sampling_mask & masked_tokens |
| return sampling_mask |
|
|