| | from functools import cached_property |
| |
|
| | import torch |
| |
|
| | from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase |
| | from src.data.esm.utils.constants import esm3 as C |
| |
|
| |
|
| | class SASADiscretizingTokenizer(EsmTokenizerBase): |
| | """Tokenizer for Solvent Accessible Surface Area (SASA).""" |
| |
|
| | def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES): |
| | self._boundaries = sorted(boundaries) |
| |
|
| | @cached_property |
| | def special_tokens(self) -> list[str]: |
| | return ["<pad>", "<motif>", "<unk>"] |
| |
|
| | @cached_property |
| | def vocab(self) -> list[str]: |
| | """Discrete token vocabulary. |
| | |
| | Returns: |
| | token vocabulary with ranges represented as "<low-high>". |
| | """ |
| | boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"] |
| | range_tokens = [ |
| | f"<{low}-{high}>" |
| | for low, high in zip(boundary_strs[:-1], boundary_strs[1:]) |
| | ] |
| | return self.special_tokens + range_tokens |
| |
|
| | @cached_property |
| | def midpoints_tensor(self) -> torch.Tensor: |
| | """Midpoints of the SASA token ranges.""" |
| | boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2] |
| | midpoint_tokens = [ |
| | (float(high) + float(low)) / 2 |
| | for low, high in zip(boundaries[:-1], boundaries[1:]) |
| | ] |
| | midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens |
| | return torch.Tensor(midpoint_tokens) |
| |
|
| | def midpoints(self) -> list[float]: |
| | """Midpoints of the SASA token ranges.""" |
| | return self.midpoints_tensor.tolist() |
| |
|
| | @cached_property |
| | def vocab_to_index(self) -> dict[str, int]: |
| | """Constructs token -> token id mapping.""" |
| | return {word: i for i, word in enumerate(self.vocab)} |
| |
|
| | def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor: |
| | """Determines which positions are special tokens. |
| | |
| | Args: |
| | tokens: <int>[length] |
| | Returns: |
| | <bool>[length] tensor, true where special tokens are located in the input. |
| | """ |
| | return tokens < len(self.special_tokens) |
| |
|
| | def encode( |
| | self, values: list[float | str], add_special_tokens: bool = True |
| | ) -> torch.Tensor: |
| | """Encodes SASA values as discrete tokens. |
| | |
| | Args: |
| | values: list of either SASA values or individual tokens. For example |
| | [1.2, "<pad>", 10.3, <pad>, 0.] |
| | Returns: |
| | Token ids as tensor. Adds BOS and EOS special tokens. |
| | """ |
| | ids = [] |
| | if add_special_tokens: |
| | ids.append(self.vocab_to_index["<pad>"]) |
| | for value in values: |
| | if isinstance(value, (float, int)): |
| | bucket = torch.bucketize(value, torch.tensor(self._boundaries)) |
| | token_id = len(self.special_tokens) + bucket |
| | elif isinstance(value, str): |
| | token_id = self.vocab_to_index[value] |
| | else: |
| | raise TypeError(value) |
| | ids.append(token_id) |
| | if add_special_tokens: |
| | ids.append(self.vocab_to_index["<pad>"]) |
| |
|
| | return torch.tensor(ids, dtype=torch.int64) |
| |
|
| | def decode_float(self, encoded: torch.Tensor) -> list[float]: |
| | """Decodes SASA token ids into float values.""" |
| | decoded = self.midpoints_tensor[encoded.cpu()] |
| | nan_mask = torch.isnan(decoded) |
| | np_arr = decoded.numpy() |
| | np_arr[nan_mask.numpy()] = None |
| | return np_arr.tolist() |
| |
|
| | def decode(self, encoded: torch.Tensor) -> str: |
| | """Decodes SASA token ids.""" |
| | return ",".join(self.vocab[i] for i in encoded) |
| |
|
| | def decode_list(self, encoded: torch.Tensor) -> list[str]: |
| | """Decodes SASA token ids.""" |
| | return [self.vocab[i] for i in encoded] |
| |
|
| | @property |
| | def mask_token(self) -> str: |
| | return "<pad>" |
| |
|
| | @property |
| | def mask_token_id(self) -> int: |
| | return self.vocab_to_index[self.mask_token] |
| |
|
| | @property |
| | def bos_token(self) -> str: |
| | return "<pad>" |
| |
|
| | @property |
| | def bos_token_id(self) -> int: |
| | return self.vocab_to_index[self.bos_token] |
| |
|
| | @property |
| | def eos_token(self) -> str: |
| | return "<pad>" |
| |
|
| | @property |
| | def eos_token_id(self) -> int: |
| | return self.vocab_to_index[self.eos_token] |
| |
|
| | @property |
| | def pad_token(self) -> str: |
| | return "<pad>" |
| |
|
| | @property |
| | def pad_token_id(self) -> int: |
| | return self.vocab_to_index[self.pad_token] |
| |
|
| | @property |
| | def chain_break_token(self) -> str: |
| | return "<pad>" |
| |
|
| | @property |
| | def chain_break_token_id(self) -> int: |
| | return self.vocab_to_index[self.chain_break_token] |
| |
|
| | @property |
| | def all_token_ids(self): |
| | return list(range(len(self.vocab))) |
| |
|
| | @property |
| | def special_token_ids(self): |
| | return [self.vocab_to_index[token] for token in self.special_tokens] |
| |
|