File size: 4,905 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | from functools import cached_property
import torch
from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase
from src.data.esm.utils.constants import esm3 as C
class SASADiscretizingTokenizer(EsmTokenizerBase):
"""Tokenizer for Solvent Accessible Surface Area (SASA)."""
def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES):
self._boundaries = sorted(boundaries)
@cached_property
def special_tokens(self) -> list[str]:
return ["<pad>", "<motif>", "<unk>"]
@cached_property
def vocab(self) -> list[str]:
"""Discrete token vocabulary.
Returns:
token vocabulary with ranges represented as "<low-high>".
"""
boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"]
range_tokens = [
f"<{low}-{high}>"
for low, high in zip(boundary_strs[:-1], boundary_strs[1:])
]
return self.special_tokens + range_tokens
@cached_property
def midpoints_tensor(self) -> torch.Tensor:
"""Midpoints of the SASA token ranges."""
boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
midpoint_tokens = [
(float(high) + float(low)) / 2
for low, high in zip(boundaries[:-1], boundaries[1:])
]
midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
return torch.Tensor(midpoint_tokens)
def midpoints(self) -> list[float]:
"""Midpoints of the SASA token ranges."""
return self.midpoints_tensor.tolist()
@cached_property
def vocab_to_index(self) -> dict[str, int]:
"""Constructs token -> token id mapping."""
return {word: i for i, word in enumerate(self.vocab)}
def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
"""Determines which positions are special tokens.
Args:
tokens: <int>[length]
Returns:
<bool>[length] tensor, true where special tokens are located in the input.
"""
return tokens < len(self.special_tokens)
def encode(
self, values: list[float | str], add_special_tokens: bool = True
) -> torch.Tensor:
"""Encodes SASA values as discrete tokens.
Args:
values: list of either SASA values or individual tokens. For example
[1.2, "<pad>", 10.3, <pad>, 0.]
Returns:
Token ids as tensor. Adds BOS and EOS special tokens.
"""
ids = []
if add_special_tokens:
ids.append(self.vocab_to_index["<pad>"]) # BOS
for value in values:
if isinstance(value, (float, int)):
bucket = torch.bucketize(value, torch.tensor(self._boundaries))
token_id = len(self.special_tokens) + bucket
elif isinstance(value, str):
token_id = self.vocab_to_index[value]
else:
raise TypeError(value)
ids.append(token_id)
if add_special_tokens:
ids.append(self.vocab_to_index["<pad>"]) # EOS
return torch.tensor(ids, dtype=torch.int64)
def decode_float(self, encoded: torch.Tensor) -> list[float]:
"""Decodes SASA token ids into float values."""
decoded = self.midpoints_tensor[encoded.cpu()]
nan_mask = torch.isnan(decoded)
np_arr = decoded.numpy()
np_arr[nan_mask.numpy()] = None
return np_arr.tolist()
def decode(self, encoded: torch.Tensor) -> str:
"""Decodes SASA token ids."""
return ",".join(self.vocab[i] for i in encoded)
def decode_list(self, encoded: torch.Tensor) -> list[str]:
"""Decodes SASA token ids."""
return [self.vocab[i] for i in encoded]
@property
def mask_token(self) -> str:
return "<pad>"
@property
def mask_token_id(self) -> int:
return self.vocab_to_index[self.mask_token]
@property
def bos_token(self) -> str:
return "<pad>"
@property
def bos_token_id(self) -> int:
return self.vocab_to_index[self.bos_token]
@property
def eos_token(self) -> str:
return "<pad>"
@property
def eos_token_id(self) -> int:
return self.vocab_to_index[self.eos_token]
@property
def pad_token(self) -> str:
return "<pad>"
@property
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]
@property
def chain_break_token(self) -> str:
return "<pad>"
@property
def chain_break_token_id(self) -> int:
return self.vocab_to_index[self.chain_break_token]
@property
def all_token_ids(self):
return list(range(len(self.vocab)))
@property
def special_token_ids(self):
return [self.vocab_to_index[token] for token in self.special_tokens]
|