File size: 9,126 Bytes
45723db
 
 
 
62249d1
45723db
 
 
 
 
 
 
62249d1
 
45723db
62249d1
 
45723db
62249d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45723db
 
62249d1
 
 
 
 
 
 
 
 
 
 
 
45723db
62249d1
 
 
 
 
 
 
 
 
 
45723db
62249d1
 
 
45723db
62249d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45723db
62249d1
 
45723db
62249d1
45723db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62249d1
 
 
 
 
 
45723db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""HELM-BERT tokenizer."""

import json
import os
from typing import Dict, List, Optional, Tuple

from transformers import PreTrainedTokenizer


# Default vocabulary for HELM notation
HELM_VOCAB = {
    # Special tokens (0-4)
    " ": 0,  # PAD
    "@": 1,  # BOS/CLS
    "\n": 2,  # EOS/SEP
    "§": 3,  # UNK
    "¶": 4,  # MASK
    # Natural amino acids (5-25)
    "A": 5,
    "R": 6,
    "N": 7,
    "D": 8,
    "C": 9,
    "E": 10,
    "Q": 11,
    "G": 12,
    "H": 13,
    "I": 14,
    "L": 15,
    "K": 16,
    "M": 17,
    "F": 18,
    "P": 19,
    "S": 20,
    "T": 21,
    "W": 22,
    "Y": 23,
    "V": 24,
    "X": 25,  # Unknown amino acid
    # Structure symbols (26-37)
    "[": 26,
    "]": 27,
    "{": 28,
    "}": 29,
    "(": 30,
    ")": 31,
    "$": 32,
    ",": 33,
    ":": 34,
    "|": 35,
    "-": 36,
    ".": 37,
    # Numbers (38-47)
    "0": 38,
    "1": 39,
    "2": 40,
    "3": 41,
    "4": 42,
    "5": 43,
    "6": 44,
    "7": 45,
    "8": 46,
    "9": 47,
    # Uppercase non-amino acids (48-50)
    "B": 48,
    "O": 49,
    ">": 50,
    # Lowercase letters (51-72)
    "a": 51,
    "b": 52,
    "c": 53,
    "d": 54,
    "e": 55,
    "f": 56,
    "g": 57,
    "h": 58,
    "i": 59,
    "l": 60,
    "m": 61,
    "n": 62,
    "o": 63,
    "p": 64,
    "r": 65,
    "s": 66,
    "t": 67,
    "u": 68,
    "v": 69,
    "x": 70,
    "y": 71,
    "z": 72,
    # Encoded polymer markers (73-76)
    "/": 73,  # PEPTIDE
    "*": 74,  # me
    "\t": 75,  # am
    "&": 76,  # ac
    # Miscellaneous (77)
    "_": 77,
}

# Multi-character to single-character encoding
HELM_ENCODE_MAP = {"PEPTIDE": "/", "me": "*", "am": "\t", "ac": "&"}
HELM_DECODE_MAP = {v: k for k, v in HELM_ENCODE_MAP.items()}


class HELMBertTokenizer(PreTrainedTokenizer):
    """Tokenizer for HELM-BERT.

    This tokenizer handles HELM (Hierarchical Editing Language for Macromolecules)
    notation, converting peptide sequences into token IDs for the HELM-BERT model.

    The tokenizer uses character-level tokenization with special handling for
    multi-character HELM tokens like "PEPTIDE", "me", "am", "ac".

    Example:
        >>> from helmbert import HELMBertTokenizer
        >>> tokenizer = HELMBertTokenizer()
        >>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt")
        >>> inputs.input_ids
        tensor([[ 1, 73, 39, 28, 5, 37, 9, 37, 8, 37, 10, 29, 32, 32, 32, 32,  2]])
    """

    vocab_files_names = {"vocab_file": "vocab.json"}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file: Optional[str] = None,
        unk_token: str = "§",
        sep_token: str = "\n",
        pad_token: str = " ",
        cls_token: str = "@",
        mask_token: str = "¶",
        bos_token: str = "@",
        eos_token: str = "\n",
        model_max_length: int = 512,
        **kwargs,
    ):
        # Load or create vocabulary
        if vocab_file is not None and os.path.isfile(vocab_file):
            with open(vocab_file, encoding="utf-8") as f:
                self.vocab = json.load(f)
        else:
            self.vocab = HELM_VOCAB.copy()

        self.ids_to_tokens = {v: k for k, v in self.vocab.items()}

        # HELM encoding/decoding maps
        self.encode_map = HELM_ENCODE_MAP.copy()
        self.decode_map = HELM_DECODE_MAP.copy()

        super().__init__(
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            bos_token=bos_token,
            eos_token=eos_token,
            model_max_length=model_max_length,
            **kwargs,
        )

    @property
    def vocab_size(self) -> int:
        """Return the vocabulary size."""
        return len(self.vocab)

    def get_vocab(self) -> Dict[str, int]:
        """Return the vocabulary as a dictionary."""
        return self.vocab.copy()

    def _encode_helm(self, text: str) -> str:
        """Encode multi-character HELM tokens to single characters.

        Args:
            text: Raw HELM notation string

        Returns:
            Encoded string with single-character tokens
        """
        if not text:
            return ""
        result = text
        for seq, tok in self.encode_map.items():
            result = result.replace(seq, tok)
        return result

    def _decode_helm(self, text: str) -> str:
        """Decode single-character tokens back to multi-character HELM tokens.

        Args:
            text: Encoded string with single-character tokens

        Returns:
            Decoded HELM notation string
        """
        if not text:
            return ""
        result = text
        for tok, seq in self.decode_map.items():
            result = result.replace(tok, seq)
        return result

    def _tokenize(self, text: str) -> List[str]:
        """Tokenize a HELM string into a list of tokens.

        Args:
            text: HELM notation string

        Returns:
            List of single-character tokens
        """
        # First encode multi-character tokens to single characters
        encoded = self._encode_helm(text)
        # Return as list of characters
        return list(encoded)

    def _convert_token_to_id(self, token: str) -> int:
        """Convert a token to its ID."""
        return self.vocab.get(token, self.vocab.get(self.unk_token, 3))

    def _convert_id_to_token(self, index: int) -> str:
        """Convert an ID to its token."""
        return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Convert a list of tokens to a HELM string.

        Args:
            tokens: List of tokens

        Returns:
            Decoded HELM notation string
        """
        # Join tokens and decode back to HELM notation
        joined = "".join(tokens)
        return self._decode_helm(joined)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """Build model inputs by adding special tokens.

        Args:
            token_ids_0: First sequence of token IDs
            token_ids_1: Optional second sequence of token IDs

        Returns:
            List of token IDs with special tokens added
        """
        cls_id = [self.cls_token_id]
        sep_id = [self.sep_token_id]

        if token_ids_1 is None:
            return cls_id + token_ids_0 + sep_id

        return cls_id + token_ids_0 + sep_id + token_ids_1 + sep_id

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        """Get a mask identifying special tokens.

        Args:
            token_ids_0: First sequence of token IDs
            token_ids_1: Optional second sequence of token IDs
            already_has_special_tokens: Whether the sequences already have special tokens

        Returns:
            List of 0s and 1s (1 = special token)
        """
        if already_has_special_tokens:
            return [
                1
                if x in [self.cls_token_id, self.sep_token_id, self.pad_token_id]
                else 0
                for x in token_ids_0
            ]

        if token_ids_1 is None:
            return [1] + [0] * len(token_ids_0) + [1]

        return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """Create token type IDs for sequence pairs.

        Args:
            token_ids_0: First sequence of token IDs
            token_ids_1: Optional second sequence of token IDs

        Returns:
            List of token type IDs
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        if token_ids_1 is None:
            return [0] * len(cls + token_ids_0 + sep)

        return [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)

    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ) -> Tuple[str]:
        """Save the vocabulary to a file.

        Args:
            save_directory: Directory to save the vocabulary
            filename_prefix: Optional prefix for the filename

        Returns:
            Tuple containing the path to the saved vocabulary file
        """
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory, exist_ok=True)

        vocab_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
        )

        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.vocab, f, ensure_ascii=False, indent=2)

        return (vocab_file,)

    @property
    def mask_token_id(self) -> int:
        """Return the mask token ID."""
        return self.vocab.get(self.mask_token, 4)