File size: 2,715 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase
from src.data.esm.utils.constants import esm3 as C


class StructureTokenizer(EsmTokenizerBase):
    """A convenince class for accessing special token ids of
    the StructureTokenEncoder and StructureTokenDecoder."""

    def __init__(self, codebook_size: int = C.VQVAE_CODEBOOK_SIZE):
        self.vq_vae_special_tokens = {
            "MASK": codebook_size,
            "EOS": codebook_size + 1,
            "BOS": codebook_size + 2,
            "PAD": codebook_size + 3,
            "CHAINBREAK": codebook_size + 4,
        }

    def mask_token(self) -> str:
        raise NotImplementedError(
            "Structure tokens are defined on 3D coordinates, not strings."
        )

    @property
    def mask_token_id(self) -> int:
        return self.vq_vae_special_tokens["MASK"]

    def bos_token(self) -> str:
        raise NotImplementedError(
            "Structure tokens are defined on 3D coordinates, not strings."
        )

    @property
    def bos_token_id(self) -> int:
        return self.vq_vae_special_tokens["BOS"]

    def eos_token(self) -> str:
        raise NotImplementedError(
            "Structure tokens are defined on 3D coordinates, not strings."
        )

    @property
    def eos_token_id(self) -> int:
        return self.vq_vae_special_tokens["EOS"]

    def pad_token(self) -> str:
        raise NotImplementedError(
            "Structure tokens are defined on 3D coordinates, not strings."
        )

    @property
    def pad_token_id(self) -> int:
        return self.vq_vae_special_tokens["PAD"]

    def chain_break_token(self) -> str:
        raise NotImplementedError(
            "Structure tokens are defined on 3D coordinates, not strings."
        )

    @property
    def chain_break_token_id(self) -> int:
        return self.vq_vae_special_tokens["CHAINBREAK"]

    @property
    def all_token_ids(self):
        return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens)))

    @property
    def special_token_ids(self):
        return self.vq_vae_special_tokens.values()

    def encode(self, *args, **kwargs):
        raise NotImplementedError(
            "The StructureTokenizer class is provided as a convenience for "
            "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
            "Please use them instead."
        )

    def decode(self, *args, **kwargs):
        raise NotImplementedError(
            "The StructureTokenizer class is provided as a convenience for "
            "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
            "Please use them instead."
        )