File size: 2,715 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase
from src.data.esm.utils.constants import esm3 as C
class StructureTokenizer(EsmTokenizerBase):
"""A convenince class for accessing special token ids of
the StructureTokenEncoder and StructureTokenDecoder."""
def __init__(self, codebook_size: int = C.VQVAE_CODEBOOK_SIZE):
self.vq_vae_special_tokens = {
"MASK": codebook_size,
"EOS": codebook_size + 1,
"BOS": codebook_size + 2,
"PAD": codebook_size + 3,
"CHAINBREAK": codebook_size + 4,
}
def mask_token(self) -> str:
raise NotImplementedError(
"Structure tokens are defined on 3D coordinates, not strings."
)
@property
def mask_token_id(self) -> int:
return self.vq_vae_special_tokens["MASK"]
def bos_token(self) -> str:
raise NotImplementedError(
"Structure tokens are defined on 3D coordinates, not strings."
)
@property
def bos_token_id(self) -> int:
return self.vq_vae_special_tokens["BOS"]
def eos_token(self) -> str:
raise NotImplementedError(
"Structure tokens are defined on 3D coordinates, not strings."
)
@property
def eos_token_id(self) -> int:
return self.vq_vae_special_tokens["EOS"]
def pad_token(self) -> str:
raise NotImplementedError(
"Structure tokens are defined on 3D coordinates, not strings."
)
@property
def pad_token_id(self) -> int:
return self.vq_vae_special_tokens["PAD"]
def chain_break_token(self) -> str:
raise NotImplementedError(
"Structure tokens are defined on 3D coordinates, not strings."
)
@property
def chain_break_token_id(self) -> int:
return self.vq_vae_special_tokens["CHAINBREAK"]
@property
def all_token_ids(self):
return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens)))
@property
def special_token_ids(self):
return self.vq_vae_special_tokens.values()
def encode(self, *args, **kwargs):
raise NotImplementedError(
"The StructureTokenizer class is provided as a convenience for "
"accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
"Please use them instead."
)
def decode(self, *args, **kwargs):
raise NotImplementedError(
"The StructureTokenizer class is provided as a convenience for "
"accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
"Please use them instead."
)
|