"""Hugging Face-compatible tokenizer for APE molecular vocabularies. This file is intentionally self-contained so it can be copied into a model repo and loaded by ``AutoTokenizer.from_pretrained(..., trust_remote_code=True)``. """ import json import os import re from collections.abc import Mapping from collections import defaultdict from pathlib import Path from typing import Any, Literal from transformers import PreTrainedTokenizer Representation = Literal["SELFIES", "SMILES"] VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "selfies_vocab_file": "selfies_vocab.json", "smiles_vocab_file": "smiles_vocab.json", } SELFIES_RE = re.compile(r"\[[^\]]+\]") SMILES_RE = re.compile( r"(\[[^\]]+\]|Br?|Cl?|Si?|Se?|Li?|Na?|Mg?|Al?|Ca?|Fe?|Zn?|" r"N|O|S|P|F|I|K|B|C|H|" r"b|c|n|o|s|p|" r"\%\d{2}|\d|" r"\(|\)|\.|=|#|-|\+|\\|/|:|~|@|\?|\*|\$)" ) def _base_piece_count(token: str, representation: str) -> int: """Count primitive molecular pieces in a vocab token.""" pieces = pre_tokenize_molecule(token, representation) return max(1, len(pieces)) def _max_vocab_piece_span(vocab: dict[str, int], representation: str) -> int: """Maximum number of primitive pieces covered by any non-special vocab token.""" max_span = 1 for token in vocab: if token.startswith("<") and token.endswith(">"): continue max_span = max(max_span, _base_piece_count(token, representation)) return max_span def _coerce_vocab(vocab: Mapping[str, Any]) -> dict[str, int]: if not isinstance(vocab, Mapping): raise ValueError("Vocabulary must be a JSON object mapping token strings to integer IDs.") out = {str(token): int(idx) for token, idx in vocab.items()} if len(set(out.values())) != len(out): raise ValueError("Vocabulary token IDs must be unique.") return out def _token_text(token: Any) -> str: return str(getattr(token, "content", token)) def _normalize_representation(representation: str) -> Representation: normalized = representation.upper() if normalized not in {"SELFIES", "SMILES"}: raise ValueError(f"representation must be 'SELFIES' or 'SMILES', got {representation!r}") return normalized # type: ignore[return-value] def _select_vocab_file( *, representation: Representation, vocab_file: str | os.PathLike[str] | None, selfies_vocab_file: str | os.PathLike[str] | None, smiles_vocab_file: str | os.PathLike[str] | None, ) -> str | os.PathLike[str] | None: if representation == "SELFIES" and selfies_vocab_file is not None: return selfies_vocab_file if representation == "SMILES" and smiles_vocab_file is not None: return smiles_vocab_file return vocab_file def pre_tokenize_molecule(molecule: str, representation: str) -> list[str]: active_representation = _normalize_representation(representation) if active_representation == "SELFIES": return SELFIES_RE.findall(molecule) tokens: list[str] = [] cursor = 0 for match in SMILES_RE.finditer(molecule): if match.start() > cursor: tokens.extend(molecule[cursor : match.start()]) tokens.append(match.group(0)) cursor = match.end() if cursor < len(molecule): tokens.extend(molecule[cursor:]) return [token for token in tokens if token and not token.isspace()] def ape_tokenize( text: str, vocab: dict[str, int], representation: str, unk_token: str = "", max_piece_span: int | None = None, ) -> list[str]: pieces = pre_tokenize_molecule(text, representation) if not pieces: return [unk_token] if max_piece_span is None: max_piece_span = _max_vocab_piece_span(vocab, representation) n = len(pieces) tokens: list[str] = [] append_token = tokens.append vocab_contains = vocab.__contains__ join_pieces = "".join i = 0 while i < n: upper = min(n, i + max_piece_span) for j in range(upper, i, -1): candidate = join_pieces(pieces[i:j]) if vocab_contains(candidate): append_token(candidate) i = j break else: append_token(unk_token) i += 1 return tokens class APEPreTrainedTokenizer(PreTrainedTokenizer): """Hugging Face tokenizer backend for APE molecular tokenization. (Not fast)""" vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file: str | os.PathLike[str] | None = None, selfies_vocab_file: str | os.PathLike[str] | None = None, smiles_vocab_file: str | os.PathLike[str] | None = None, vocab: dict[str, Any] | None = None, representation: str = "SELFIES", bos_token: str = "", eos_token: str = "", unk_token: str = "", pad_token: str = "", mask_token: str = "", model_max_length: int = 256, **kwargs, ) -> None: self.representation = _normalize_representation(representation) active_vocab_file = _select_vocab_file( representation=self.representation, vocab_file=vocab_file, selfies_vocab_file=selfies_vocab_file, smiles_vocab_file=smiles_vocab_file, ) if vocab is None: if active_vocab_file is None: vocab = { bos_token: 0, pad_token: 1, eos_token: 2, unk_token: 3, mask_token: 4, } else: with open(active_vocab_file, encoding="utf-8") as f: vocab = json.load(f) if vocab is None: raise ValueError("Loaded vocabulary is None.") self.vocab_file = str(active_vocab_file) if active_vocab_file is not None else None self.selfies_vocab_file = ( str(selfies_vocab_file) if selfies_vocab_file is not None else None ) self.smiles_vocab_file = str(smiles_vocab_file) if smiles_vocab_file is not None else None self.vocab = _coerce_vocab(vocab) self._require_special_tokens( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, mask_token=mask_token, ) self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} self.vocabulary_frequency: dict[str, int] = {} self.pair_counts: dict[tuple[str, str], int] = {} self._max_piece_span = _max_vocab_piece_span(self.vocab, self.representation) super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, mask_token=mask_token, model_max_length=model_max_length, representation=self.representation, **kwargs, ) @property def vocab_size(self) -> int: return len(self.vocab) @property def vocabulary(self) -> dict[str, int]: """Legacy alias for callers that previously used APETokenizer.""" return self.vocab @vocabulary.setter def vocabulary(self, value: dict[str, int]) -> None: self.vocab = _coerce_vocab(value) self.update_reverse_vocabulary() self._refresh_tokenization_cache() @property def special_tokens(self) -> dict[str, int]: bos_token = str(self.bos_token) pad_token = str(self.pad_token) eos_token = str(self.eos_token) unk_token = str(self.unk_token) mask_token = str(self.mask_token) return { bos_token: self._convert_token_to_id(bos_token), pad_token: self._convert_token_to_id(pad_token), eos_token: self._convert_token_to_id(eos_token), unk_token: self._convert_token_to_id(unk_token), mask_token: self._convert_token_to_id(mask_token), } @special_tokens.setter def special_tokens(self, value: dict[str, int]) -> None: for token, token_id in value.items(): self.vocab.setdefault(str(token), int(token_id)) self.vocab = _coerce_vocab(self.vocab) self.update_reverse_vocabulary() self._refresh_tokenization_cache() def get_vocab(self) -> dict[str, int]: return dict(self.vocab) def update_reverse_vocabulary(self) -> None: self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} def _refresh_tokenization_cache(self) -> None: self._max_piece_span = _max_vocab_piece_span(self.vocab, self.representation) def _require_special_tokens( self, *, bos_token: str, eos_token: str, unk_token: str, pad_token: str, mask_token: str, ) -> None: missing = [ token_text for token in [bos_token, eos_token, unk_token, pad_token, mask_token] if (token_text := _token_text(token)) not in self.vocab ] if missing: raise ValueError(f"Vocabulary is missing required special tokens: {missing}") def pre_tokenize(self, molecule: str, representation: str | None = None) -> list[str]: return pre_tokenize_molecule(molecule, representation or self.representation) def _tokenize(self, text: str, **kwargs) -> list[str]: return ape_tokenize( text, vocab=self.vocab, representation=self.representation, unk_token=str(self.unk_token), max_piece_span=self._max_piece_span, ) def encode_molecule( self, text: str, add_special_tokens: bool = True, max_length: int | None = None, truncation: bool = True, ) -> list[int]: """Fast molecular encode path avoiding generic Hugging Face tokenizer overhead.""" tokens = self._tokenize(text) ids = [self._convert_token_to_id(token) for token in tokens] if add_special_tokens: ids = self.build_inputs_with_special_tokens(ids) if max_length is not None and truncation: ids = ids[:max_length] return ids def _convert_token_to_id(self, token: str) -> int: return self.vocab.get(token, self.vocab[str(self.unk_token)]) def _convert_id_to_token(self, index: int) -> str: return self.ids_to_tokens.get(int(index), str(self.unk_token)) def convert_tokens_to_string(self, tokens: list[str]) -> str: return "".join(tokens) def _required_special_token_id( self, token_value: int | list[int] | str | list[str] | None, token_name: str, ) -> int: if token_value is None: raise ValueError(f"{token_name} must be set.") if isinstance(token_value, int): return token_value if isinstance(token_value, str): return self._convert_token_to_id(token_value) if len(token_value) == 1: only_value = token_value[0] if isinstance(only_value, int): return only_value if isinstance(only_value, str): return self._convert_token_to_id(only_value) raise ValueError(f"{token_name} must resolve to a single token id.") def build_inputs_with_special_tokens( self, token_ids_0: list[int], token_ids_1: list[int] | None = None, ) -> list[int]: bos_id = self._required_special_token_id(self.bos_token, "bos_token") eos_id = self._required_special_token_id(self.eos_token, "eos_token") if token_ids_1 is None: return [bos_id, *token_ids_0, eos_id] return [bos_id, *token_ids_0, eos_id, *token_ids_1, eos_id] def create_token_type_ids_from_sequences( self, token_ids_0: list[int], token_ids_1: list[int] | None = None, ) -> list[int]: return [0] * len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)) def pad( self, encoded_inputs: Any, padding: Any = True, max_length: int | None = None, pad_to_multiple_of: int | None = None, padding_side: str | None = None, return_attention_mask: bool | None = None, return_tensors: Any = None, verbose: bool = True, ): padding_enabled = padding not in (False, "do_not_pad") if ( padding_enabled and isinstance(encoded_inputs, list) and any("labels" in item for item in encoded_inputs) ): target_length = max( len(item.get("input_ids", item.get("labels", []))) for item in encoded_inputs ) if padding == "max_length" and max_length is not None: target_length = max_length if pad_to_multiple_of and target_length % pad_to_multiple_of: target_length = ((target_length // pad_to_multiple_of) + 1) * pad_to_multiple_of padded_inputs = [] for item in encoded_inputs: item = dict(item) labels = list(item.get("labels", [])) pad_len = max(0, target_length - len(labels)) if pad_len: label_padding = [-100] * pad_len if self.padding_side == "left": labels = label_padding + labels else: labels = labels + label_padding item["labels"] = labels padded_inputs.append(item) encoded_inputs = padded_inputs return super().pad( encoded_inputs, padding=padding, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, padding_side=padding_side, return_attention_mask=return_attention_mask, return_tensors=return_tensors, verbose=verbose, ) def save_vocabulary( self, save_directory: str, filename_prefix: str | None = None, ) -> tuple[str, ...]: if not os.path.isdir(save_directory): raise ValueError(f"Vocabulary path ({save_directory}) should be a directory.") vocab_file = Path(save_directory) / ( f"{filename_prefix}-vocab.json" if filename_prefix else "vocab.json" ) with vocab_file.open("w", encoding="utf-8") as f: json.dump(self.vocab, f, ensure_ascii=False, indent=4) return (str(vocab_file),) def add_tokens_to_vocabulary(self, tokens: list[str]) -> int: """Add tokens to the tokenizer vocabulary if they are not already present. This is intended for forcing coverage of rare valid molecular primitive symbols, especially SELFIES bracket tokens, after APE merge training. """ if not tokens: return 0 next_id = max(self.vocab.values(), default=-1) + 1 added = 0 for token in tokens: token = str(token).strip() if not token: continue if token in self.vocab: continue self.vocab[token] = next_id next_id += 1 added += 1 if added: self.update_reverse_vocabulary() self._refresh_tokenization_cache() return added def save_pretrained(self, save_directory: str | os.PathLike[str], *args, **kwargs): saved_files = super().save_pretrained(save_directory, *args, **kwargs) save_path = Path(save_directory) special_tokens_map = { "bos_token": str(self.bos_token), "eos_token": str(self.eos_token), "unk_token": str(self.unk_token), "pad_token": str(self.pad_token), "mask_token": str(self.mask_token), } with (save_path / "special_tokens_map.json").open("w", encoding="utf-8") as f: json.dump(special_tokens_map, f, ensure_ascii=False, indent=2) tokenizer_config_path = save_path / "tokenizer_config.json" if tokenizer_config_path.exists(): with tokenizer_config_path.open(encoding="utf-8") as f: tokenizer_config = json.load(f) else: tokenizer_config = {} tokenizer_config.pop("tokenizer_class", None) tokenizer_config.update( { "representation": self.representation, "model_max_length": self.model_max_length, "auto_map": { "AutoTokenizer": [ "tokenization_ape.APEPreTrainedTokenizer", None, ], }, } ) with tokenizer_config_path.open("w", encoding="utf-8") as f: json.dump(tokenizer_config, f, ensure_ascii=False, indent=2) return saved_files def save_vocabulary_file(self, file_path: str | os.PathLike[str]) -> None: path = Path(file_path) path.parent.mkdir(parents=True, exist_ok=True) freq_path = path.with_name(f"{path.stem}_freq.json") with path.open("w", encoding="utf-8") as f: json.dump(self.vocab, f, ensure_ascii=False, indent=4) with freq_path.open("w", encoding="utf-8") as f: json.dump(self.vocabulary_frequency, f, ensure_ascii=False, indent=4) def load_vocabulary_file( self, file_path: str | os.PathLike[str], representation: str | None = None, ) -> None: if representation is not None: self.representation = _normalize_representation(representation) with open(file_path, encoding="utf-8") as f: vocab = json.load(f) self.vocab = _coerce_vocab(vocab) self._require_special_tokens( bos_token=str(self.bos_token), eos_token=str(self.eos_token), unk_token=str(self.unk_token), pad_token=str(self.pad_token), mask_token=str(self.mask_token), ) self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} self._refresh_tokenization_cache() def train( self, corpus, type: str = "selfies", representation: str | None = None, max_vocab_size: int = 5000, min_freq_for_merge: int = 2000, max_merge_pieces: int | None = 8, save_checkpoint: bool = False, checkpoint_path: str = "checkpoint", checkpoint_interval: int = 500, ) -> None: import warnings new_rep = _normalize_representation(representation or type) if new_rep != self.representation: warnings.warn( f"train() representation={new_rep!r} differs from tokenizer " f"representation={self.representation!r}. Overwriting.", UserWarning, stacklevel=2, ) self.representation = new_rep if not corpus: raise ValueError("Cannot train APE tokenizer on an empty corpus.") print(f"Pretokenizing {self.representation}...", flush=True) tokenized_corpus = [] vocabulary_frequency: defaultdict[str, int] = defaultdict(int) saw_tokens = False for sentence in corpus: tokens = self.pre_tokenize(str(sentence)) if not tokens: continue saw_tokens = True for token in tokens: vocabulary_frequency[token] += 1 if len(tokens) > 1: tokenized_corpus.append(tokens) print( f"Pretokenization complete, found {len(vocabulary_frequency)} tokens", flush=True, ) if not saw_tokens: raise ValueError("Cannot train APE tokenizer on an empty corpus.") pre_tokens_counts = len(vocabulary_frequency) merged_counter = len(vocabulary_frequency) + 1 if save_checkpoint and checkpoint_interval <= 0: raise ValueError( "checkpoint_interval must be positive when save_checkpoint is enabled." ) checkpoint_increment = checkpoint_interval batch = checkpoint_interval + pre_tokens_counts piece_count_cache: dict[str, int] = {} def merged_piece_count(token: str) -> int: count = piece_count_cache.get(token) if count is None: count = _base_piece_count(token, self.representation) piece_count_cache[token] = count return count def get_most_common_pair(tokenized): pair_counts: defaultdict[tuple[str, str], int] = defaultdict(int) for tokens in tokenized: for i in range(len(tokens) - 1): pair = (tokens[i], tokens[i + 1]) if max_merge_pieces is not None: merged_candidate = pair[0] + pair[1] if merged_piece_count(merged_candidate) > max_merge_pieces: continue pair_counts[pair] += 1 self.pair_counts = dict(pair_counts) if not pair_counts: return ("", ""), 0 most_common_pair = ("", "") most_common_frequency = 0 for pair, count in pair_counts.items(): if count > most_common_frequency: most_common_pair = pair most_common_frequency = count return most_common_pair, most_common_frequency while True: if save_checkpoint and len(vocabulary_frequency) >= batch: self.vocabulary_frequency = dict(vocabulary_frequency) self.vocab = { **{ str(self.bos_token): 0, str(self.pad_token): 1, str(self.eos_token): 2, str(self.unk_token): 3, str(self.mask_token): 4, }, **{ word: idx for idx, word in enumerate( vocabulary_frequency.keys(), start=5, ) }, } self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} self._refresh_tokenization_cache() checkpoint_dir = Path(checkpoint_path) checkpoint_dir.mkdir(parents=True, exist_ok=True) self.save_vocabulary_file(checkpoint_dir / f"checkpoint_{batch}.json") self.save_pretrained(str(checkpoint_dir / f"checkpoint_{batch}")) print(f"Checkpoint saved at {checkpoint_dir}/checkpoint_{batch}.json") batch += checkpoint_increment if len(vocabulary_frequency) >= max_vocab_size: print("Max vocabulary achieved", flush=True) break if not tokenized_corpus: print("No more mergeable pairs", flush=True) break most_common_pair, freq = get_most_common_pair(tokenized_corpus) if freq < min_freq_for_merge: print("Not enough frequency found", flush=True) break if not most_common_pair[0] or not most_common_pair[1]: print("No valid merge pair found", flush=True) break left_token, right_token = most_common_pair merged_word = left_token + right_token if merged_word not in vocabulary_frequency: print( f"New merge found: {merged_word} {merged_counter}/{max_vocab_size} " f"{round(merged_counter / max_vocab_size * 100, 2)}%", flush=True, ) merged_counter += 1 vocabulary_frequency[merged_word] += freq new_tokenized_corpus = [] for tokens in tokenized_corpus: new_tokens = [] append_token = new_tokens.append i = 0 token_count = len(tokens) while i < token_count: if ( i < token_count - 1 and tokens[i] == left_token and tokens[i + 1] == right_token ): append_token(merged_word) i += 2 else: append_token(tokens[i]) i += 1 if len(new_tokens) > 1: new_tokenized_corpus.append(new_tokens) tokenized_corpus = new_tokenized_corpus self.vocabulary_frequency = dict(vocabulary_frequency) self.vocab = { str(self.bos_token): 0, str(self.pad_token): 1, str(self.eos_token): 2, str(self.unk_token): 3, str(self.mask_token): 4, **{word: idx for idx, word in enumerate(vocabulary_frequency.keys(), start=5)}, } self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} self._refresh_tokenization_cache() def train_from_iterator(self, iterator, *args, **kwargs) -> None: raise NotImplementedError("train_from_iterator is not implemented for APE") APEPreTrainedTokenizer.register_for_auto_class("AutoTokenizer")