Spaces:
Running
Running
File size: 6,108 Bytes
94aa6f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import pandas as pd
import typing as T
import selfies as sf
from tokenizers import ByteLevelBPETokenizer
from tokenizers import Tokenizer, processors, models
from tokenizers.implementations import BaseTokenizer, ByteLevelBPETokenizer
import massspecgym.utils as utils
from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
class SpecialTokensBaseTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: Tokenizer,
max_len: T.Optional[int] = None,
):
"""Initialize the base tokenizer with special tokens performing padding and truncation."""
super().__init__(tokenizer)
# Save essential attributes
self.pad_token = PAD_TOKEN
self.sos_token = SOS_TOKEN
self.eos_token = EOS_TOKEN
self.unk_token = UNK_TOKEN
self.max_length = max_len
# Add special tokens
self.add_special_tokens([self.pad_token, self.sos_token, self.eos_token, self.unk_token])
# Get token IDs
self.pad_token_id = self.token_to_id(self.pad_token)
self.sos_token_id = self.token_to_id(self.sos_token)
self.eos_token_id = self.token_to_id(self.eos_token)
self.unk_token_id = self.token_to_id(self.unk_token)
# Enable padding
self.enable_padding(
direction="right",
pad_token=self.pad_token,
pad_id=self.pad_token_id,
length=max_len,
)
# Enable truncation
self.enable_truncation(max_len)
# Set post-processing to add SOS and EOS tokens
self._tokenizer.post_processor = processors.TemplateProcessing(
single=f"{self.sos_token} $A {self.eos_token}",
pair=f"{self.sos_token} $A {self.eos_token} {self.sos_token} $B {self.eos_token}",
special_tokens=[
(self.sos_token, self.sos_token_id),
(self.eos_token, self.eos_token_id),
],
)
class SelfiesTokenizer(SpecialTokensBaseTokenizer):
def __init__(
self,
selfies_train: T.Optional[T.Union[str, T.List[str]]] = None,
**kwargs
):
"""
Initialize the SELFIES tokenizer with optional training data to build a vocanulary.
Args:
selfies_train (str or list of str): Either a list of SELFIES strings to build the vocabulary from,
or a `semantic_robust_alphabet` string indicating the usahe of `selfies.get_semantic_robust_alphabet()`
alphabet. If None, the MassSpecGym training molecules will be used.
"""
if selfies_train == 'semantic_robust_alphabet':
alphabet = list(sorted(sf.get_semantic_robust_alphabet()))
else:
if not selfies_train:
selfies_train = utils.load_train_mols()
selfies = [sf.encoder(s, strict=False) for s in selfies_train]
else:
selfies = selfies_train
alphabet = list(sorted(sf.get_alphabet_from_selfies(selfies)))
vocab = {symbol: i for i, symbol in enumerate(alphabet)}
vocab[UNK_TOKEN] = len(vocab)
tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token=UNK_TOKEN))
super().__init__(tokenizer, **kwargs)
def encode(self, text: str, add_special_tokens: bool = True) -> Tokenizer:
"""Encodes a SMILES string into a list of SELFIES token IDs."""
selfies_string = sf.encoder(text, strict=False)
selfies_tokens = list(sf.split_selfies(selfies_string))
return super().encode(
selfies_tokens, is_pretokenized=True, add_special_tokens=add_special_tokens
)
def decode(self, token_ids: T.List[int], skip_special_tokens: bool = True) -> str:
"""Decodes a list of SELFIES token IDs back into a SMILES string."""
selfies_string = super().decode(
token_ids, skip_special_tokens=skip_special_tokens
)
selfies_string = self._decode_wordlevel_str_to_selfies(selfies_string)
return sf.decoder(selfies_string)
def encode_batch(
self, texts: T.List[str], add_special_tokens: bool = True
) -> T.List[Tokenizer]:
"""Encodes a batch of SMILES strings into a list of SELFIES token IDs."""
selfies_strings = [
list(sf.split_selfies(sf.encoder(text, strict=False))) for text in texts
]
return super().encode_batch(
selfies_strings, is_pretokenized=True, add_special_tokens=add_special_tokens
)
def decode_batch(
self, token_ids_batch: T.List[T.List[int]], skip_special_tokens: bool = True
) -> T.List[str]:
"""Decodes a batch of SELFIES token IDs back into SMILES strings."""
selfies_strings = super().decode_batch(
token_ids_batch, skip_special_tokens=skip_special_tokens
)
return [
sf.decoder(
self._decode_wordlevel_str_to_selfies(
selfies_string
)
)
for selfies_string in selfies_strings
]
def _decode_wordlevel_str_to_selfies(self, text: str) -> str:
"""Converts a WordLevel string back to a SELFIES string."""
return text.replace(" ", "")
class SmilesBPETokenizer(SpecialTokensBaseTokenizer):
def __init__(self, smiles_pth: T.Optional[str] = None, **kwargs):
"""
Initialize the BPE tokenizer for SMILES strings, with optional training data.
Args:
smiles_pth (str): Path to a file containing SMILES strings to train the tokenizer on. If None,
the MassSpecGym training molecules will be used.
"""
tokenizer = ByteLevelBPETokenizer()
if smiles_pth:
tokenizer.train(smiles_pth)
else:
smiles = utils.load_unlabeled_mols("smiles").tolist()
smiles += utils.load_train_mols().tolist()
print(f"Training tokenizer on {len(smiles)} SMILES strings.")
tokenizer.train_from_iterator(smiles)
super().__init__(tokenizer, **kwargs)
|