from typing import List from tokenizers import NormalizedString, PreTokenizedString from tokenizers.pre_tokenizers import PreTokenizer from transformers import PreTrainedTokenizerFast import re ATOM_REGEX_PATTERN = r"""(<(.*?)>|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])""" FORMULATION_REGEX_PATTERN = r"""(<(.*?)>|[-+]?\d*\.\d+|[-+]?\d+\.?\d*[eE][-+]?\d+|[-+]?\d+|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])""" NUMBER_REGEX_PATTERN = r"""(\d{2}|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d+|\(|\))""" # NUMBER_REGEX_PATTERN = r"""((?""" class MoleculePreTokenizer: def molecule_based_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: splits = [] if str(normalized_string).startswith(('', '', '')): for m in re.finditer(ATOM_REGEX_PATTERN, str(normalized_string)): start = m.start(0) stop = m.end(0) if start == 0: # remove special tokens continue splits.append(normalized_string[start:stop]) elif str(normalized_string).startswith(''): for m in re.finditer(FORMULATION_REGEX_PATTERN, str(normalized_string)): start = m.start(0) stop = m.end(0) if start == 0 or stop == len(str(normalized_string)): # remove special tokens continue splits.append(normalized_string[start:stop]) elif str(normalized_string).startswith(('', '')): for m in re.finditer(NUMBER_REGEX_PATTERN, str(normalized_string)): start = m.start(0) stop = m.end(0) if start == 0: # remove special tokens continue splits.append(normalized_string[start:stop]) else: last = 0 for m in re.finditer(SPECIAL_REGEX_PATTERN, str(normalized_string)): # remove special tokens start = m.start(0) stop = m.end(0) # splits.append(normalized_string[start:stop]) last = stop splits.append(normalized_string[last:]) return splits def pre_tokenize(self, pretok: PreTokenizedString): pretok.split(self.molecule_based_split) class MultiMolTranBertTokenizer(PreTrainedTokenizerFast): def __init__(self, vocab_file: str = '', do_lower_case=False, cls_token='', eos_token='', pad_token='', unk_token='', mask_token='', **kwargs): super().__init__( tokenizer_file=vocab_file, bos_token=cls_token, eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, mask_token=mask_token ) def get_padding_idx(self): return 2 def convert_idx_to_tokens(self, idx_tensor): tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] return tokens def convert_tokens_to_string(self, tokens): stopwords = ['', ''] clean_tokens = [word for word in tokens if word not in stopwords] out_string = ''.join(clean_tokens) return out_string def idx_to_smiles(self, torch_model, idx): '''Convert tokens idx back to SMILES text''' rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx) flat_list_tokens = [item for sublist in rev_tokens for item in sublist] decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens) return decoded_smiles def load_tokenizer(vocab_file, **kwargs): tokenizer = MultiMolTranBertTokenizer(vocab_file, **kwargs) tokenizer.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(MoleculePreTokenizer()) return tokenizer