File size: 4,351 Bytes
3d83373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import List

from tokenizers import NormalizedString, PreTokenizedString
from tokenizers.pre_tokenizers import PreTokenizer
from transformers import PreTrainedTokenizerFast

import re


ATOM_REGEX_PATTERN = r"""(<(.*?)>|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
FORMULATION_REGEX_PATTERN = r"""(<(.*?)>|[-+]?\d*\.\d+|[-+]?\d+\.?\d*[eE][-+]?\d+|[-+]?\d+|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
NUMBER_REGEX_PATTERN = r"""(\d{2}|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d+|\(|\))"""
# NUMBER_REGEX_PATTERN = r"""((?<!\d)\d{2}(?!\d)|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d)"""
# NUMBER_REGEX_PATTERN = r"""(\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d|\b\d{2}\b)"""
SPECIAL_REGEX_PATTERN = r"""<(.*?)>"""


class MoleculePreTokenizer:

    def molecule_based_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        splits = []
        if str(normalized_string).startswith(('<smiles>', '<selfies>', '<polymer_spg>')):
            for m in re.finditer(ATOM_REGEX_PATTERN, str(normalized_string)):
                start = m.start(0)
                stop = m.end(0)
                if start == 0:  # remove special tokens
                    continue
                splits.append(normalized_string[start:stop])
        elif str(normalized_string).startswith('<formulation_start>'):
            for m in re.finditer(FORMULATION_REGEX_PATTERN, str(normalized_string)):
                start = m.start(0)
                stop = m.end(0)
                if start == 0 or stop == len(str(normalized_string)):  # remove special tokens
                    continue
                splits.append(normalized_string[start:stop])
        elif str(normalized_string).startswith(('<formula>', '<inchi>')):
            for m in re.finditer(NUMBER_REGEX_PATTERN, str(normalized_string)):
                start = m.start(0)
                stop = m.end(0)
                if start == 0:  # remove special tokens
                    continue
                splits.append(normalized_string[start:stop])
        else:
            last = 0
            for m in re.finditer(SPECIAL_REGEX_PATTERN, str(normalized_string)):  # remove special tokens
                start = m.start(0)
                stop = m.end(0)
                # splits.append(normalized_string[start:stop])
                last = stop
            splits.append(normalized_string[last:])

        return splits

    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.molecule_based_split)


class MultiMolTranBertTokenizer(PreTrainedTokenizerFast):
        def __init__(self, vocab_file: str = '',
                    do_lower_case=False,
                    cls_token='<bos>',
                    eos_token='<sep>',
                    pad_token='<pad>',
                    unk_token='<unk>',
                    mask_token='<mask>',
                    **kwargs):

            super().__init__(
                tokenizer_file=vocab_file,
                bos_token=cls_token,
                eos_token=eos_token,
                pad_token=pad_token,
                unk_token=unk_token,
                mask_token=mask_token
            )

        def get_padding_idx(self):
            return 2

        def convert_idx_to_tokens(self, idx_tensor):
            tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
            return tokens

        def convert_tokens_to_string(self, tokens):
            stopwords = ['<bos>', '<eos>']
            clean_tokens = [word for word in tokens if word not in stopwords]
            out_string = ''.join(clean_tokens)
            return out_string

        def idx_to_smiles(self, torch_model, idx):
            '''Convert tokens idx back to SMILES text'''
            rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
            flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
            decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
            return decoded_smiles


def load_tokenizer(vocab_file, **kwargs):
    tokenizer = MultiMolTranBertTokenizer(vocab_file, **kwargs)
    tokenizer.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(MoleculePreTokenizer())
    return tokenizer