MVP / massspecgym /models /tokenizers.py
yzhouchen001's picture
partial push
94aa6f9
import pandas as pd
import typing as T
import selfies as sf
from tokenizers import ByteLevelBPETokenizer
from tokenizers import Tokenizer, processors, models
from tokenizers.implementations import BaseTokenizer, ByteLevelBPETokenizer
import massspecgym.utils as utils
from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
class SpecialTokensBaseTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: Tokenizer,
max_len: T.Optional[int] = None,
):
"""Initialize the base tokenizer with special tokens performing padding and truncation."""
super().__init__(tokenizer)
# Save essential attributes
self.pad_token = PAD_TOKEN
self.sos_token = SOS_TOKEN
self.eos_token = EOS_TOKEN
self.unk_token = UNK_TOKEN
self.max_length = max_len
# Add special tokens
self.add_special_tokens([self.pad_token, self.sos_token, self.eos_token, self.unk_token])
# Get token IDs
self.pad_token_id = self.token_to_id(self.pad_token)
self.sos_token_id = self.token_to_id(self.sos_token)
self.eos_token_id = self.token_to_id(self.eos_token)
self.unk_token_id = self.token_to_id(self.unk_token)
# Enable padding
self.enable_padding(
direction="right",
pad_token=self.pad_token,
pad_id=self.pad_token_id,
length=max_len,
)
# Enable truncation
self.enable_truncation(max_len)
# Set post-processing to add SOS and EOS tokens
self._tokenizer.post_processor = processors.TemplateProcessing(
single=f"{self.sos_token} $A {self.eos_token}",
pair=f"{self.sos_token} $A {self.eos_token} {self.sos_token} $B {self.eos_token}",
special_tokens=[
(self.sos_token, self.sos_token_id),
(self.eos_token, self.eos_token_id),
],
)
class SelfiesTokenizer(SpecialTokensBaseTokenizer):
def __init__(
self,
selfies_train: T.Optional[T.Union[str, T.List[str]]] = None,
**kwargs
):
"""
Initialize the SELFIES tokenizer with optional training data to build a vocanulary.
Args:
selfies_train (str or list of str): Either a list of SELFIES strings to build the vocabulary from,
or a `semantic_robust_alphabet` string indicating the usahe of `selfies.get_semantic_robust_alphabet()`
alphabet. If None, the MassSpecGym training molecules will be used.
"""
if selfies_train == 'semantic_robust_alphabet':
alphabet = list(sorted(sf.get_semantic_robust_alphabet()))
else:
if not selfies_train:
selfies_train = utils.load_train_mols()
selfies = [sf.encoder(s, strict=False) for s in selfies_train]
else:
selfies = selfies_train
alphabet = list(sorted(sf.get_alphabet_from_selfies(selfies)))
vocab = {symbol: i for i, symbol in enumerate(alphabet)}
vocab[UNK_TOKEN] = len(vocab)
tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token=UNK_TOKEN))
super().__init__(tokenizer, **kwargs)
def encode(self, text: str, add_special_tokens: bool = True) -> Tokenizer:
"""Encodes a SMILES string into a list of SELFIES token IDs."""
selfies_string = sf.encoder(text, strict=False)
selfies_tokens = list(sf.split_selfies(selfies_string))
return super().encode(
selfies_tokens, is_pretokenized=True, add_special_tokens=add_special_tokens
)
def decode(self, token_ids: T.List[int], skip_special_tokens: bool = True) -> str:
"""Decodes a list of SELFIES token IDs back into a SMILES string."""
selfies_string = super().decode(
token_ids, skip_special_tokens=skip_special_tokens
)
selfies_string = self._decode_wordlevel_str_to_selfies(selfies_string)
return sf.decoder(selfies_string)
def encode_batch(
self, texts: T.List[str], add_special_tokens: bool = True
) -> T.List[Tokenizer]:
"""Encodes a batch of SMILES strings into a list of SELFIES token IDs."""
selfies_strings = [
list(sf.split_selfies(sf.encoder(text, strict=False))) for text in texts
]
return super().encode_batch(
selfies_strings, is_pretokenized=True, add_special_tokens=add_special_tokens
)
def decode_batch(
self, token_ids_batch: T.List[T.List[int]], skip_special_tokens: bool = True
) -> T.List[str]:
"""Decodes a batch of SELFIES token IDs back into SMILES strings."""
selfies_strings = super().decode_batch(
token_ids_batch, skip_special_tokens=skip_special_tokens
)
return [
sf.decoder(
self._decode_wordlevel_str_to_selfies(
selfies_string
)
)
for selfies_string in selfies_strings
]
def _decode_wordlevel_str_to_selfies(self, text: str) -> str:
"""Converts a WordLevel string back to a SELFIES string."""
return text.replace(" ", "")
class SmilesBPETokenizer(SpecialTokensBaseTokenizer):
def __init__(self, smiles_pth: T.Optional[str] = None, **kwargs):
"""
Initialize the BPE tokenizer for SMILES strings, with optional training data.
Args:
smiles_pth (str): Path to a file containing SMILES strings to train the tokenizer on. If None,
the MassSpecGym training molecules will be used.
"""
tokenizer = ByteLevelBPETokenizer()
if smiles_pth:
tokenizer.train(smiles_pth)
else:
smiles = utils.load_unlabeled_mols("smiles").tolist()
smiles += utils.load_train_mols().tolist()
print(f"Training tokenizer on {len(smiles)} SMILES strings.")
tokenizer.train_from_iterator(smiles)
super().__init__(tokenizer, **kwargs)