import warnings from typing import Literal import attr import torch import torch.nn.functional as F from src.data.esm.sdk.api import ( ESMProteinTensor, SamplingConfig, SamplingTrackConfig, ) from src.data.esm.tokenization import ( TokenizerCollectionProtocol, get_invalid_tokenizer_ids, ) from src.data.esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer, ) from src.data.esm.utils.constants.esm3 import ( MAX_RESIDUE_ANNOTATIONS, SASA_DISCRETIZATION_BOUNDARIES, ) def _non_batched_dims(k: str, v: torch.Tensor): match k: case "sequence": return 1 case "structure": if v.is_floating_point(): # This is the one hot soft structure token. return 2 else: # This is the normal int structure token. return 1 case "secondary_structure": return 1 case "sasa": return 1 case "function": return 2 case "residue_annotations": return 2 case "coordinates": return 3 case _: raise ValueError(f"Unknown dim for track {k}") class _BatchedESMProteinTensor(ESMProteinTensor): @staticmethod def from_protein_tensor(protein: ESMProteinTensor): def _maybe_unsqueeze(x: torch.Tensor | None): return x.unsqueeze(0) if x is not None else None return _BatchedESMProteinTensor( sequence=_maybe_unsqueeze(protein.sequence), structure=_maybe_unsqueeze(protein.structure), secondary_structure=_maybe_unsqueeze(protein.secondary_structure), sasa=_maybe_unsqueeze(protein.sasa), function=_maybe_unsqueeze(protein.function), residue_annotations=_maybe_unsqueeze(protein.residue_annotations), coordinates=_maybe_unsqueeze(protein.coordinates), ) def __len__(self) -> int: def get_len(k, v) -> int: assert len(v.shape) == _non_batched_dims(k, v) + 1 return v.size(1) l = self._detect_attribute(get_len, "length") return l if l is not None else 0 @property def batch_size(self) -> int: def get_batch_size(k, v) -> int: assert len(v.shape) == _non_batched_dims(k, v) + 1 return v.size(0) d = self._detect_attribute(get_batch_size, "batch size") assert d is not None return d def slice(self, i: int, sequence_len: int | None = None) -> ESMProteinTensor: def _maybe_slice(x: torch.Tensor | None): if x is None: return None row = x[i] if sequence_len is not None: row = row[:sequence_len] return row return ESMProteinTensor( sequence=_maybe_slice(self.sequence), structure=_maybe_slice(self.structure), secondary_structure=_maybe_slice(self.secondary_structure), sasa=_maybe_slice(self.sasa), function=_maybe_slice(self.function), residue_annotations=_maybe_slice(self.residue_annotations), coordinates=_maybe_slice(self.coordinates), ) def set_slice(self, i: int, slice: ESMProteinTensor): """Update the i-th slice of this tensor data class.""" for f in attr.fields(ESMProteinTensor): s = getattr(self, f.name) v = getattr(slice, f.name) assert v is None or ( v is not None and s is not None ), f"Trying to set a slice on None tensor ({f.name})." if v is not None: s[i, ...] = v def get_default_sampling_config( tokenizers: TokenizerCollectionProtocol, ) -> SamplingConfig: tracks = [f.name for f in attr.fields(SamplingConfig)] sampling_config = SamplingConfig() for current_track in tracks: setattr( sampling_config, current_track, SamplingTrackConfig( invalid_ids=get_invalid_tokenizer_ids( getattr(tokenizers, current_track) ), temperature=1.0, top_p=1.0, # TODO: Add different mask and padding tokens for all tracks # Some tracks have the same pad and mask, which causes ambiguity when sampling only_sample_masked_tokens=current_track not in ["secondary_structure", "sasa", "function"], ), ) return sampling_config def validate_sampling_config( sampling_config: SamplingConfig, on_invalid: Literal["raise", "warn"] = "warn" ): # Check that all tracks have topk_logprobs less or equal to MAX_TOP_K for track in attr.fields(SamplingConfig): track: attr.Attribute track_config = getattr(sampling_config, track.name, None) if isinstance(track_config, SamplingTrackConfig): max_topk = track.metadata["max_topk"] if track_config.topk_logprobs > max_topk: msg = ( f"Sampling track {track.name} has topk_logprobs={track_config.topk_logprobs} " f"greater than MAX_TOPK={max_topk}." ) if on_invalid == "raise": raise AssertionError(msg) else: warnings.warn(msg) def sample_logits( logits: torch.Tensor, temperature: float | torch.Tensor, valid_ids: list[int] = [], top_p: float | torch.Tensor = 1.0, mask_logits_of_invalid_ids: bool = True, ): """Default sampling from logits. Args: logits is shape (..., vocab_size) temperature is broadcastable to (...) """ if len(valid_ids) == 0: raise ValueError( "Can not sample logits if there are no valid ids to sample from." ) if top_p < 1.0: logits = top_p_logits(logits, top_p=top_p) temperature = _tensorize_like(temperature, logits) batch_dims = logits.size()[:-1] logits = logits.reshape(-1, logits.shape[-1]) # Only sample from valid ids # the /logits endpoint should receive unmodified logits if mask_logits_of_invalid_ids: mask = torch.ones_like(logits, dtype=torch.bool) mask[..., valid_ids] = False logits[mask] = -torch.inf if torch.all(temperature == 0): ids = logits.argmax(-1) return ids.reshape(*batch_dims) assert not torch.any(temperature == 0), "Partial temperature 0 not supported." # Sample from all logits probs = F.softmax(logits / temperature[..., None], dim=-1) ids = torch.multinomial(probs, 1).squeeze(1) ids = ids.reshape(*batch_dims) return ids def sample_function_logits( logits: torch.Tensor, tokenizer: InterProQuantizedTokenizer, top_p: float | torch.Tensor = 1.0, temperature: float | torch.Tensor = 1.0, p_none_threshold: float = 0.05, ) -> tuple[torch.Tensor, torch.Tensor]: """Works with inputs that have batch dimension.""" [B, L, D, V] = logits.shape assert D == tokenizer.depth if top_p < 1.0: logits = top_p_logits(logits, top_p=top_p) temperature = torch.ones_like(logits[..., 0]) * temperature log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (B, L, D, V) # Choose which positions have no predicted function. none_index = tokenizer.vocab_to_index[""] log_p_nones = log_p[..., none_index] # (B, L, D) p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of predictions" where_none = p_none > p_none_threshold # (B, L) # Set probability of to 0 for all not-none positions batch_size, seq_len, depth = log_p.shape[:-1] expanded_where_not_none = ~where_none.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1) expanded_where_not_none = expanded_where_not_none.expand( batch_size, seq_len, depth, 1 ) # (B, L, D, 1) indices = torch.arange(log_p.shape[-1], device=log_p.device) # (V,) mask = indices == none_index # (V,) mask = expanded_where_not_none & mask # (B, L, D, 1) x (V,) -> (B, L, D, V) log_p[mask] = -torch.inf ids = torch.argmax(log_p, dim=-1) # (B, L, D) ids[where_none, :] = tokenizer.vocab_to_index[""] return ids, log_p def sample_residue_annotation_logits( logits: torch.Tensor, annotation_threshold: float = 0.5 ) -> tuple[torch.Tensor, torch.Tensor]: # Take top residue annotations top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[ ..., :MAX_RESIDUE_ANNOTATIONS ] # (B, L, MAX_R) top_residue_annotations_logprobs = torch.gather( F.logsigmoid(logits), -1, top_residue_annotations_idx ) # (B, L, MAX_R) top_residue_annotations_probs = top_residue_annotations_logprobs.exp() # Keep only positive predictions is_negative = top_residue_annotations_probs < annotation_threshold top_residue_annotations_idx[is_negative] = 0 top_residue_annotations_logprobs = top_residue_annotations_logprobs return top_residue_annotations_idx, top_residue_annotations_logprobs def sample_sasa_logits( logits: torch.Tensor, tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int, valid_ids: list[int], mask_logits_of_invalid_ids: bool = True, ) -> torch.Tensor: # Only sample from valid ids # the /logits endpoint should receive unmodified logits if mask_logits_of_invalid_ids: mask = torch.ones_like(logits, dtype=torch.bool) mask[..., valid_ids] = False logits[mask] = -torch.inf sasa_probs = torch.nn.functional.softmax(logits, dim=-1) max_prob_idx = torch.argmax(sasa_probs, dim=-1) sasa_bins = torch.tensor([0] + SASA_DISCRETIZATION_BOUNDARIES, dtype=torch.float) sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2 sasa_bins = sasa_bins.to(sasa_probs.device) sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx) # Adjust sasa_values based on max_prob_idx conditions sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1) sasa_value[max_prob_idx == 18] = float("inf") sasa_value[~sampling_mask] = float("inf") return sasa_value def top_p_logits(logits: torch.Tensor, top_p: float | torch.Tensor) -> torch.Tensor: top_p = _tensorize_like(top_p, logits) batch_dims = logits.size()[:-1] logits = logits.reshape(-1, logits.shape[-1]) # Sort logits in descending order and extract the mask for the top_p sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) cumsum_logits = sorted_logits.softmax(-1).cumsum(-1) top_p_mask = cumsum_logits <= top_p[:, None] # Make sure at least one token is sampled top_p_mask[:, 0] = True # Mask out the logits that are not in the top_p batch_indices_to_mask, _ = torch.where(~top_p_mask) vocab_indices_to_mask = sorted_indices[~top_p_mask] logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min return logits.reshape(*batch_dims, -1) def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor): if isinstance(value, (float, int)): value = torch.full_like(logits[..., 0], value, dtype=logits.dtype) return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1) def get_sampling_mask( tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int ): # Do not sample at BOS and EOS tokens sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, ) sampling_mask[:, 0] = False sampling_mask[:, -1] = False # Do not sample at special token positions but allow sampling at mask token special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx}) if len(special_minus_mask) > 0: special_tokens = torch.tensor(special_minus_mask, device=tokens.device) assert special_tokens.numel() > 0 sampling_mask = sampling_mask & ( tokens[..., None] != special_tokens[None, :] ).all(-1) # Keep only samples from masked positions (if specified) if sampling_track_config.only_sample_masked_tokens: masked_tokens = tokens == mask_idx sampling_mask = sampling_mask & masked_tokens return sampling_mask