"""HELM-BERT tokenizer.""" import json import os from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer # Default vocabulary for HELM notation HELM_VOCAB = { # Special tokens (0-4) " ": 0, # PAD "@": 1, # BOS/CLS "\n": 2, # EOS/SEP "§": 3, # UNK "¶": 4, # MASK # Natural amino acids (5-25) "A": 5, "R": 6, "N": 7, "D": 8, "C": 9, "E": 10, "Q": 11, "G": 12, "H": 13, "I": 14, "L": 15, "K": 16, "M": 17, "F": 18, "P": 19, "S": 20, "T": 21, "W": 22, "Y": 23, "V": 24, "X": 25, # Unknown amino acid # Structure symbols (26-37) "[": 26, "]": 27, "{": 28, "}": 29, "(": 30, ")": 31, "$": 32, ",": 33, ":": 34, "|": 35, "-": 36, ".": 37, # Numbers (38-47) "0": 38, "1": 39, "2": 40, "3": 41, "4": 42, "5": 43, "6": 44, "7": 45, "8": 46, "9": 47, # Uppercase non-amino acids (48-50) "B": 48, "O": 49, ">": 50, # Lowercase letters (51-72) "a": 51, "b": 52, "c": 53, "d": 54, "e": 55, "f": 56, "g": 57, "h": 58, "i": 59, "l": 60, "m": 61, "n": 62, "o": 63, "p": 64, "r": 65, "s": 66, "t": 67, "u": 68, "v": 69, "x": 70, "y": 71, "z": 72, # Encoded polymer markers (73-76) "/": 73, # PEPTIDE "*": 74, # me "\t": 75, # am "&": 76, # ac # Miscellaneous (77) "_": 77, } # Multi-character to single-character encoding HELM_ENCODE_MAP = {"PEPTIDE": "/", "me": "*", "am": "\t", "ac": "&"} HELM_DECODE_MAP = {v: k for k, v in HELM_ENCODE_MAP.items()} class HELMBertTokenizer(PreTrainedTokenizer): """Tokenizer for HELM-BERT. This tokenizer handles HELM (Hierarchical Editing Language for Macromolecules) notation, converting peptide sequences into token IDs for the HELM-BERT model. The tokenizer uses character-level tokenization with special handling for multi-character HELM tokens like "PEPTIDE", "me", "am", "ac". Example: >>> from helmbert import HELMBertTokenizer >>> tokenizer = HELMBertTokenizer() >>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt") >>> inputs.input_ids tensor([[ 1, 73, 39, 28, 5, 37, 9, 37, 8, 37, 10, 29, 32, 32, 32, 32, 2]]) """ vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file: Optional[str] = None, unk_token: str = "§", sep_token: str = "\n", pad_token: str = " ", cls_token: str = "@", mask_token: str = "¶", bos_token: str = "@", eos_token: str = "\n", model_max_length: int = 512, **kwargs, ): # Load or create vocabulary if vocab_file is not None and os.path.isfile(vocab_file): with open(vocab_file, encoding="utf-8") as f: self.vocab = json.load(f) else: self.vocab = HELM_VOCAB.copy() self.ids_to_tokens = {v: k for k, v in self.vocab.items()} # HELM encoding/decoding maps self.encode_map = HELM_ENCODE_MAP.copy() self.decode_map = HELM_DECODE_MAP.copy() super().__init__( unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, bos_token=bos_token, eos_token=eos_token, model_max_length=model_max_length, **kwargs, ) @property def vocab_size(self) -> int: """Return the vocabulary size.""" return len(self.vocab) def get_vocab(self) -> Dict[str, int]: """Return the vocabulary as a dictionary.""" return self.vocab.copy() def _encode_helm(self, text: str) -> str: """Encode multi-character HELM tokens to single characters. Args: text: Raw HELM notation string Returns: Encoded string with single-character tokens """ if not text: return "" result = text for seq, tok in self.encode_map.items(): result = result.replace(seq, tok) return result def _decode_helm(self, text: str) -> str: """Decode single-character tokens back to multi-character HELM tokens. Args: text: Encoded string with single-character tokens Returns: Decoded HELM notation string """ if not text: return "" result = text for tok, seq in self.decode_map.items(): result = result.replace(tok, seq) return result def _tokenize(self, text: str) -> List[str]: """Tokenize a HELM string into a list of tokens. Args: text: HELM notation string Returns: List of single-character tokens """ # First encode multi-character tokens to single characters encoded = self._encode_helm(text) # Return as list of characters return list(encoded) def _convert_token_to_id(self, token: str) -> int: """Convert a token to its ID.""" return self.vocab.get(token, self.vocab.get(self.unk_token, 3)) def _convert_id_to_token(self, index: int) -> str: """Convert an ID to its token.""" return self.ids_to_tokens.get(index, self.unk_token) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert a list of tokens to a HELM string. Args: tokens: List of tokens Returns: Decoded HELM notation string """ # Join tokens and decode back to HELM notation joined = "".join(tokens) return self._decode_helm(joined) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """Build model inputs by adding special tokens. Args: token_ids_0: First sequence of token IDs token_ids_1: Optional second sequence of token IDs Returns: List of token IDs with special tokens added """ cls_id = [self.cls_token_id] sep_id = [self.sep_token_id] if token_ids_1 is None: return cls_id + token_ids_0 + sep_id return cls_id + token_ids_0 + sep_id + token_ids_1 + sep_id def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: """Get a mask identifying special tokens. Args: token_ids_0: First sequence of token IDs token_ids_1: Optional second sequence of token IDs already_has_special_tokens: Whether the sequences already have special tokens Returns: List of 0s and 1s (1 = special token) """ if already_has_special_tokens: return [ 1 if x in [self.cls_token_id, self.sep_token_id, self.pad_token_id] else 0 for x in token_ids_0 ] if token_ids_1 is None: return [1] + [0] * len(token_ids_0) + [1] return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """Create token type IDs for sequence pairs. Args: token_ids_0: First sequence of token IDs token_ids_1: Optional second sequence of token IDs Returns: List of token type IDs """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return [0] * len(cls + token_ids_0 + sep) return [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep) def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> Tuple[str]: """Save the vocabulary to a file. Args: save_directory: Directory to save the vocabulary filename_prefix: Optional prefix for the filename Returns: Tuple containing the path to the saved vocabulary file """ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self.vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) @property def mask_token_id(self) -> int: """Return the mask token ID.""" return self.vocab.get(self.mask_token, 4)