Fill-Mask
Transformers
Safetensors
modernbert
chemistry
molecules
selfies
ape-tokenizer
masked-language-modeling
Instructions to use HauserGroup/ModernMolBERT-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use HauserGroup/ModernMolBERT-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="HauserGroup/ModernMolBERT-base")# Load model directly from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("HauserGroup/ModernMolBERT-base") model = AutoModelForMaskedLM.from_pretrained("HauserGroup/ModernMolBERT-base") - Notebooks
- Google Colab
- Kaggle
| """Hugging Face-compatible tokenizer for APE molecular vocabularies. | |
| This file is intentionally self-contained so it can be copied into a model repo | |
| and loaded by ``AutoTokenizer.from_pretrained(..., trust_remote_code=True)``. | |
| """ | |
| import json | |
| import os | |
| import re | |
| from collections.abc import Mapping | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Any, Literal | |
| from transformers import PreTrainedTokenizer | |
| Representation = Literal["SELFIES", "SMILES"] | |
| VOCAB_FILES_NAMES = { | |
| "vocab_file": "vocab.json", | |
| "selfies_vocab_file": "selfies_vocab.json", | |
| "smiles_vocab_file": "smiles_vocab.json", | |
| } | |
| SELFIES_RE = re.compile(r"\[[^\]]+\]") | |
| SMILES_RE = re.compile( | |
| r"(\[[^\]]+\]|Br?|Cl?|Si?|Se?|Li?|Na?|Mg?|Al?|Ca?|Fe?|Zn?|" | |
| r"N|O|S|P|F|I|K|B|C|H|" | |
| r"b|c|n|o|s|p|" | |
| r"\%\d{2}|\d|" | |
| r"\(|\)|\.|=|#|-|\+|\\|/|:|~|@|\?|\*|\$)" | |
| ) | |
| def _base_piece_count(token: str, representation: str) -> int: | |
| """Count primitive molecular pieces in a vocab token.""" | |
| pieces = pre_tokenize_molecule(token, representation) | |
| return max(1, len(pieces)) | |
| def _max_vocab_piece_span(vocab: dict[str, int], representation: str) -> int: | |
| """Maximum number of primitive pieces covered by any non-special vocab token.""" | |
| max_span = 1 | |
| for token in vocab: | |
| if token.startswith("<") and token.endswith(">"): | |
| continue | |
| max_span = max(max_span, _base_piece_count(token, representation)) | |
| return max_span | |
| def _coerce_vocab(vocab: Mapping[str, Any]) -> dict[str, int]: | |
| if not isinstance(vocab, Mapping): | |
| raise ValueError("Vocabulary must be a JSON object mapping token strings to integer IDs.") | |
| out = {str(token): int(idx) for token, idx in vocab.items()} | |
| if len(set(out.values())) != len(out): | |
| raise ValueError("Vocabulary token IDs must be unique.") | |
| return out | |
| def _token_text(token: Any) -> str: | |
| return str(getattr(token, "content", token)) | |
| def _normalize_representation(representation: str) -> Representation: | |
| normalized = representation.upper() | |
| if normalized not in {"SELFIES", "SMILES"}: | |
| raise ValueError(f"representation must be 'SELFIES' or 'SMILES', got {representation!r}") | |
| return normalized # type: ignore[return-value] | |
| def _select_vocab_file( | |
| *, | |
| representation: Representation, | |
| vocab_file: str | os.PathLike[str] | None, | |
| selfies_vocab_file: str | os.PathLike[str] | None, | |
| smiles_vocab_file: str | os.PathLike[str] | None, | |
| ) -> str | os.PathLike[str] | None: | |
| if representation == "SELFIES" and selfies_vocab_file is not None: | |
| return selfies_vocab_file | |
| if representation == "SMILES" and smiles_vocab_file is not None: | |
| return smiles_vocab_file | |
| return vocab_file | |
| def pre_tokenize_molecule(molecule: str, representation: str) -> list[str]: | |
| active_representation = _normalize_representation(representation) | |
| if active_representation == "SELFIES": | |
| return SELFIES_RE.findall(molecule) | |
| tokens: list[str] = [] | |
| cursor = 0 | |
| for match in SMILES_RE.finditer(molecule): | |
| if match.start() > cursor: | |
| tokens.extend(molecule[cursor : match.start()]) | |
| tokens.append(match.group(0)) | |
| cursor = match.end() | |
| if cursor < len(molecule): | |
| tokens.extend(molecule[cursor:]) | |
| return [token for token in tokens if token and not token.isspace()] | |
| def ape_tokenize( | |
| text: str, | |
| vocab: dict[str, int], | |
| representation: str, | |
| unk_token: str = "<unk>", | |
| max_piece_span: int | None = None, | |
| ) -> list[str]: | |
| pieces = pre_tokenize_molecule(text, representation) | |
| if not pieces: | |
| return [unk_token] | |
| if max_piece_span is None: | |
| max_piece_span = _max_vocab_piece_span(vocab, representation) | |
| n = len(pieces) | |
| tokens: list[str] = [] | |
| append_token = tokens.append | |
| vocab_contains = vocab.__contains__ | |
| join_pieces = "".join | |
| i = 0 | |
| while i < n: | |
| upper = min(n, i + max_piece_span) | |
| for j in range(upper, i, -1): | |
| candidate = join_pieces(pieces[i:j]) | |
| if vocab_contains(candidate): | |
| append_token(candidate) | |
| i = j | |
| break | |
| else: | |
| append_token(unk_token) | |
| i += 1 | |
| return tokens | |
| class APEPreTrainedTokenizer(PreTrainedTokenizer): | |
| """Hugging Face tokenizer backend for APE molecular tokenization. (Not fast)""" | |
| vocab_files_names = VOCAB_FILES_NAMES | |
| model_input_names = ["input_ids", "attention_mask"] | |
| def __init__( | |
| self, | |
| vocab_file: str | os.PathLike[str] | None = None, | |
| selfies_vocab_file: str | os.PathLike[str] | None = None, | |
| smiles_vocab_file: str | os.PathLike[str] | None = None, | |
| vocab: dict[str, Any] | None = None, | |
| representation: str = "SELFIES", | |
| bos_token: str = "<s>", | |
| eos_token: str = "</s>", | |
| unk_token: str = "<unk>", | |
| pad_token: str = "<pad>", | |
| mask_token: str = "<mask>", | |
| model_max_length: int = 256, | |
| **kwargs, | |
| ) -> None: | |
| self.representation = _normalize_representation(representation) | |
| active_vocab_file = _select_vocab_file( | |
| representation=self.representation, | |
| vocab_file=vocab_file, | |
| selfies_vocab_file=selfies_vocab_file, | |
| smiles_vocab_file=smiles_vocab_file, | |
| ) | |
| if vocab is None: | |
| if active_vocab_file is None: | |
| vocab = { | |
| bos_token: 0, | |
| pad_token: 1, | |
| eos_token: 2, | |
| unk_token: 3, | |
| mask_token: 4, | |
| } | |
| else: | |
| with open(active_vocab_file, encoding="utf-8") as f: | |
| vocab = json.load(f) | |
| if vocab is None: | |
| raise ValueError("Loaded vocabulary is None.") | |
| self.vocab_file = str(active_vocab_file) if active_vocab_file is not None else None | |
| self.selfies_vocab_file = ( | |
| str(selfies_vocab_file) if selfies_vocab_file is not None else None | |
| ) | |
| self.smiles_vocab_file = str(smiles_vocab_file) if smiles_vocab_file is not None else None | |
| self.vocab = _coerce_vocab(vocab) | |
| self._require_special_tokens( | |
| bos_token=bos_token, | |
| eos_token=eos_token, | |
| unk_token=unk_token, | |
| pad_token=pad_token, | |
| mask_token=mask_token, | |
| ) | |
| self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} | |
| self.vocabulary_frequency: dict[str, int] = {} | |
| self.pair_counts: dict[tuple[str, str], int] = {} | |
| self._max_piece_span = _max_vocab_piece_span(self.vocab, self.representation) | |
| super().__init__( | |
| bos_token=bos_token, | |
| eos_token=eos_token, | |
| unk_token=unk_token, | |
| pad_token=pad_token, | |
| mask_token=mask_token, | |
| model_max_length=model_max_length, | |
| representation=self.representation, | |
| **kwargs, | |
| ) | |
| def vocab_size(self) -> int: | |
| return len(self.vocab) | |
| def vocabulary(self) -> dict[str, int]: | |
| """Legacy alias for callers that previously used APETokenizer.""" | |
| return self.vocab | |
| def vocabulary(self, value: dict[str, int]) -> None: | |
| self.vocab = _coerce_vocab(value) | |
| self.update_reverse_vocabulary() | |
| self._refresh_tokenization_cache() | |
| def special_tokens(self) -> dict[str, int]: | |
| bos_token = str(self.bos_token) | |
| pad_token = str(self.pad_token) | |
| eos_token = str(self.eos_token) | |
| unk_token = str(self.unk_token) | |
| mask_token = str(self.mask_token) | |
| return { | |
| bos_token: self._convert_token_to_id(bos_token), | |
| pad_token: self._convert_token_to_id(pad_token), | |
| eos_token: self._convert_token_to_id(eos_token), | |
| unk_token: self._convert_token_to_id(unk_token), | |
| mask_token: self._convert_token_to_id(mask_token), | |
| } | |
| def special_tokens(self, value: dict[str, int]) -> None: | |
| for token, token_id in value.items(): | |
| self.vocab.setdefault(str(token), int(token_id)) | |
| self.vocab = _coerce_vocab(self.vocab) | |
| self.update_reverse_vocabulary() | |
| self._refresh_tokenization_cache() | |
| def get_vocab(self) -> dict[str, int]: | |
| return dict(self.vocab) | |
| def update_reverse_vocabulary(self) -> None: | |
| self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} | |
| def _refresh_tokenization_cache(self) -> None: | |
| self._max_piece_span = _max_vocab_piece_span(self.vocab, self.representation) | |
| def _require_special_tokens( | |
| self, | |
| *, | |
| bos_token: str, | |
| eos_token: str, | |
| unk_token: str, | |
| pad_token: str, | |
| mask_token: str, | |
| ) -> None: | |
| missing = [ | |
| token_text | |
| for token in [bos_token, eos_token, unk_token, pad_token, mask_token] | |
| if (token_text := _token_text(token)) not in self.vocab | |
| ] | |
| if missing: | |
| raise ValueError(f"Vocabulary is missing required special tokens: {missing}") | |
| def pre_tokenize(self, molecule: str, representation: str | None = None) -> list[str]: | |
| return pre_tokenize_molecule(molecule, representation or self.representation) | |
| def _tokenize(self, text: str, **kwargs) -> list[str]: | |
| return ape_tokenize( | |
| text, | |
| vocab=self.vocab, | |
| representation=self.representation, | |
| unk_token=str(self.unk_token), | |
| max_piece_span=self._max_piece_span, | |
| ) | |
| def encode_molecule( | |
| self, | |
| text: str, | |
| add_special_tokens: bool = True, | |
| max_length: int | None = None, | |
| truncation: bool = True, | |
| ) -> list[int]: | |
| """Fast molecular encode path avoiding generic Hugging Face tokenizer overhead.""" | |
| tokens = self._tokenize(text) | |
| ids = [self._convert_token_to_id(token) for token in tokens] | |
| if add_special_tokens: | |
| ids = self.build_inputs_with_special_tokens(ids) | |
| if max_length is not None and truncation: | |
| ids = ids[:max_length] | |
| return ids | |
| def _convert_token_to_id(self, token: str) -> int: | |
| return self.vocab.get(token, self.vocab[str(self.unk_token)]) | |
| def _convert_id_to_token(self, index: int) -> str: | |
| return self.ids_to_tokens.get(int(index), str(self.unk_token)) | |
| def convert_tokens_to_string(self, tokens: list[str]) -> str: | |
| return "".join(tokens) | |
| def _required_special_token_id( | |
| self, | |
| token_value: int | list[int] | str | list[str] | None, | |
| token_name: str, | |
| ) -> int: | |
| if token_value is None: | |
| raise ValueError(f"{token_name} must be set.") | |
| if isinstance(token_value, int): | |
| return token_value | |
| if isinstance(token_value, str): | |
| return self._convert_token_to_id(token_value) | |
| if len(token_value) == 1: | |
| only_value = token_value[0] | |
| if isinstance(only_value, int): | |
| return only_value | |
| if isinstance(only_value, str): | |
| return self._convert_token_to_id(only_value) | |
| raise ValueError(f"{token_name} must resolve to a single token id.") | |
| def build_inputs_with_special_tokens( | |
| self, | |
| token_ids_0: list[int], | |
| token_ids_1: list[int] | None = None, | |
| ) -> list[int]: | |
| bos_id = self._required_special_token_id(self.bos_token, "bos_token") | |
| eos_id = self._required_special_token_id(self.eos_token, "eos_token") | |
| if token_ids_1 is None: | |
| return [bos_id, *token_ids_0, eos_id] | |
| return [bos_id, *token_ids_0, eos_id, *token_ids_1, eos_id] | |
| def create_token_type_ids_from_sequences( | |
| self, | |
| token_ids_0: list[int], | |
| token_ids_1: list[int] | None = None, | |
| ) -> list[int]: | |
| return [0] * len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)) | |
| def pad( | |
| self, | |
| encoded_inputs: Any, | |
| padding: Any = True, | |
| max_length: int | None = None, | |
| pad_to_multiple_of: int | None = None, | |
| padding_side: str | None = None, | |
| return_attention_mask: bool | None = None, | |
| return_tensors: Any = None, | |
| verbose: bool = True, | |
| ): | |
| padding_enabled = padding not in (False, "do_not_pad") | |
| if ( | |
| padding_enabled | |
| and isinstance(encoded_inputs, list) | |
| and any("labels" in item for item in encoded_inputs) | |
| ): | |
| target_length = max( | |
| len(item.get("input_ids", item.get("labels", []))) for item in encoded_inputs | |
| ) | |
| if padding == "max_length" and max_length is not None: | |
| target_length = max_length | |
| if pad_to_multiple_of and target_length % pad_to_multiple_of: | |
| target_length = ((target_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |
| padded_inputs = [] | |
| for item in encoded_inputs: | |
| item = dict(item) | |
| labels = list(item.get("labels", [])) | |
| pad_len = max(0, target_length - len(labels)) | |
| if pad_len: | |
| label_padding = [-100] * pad_len | |
| if self.padding_side == "left": | |
| labels = label_padding + labels | |
| else: | |
| labels = labels + label_padding | |
| item["labels"] = labels | |
| padded_inputs.append(item) | |
| encoded_inputs = padded_inputs | |
| return super().pad( | |
| encoded_inputs, | |
| padding=padding, | |
| max_length=max_length, | |
| pad_to_multiple_of=pad_to_multiple_of, | |
| padding_side=padding_side, | |
| return_attention_mask=return_attention_mask, | |
| return_tensors=return_tensors, | |
| verbose=verbose, | |
| ) | |
| def save_vocabulary( | |
| self, | |
| save_directory: str, | |
| filename_prefix: str | None = None, | |
| ) -> tuple[str, ...]: | |
| if not os.path.isdir(save_directory): | |
| raise ValueError(f"Vocabulary path ({save_directory}) should be a directory.") | |
| vocab_file = Path(save_directory) / ( | |
| f"{filename_prefix}-vocab.json" if filename_prefix else "vocab.json" | |
| ) | |
| with vocab_file.open("w", encoding="utf-8") as f: | |
| json.dump(self.vocab, f, ensure_ascii=False, indent=4) | |
| return (str(vocab_file),) | |
| def add_tokens_to_vocabulary(self, tokens: list[str]) -> int: | |
| """Add tokens to the tokenizer vocabulary if they are not already present. | |
| This is intended for forcing coverage of rare valid molecular primitive | |
| symbols, especially SELFIES bracket tokens, after APE merge training. | |
| """ | |
| if not tokens: | |
| return 0 | |
| next_id = max(self.vocab.values(), default=-1) + 1 | |
| added = 0 | |
| for token in tokens: | |
| token = str(token).strip() | |
| if not token: | |
| continue | |
| if token in self.vocab: | |
| continue | |
| self.vocab[token] = next_id | |
| next_id += 1 | |
| added += 1 | |
| if added: | |
| self.update_reverse_vocabulary() | |
| self._refresh_tokenization_cache() | |
| return added | |
| def save_pretrained(self, save_directory: str | os.PathLike[str], *args, **kwargs): | |
| saved_files = super().save_pretrained(save_directory, *args, **kwargs) | |
| save_path = Path(save_directory) | |
| special_tokens_map = { | |
| "bos_token": str(self.bos_token), | |
| "eos_token": str(self.eos_token), | |
| "unk_token": str(self.unk_token), | |
| "pad_token": str(self.pad_token), | |
| "mask_token": str(self.mask_token), | |
| } | |
| with (save_path / "special_tokens_map.json").open("w", encoding="utf-8") as f: | |
| json.dump(special_tokens_map, f, ensure_ascii=False, indent=2) | |
| tokenizer_config_path = save_path / "tokenizer_config.json" | |
| if tokenizer_config_path.exists(): | |
| with tokenizer_config_path.open(encoding="utf-8") as f: | |
| tokenizer_config = json.load(f) | |
| else: | |
| tokenizer_config = {} | |
| tokenizer_config.pop("tokenizer_class", None) | |
| tokenizer_config.update( | |
| { | |
| "representation": self.representation, | |
| "model_max_length": self.model_max_length, | |
| "auto_map": { | |
| "AutoTokenizer": [ | |
| "tokenization_ape.APEPreTrainedTokenizer", | |
| None, | |
| ], | |
| }, | |
| } | |
| ) | |
| with tokenizer_config_path.open("w", encoding="utf-8") as f: | |
| json.dump(tokenizer_config, f, ensure_ascii=False, indent=2) | |
| return saved_files | |
| def save_vocabulary_file(self, file_path: str | os.PathLike[str]) -> None: | |
| path = Path(file_path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| freq_path = path.with_name(f"{path.stem}_freq.json") | |
| with path.open("w", encoding="utf-8") as f: | |
| json.dump(self.vocab, f, ensure_ascii=False, indent=4) | |
| with freq_path.open("w", encoding="utf-8") as f: | |
| json.dump(self.vocabulary_frequency, f, ensure_ascii=False, indent=4) | |
| def load_vocabulary_file( | |
| self, | |
| file_path: str | os.PathLike[str], | |
| representation: str | None = None, | |
| ) -> None: | |
| if representation is not None: | |
| self.representation = _normalize_representation(representation) | |
| with open(file_path, encoding="utf-8") as f: | |
| vocab = json.load(f) | |
| self.vocab = _coerce_vocab(vocab) | |
| self._require_special_tokens( | |
| bos_token=str(self.bos_token), | |
| eos_token=str(self.eos_token), | |
| unk_token=str(self.unk_token), | |
| pad_token=str(self.pad_token), | |
| mask_token=str(self.mask_token), | |
| ) | |
| self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} | |
| self._refresh_tokenization_cache() | |
| def train( | |
| self, | |
| corpus, | |
| type: str = "selfies", | |
| representation: str | None = None, | |
| max_vocab_size: int = 5000, | |
| min_freq_for_merge: int = 2000, | |
| max_merge_pieces: int | None = 8, | |
| save_checkpoint: bool = False, | |
| checkpoint_path: str = "checkpoint", | |
| checkpoint_interval: int = 500, | |
| ) -> None: | |
| import warnings | |
| new_rep = _normalize_representation(representation or type) | |
| if new_rep != self.representation: | |
| warnings.warn( | |
| f"train() representation={new_rep!r} differs from tokenizer " | |
| f"representation={self.representation!r}. Overwriting.", | |
| UserWarning, | |
| stacklevel=2, | |
| ) | |
| self.representation = new_rep | |
| if not corpus: | |
| raise ValueError("Cannot train APE tokenizer on an empty corpus.") | |
| print(f"Pretokenizing {self.representation}...", flush=True) | |
| tokenized_corpus = [] | |
| vocabulary_frequency: defaultdict[str, int] = defaultdict(int) | |
| saw_tokens = False | |
| for sentence in corpus: | |
| tokens = self.pre_tokenize(str(sentence)) | |
| if not tokens: | |
| continue | |
| saw_tokens = True | |
| for token in tokens: | |
| vocabulary_frequency[token] += 1 | |
| if len(tokens) > 1: | |
| tokenized_corpus.append(tokens) | |
| print( | |
| f"Pretokenization complete, found {len(vocabulary_frequency)} tokens", | |
| flush=True, | |
| ) | |
| if not saw_tokens: | |
| raise ValueError("Cannot train APE tokenizer on an empty corpus.") | |
| pre_tokens_counts = len(vocabulary_frequency) | |
| merged_counter = len(vocabulary_frequency) + 1 | |
| if save_checkpoint and checkpoint_interval <= 0: | |
| raise ValueError( | |
| "checkpoint_interval must be positive when save_checkpoint is enabled." | |
| ) | |
| checkpoint_increment = checkpoint_interval | |
| batch = checkpoint_interval + pre_tokens_counts | |
| piece_count_cache: dict[str, int] = {} | |
| def merged_piece_count(token: str) -> int: | |
| count = piece_count_cache.get(token) | |
| if count is None: | |
| count = _base_piece_count(token, self.representation) | |
| piece_count_cache[token] = count | |
| return count | |
| def get_most_common_pair(tokenized): | |
| pair_counts: defaultdict[tuple[str, str], int] = defaultdict(int) | |
| for tokens in tokenized: | |
| for i in range(len(tokens) - 1): | |
| pair = (tokens[i], tokens[i + 1]) | |
| if max_merge_pieces is not None: | |
| merged_candidate = pair[0] + pair[1] | |
| if merged_piece_count(merged_candidate) > max_merge_pieces: | |
| continue | |
| pair_counts[pair] += 1 | |
| self.pair_counts = dict(pair_counts) | |
| if not pair_counts: | |
| return ("", ""), 0 | |
| most_common_pair = ("", "") | |
| most_common_frequency = 0 | |
| for pair, count in pair_counts.items(): | |
| if count > most_common_frequency: | |
| most_common_pair = pair | |
| most_common_frequency = count | |
| return most_common_pair, most_common_frequency | |
| while True: | |
| if save_checkpoint and len(vocabulary_frequency) >= batch: | |
| self.vocabulary_frequency = dict(vocabulary_frequency) | |
| self.vocab = { | |
| **{ | |
| str(self.bos_token): 0, | |
| str(self.pad_token): 1, | |
| str(self.eos_token): 2, | |
| str(self.unk_token): 3, | |
| str(self.mask_token): 4, | |
| }, | |
| **{ | |
| word: idx | |
| for idx, word in enumerate( | |
| vocabulary_frequency.keys(), | |
| start=5, | |
| ) | |
| }, | |
| } | |
| self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} | |
| self._refresh_tokenization_cache() | |
| checkpoint_dir = Path(checkpoint_path) | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| self.save_vocabulary_file(checkpoint_dir / f"checkpoint_{batch}.json") | |
| self.save_pretrained(str(checkpoint_dir / f"checkpoint_{batch}")) | |
| print(f"Checkpoint saved at {checkpoint_dir}/checkpoint_{batch}.json") | |
| batch += checkpoint_increment | |
| if len(vocabulary_frequency) >= max_vocab_size: | |
| print("Max vocabulary achieved", flush=True) | |
| break | |
| if not tokenized_corpus: | |
| print("No more mergeable pairs", flush=True) | |
| break | |
| most_common_pair, freq = get_most_common_pair(tokenized_corpus) | |
| if freq < min_freq_for_merge: | |
| print("Not enough frequency found", flush=True) | |
| break | |
| if not most_common_pair[0] or not most_common_pair[1]: | |
| print("No valid merge pair found", flush=True) | |
| break | |
| left_token, right_token = most_common_pair | |
| merged_word = left_token + right_token | |
| if merged_word not in vocabulary_frequency: | |
| print( | |
| f"New merge found: {merged_word} {merged_counter}/{max_vocab_size} " | |
| f"{round(merged_counter / max_vocab_size * 100, 2)}%", | |
| flush=True, | |
| ) | |
| merged_counter += 1 | |
| vocabulary_frequency[merged_word] += freq | |
| new_tokenized_corpus = [] | |
| for tokens in tokenized_corpus: | |
| new_tokens = [] | |
| append_token = new_tokens.append | |
| i = 0 | |
| token_count = len(tokens) | |
| while i < token_count: | |
| if ( | |
| i < token_count - 1 | |
| and tokens[i] == left_token | |
| and tokens[i + 1] == right_token | |
| ): | |
| append_token(merged_word) | |
| i += 2 | |
| else: | |
| append_token(tokens[i]) | |
| i += 1 | |
| if len(new_tokens) > 1: | |
| new_tokenized_corpus.append(new_tokens) | |
| tokenized_corpus = new_tokenized_corpus | |
| self.vocabulary_frequency = dict(vocabulary_frequency) | |
| self.vocab = { | |
| str(self.bos_token): 0, | |
| str(self.pad_token): 1, | |
| str(self.eos_token): 2, | |
| str(self.unk_token): 3, | |
| str(self.mask_token): 4, | |
| **{word: idx for idx, word in enumerate(vocabulary_frequency.keys(), start=5)}, | |
| } | |
| self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} | |
| self._refresh_tokenization_cache() | |
| def train_from_iterator(self, iterator, *args, **kwargs) -> None: | |
| raise NotImplementedError("train_from_iterator is not implemented for APE") | |
| APEPreTrainedTokenizer.register_for_auto_class("AutoTokenizer") | |