File size: 7,152 Bytes
f3b11f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# coding=utf-8

"""
Implementation of a SMILES dataset.
"""
import pandas as pd

import torch
import torch.utils.data as tud
from torch.autograd import Variable

import configuration.config_default as cfgd
from models.transformer.module.subsequent_mask import subsequent_mask

from rdkit.Chem.SaltRemover import SaltRemover
import random
import rdkit.Chem as rkc
from common.utils import Data_Type

class Dataset(tud.Dataset):
    """Custom PyTorch Dataset that takes a file containing
    Source_Mol_ID,Target_Mol_ID,Source_Mol,Target_Mol,
    Source_Mol_LogD,Target_Mol_LogD,Delta_LogD,
    Source_Mol_Solubility,Target_Mol_Solubility,Delta_Solubility,
    Source_Mol_Clint,Target_Mol_Clint,Delta_Clint,
    Transformation,Core"""

    def __init__(self, data, vocabulary, tokenizer, prediction_mode=False, use_random=False, data_type=Data_Type.frag.value):
        """

        :param data: dataframe read from training, validation or test file
        :param vocabulary: used to encode source/target tokens
        :param tokenizer: used to tokenize source/target smiles
        :param prediction_mode: if use target smiles or not (training or test)
        """
        self._vocabulary = vocabulary
        self._tokenizer = tokenizer
        self._data = data
        self._prediction_mode = prediction_mode
        self._use_random = use_random
        self._data_type = data_type

    def smiles_preprocess(self, smiles, random_type="unrestricted"):
        """
        Returns a random SMILES given a SMILES of a molecule.
        :param mol: A Mol object
        :param random_type: The type (unrestricted, restricted) of randomization performed.
        :return : A random SMILES string of the same molecule or None if the molecule is invalid.
        """
        if not self._use_random:
            return smiles
        mol = rkc.MolFromSmiles(smiles)
        if not mol:
            return None

        remover = SaltRemover()  ## default salt remover
        if random_type == "unrestricted":
            stripped = remover.StripMol(mol)
            if stripped == None:
                return smiles
            ret = rkc.MolToSmiles(stripped, canonical=False, doRandom=True, isomericSmiles=False)
            if not bool(ret):
                return smiles
            return ret
        if random_type == "restricted":
            new_atom_order = list(range(mol.GetNumAtoms()))
            random.shuffle(new_atom_order)
            random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order)
            ret = rkc.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
            if not bool(ret):
                return smiles
            return ret
        raise ValueError("Type '{}' is not valid".format(random_type))


    def __getitem__(self, i):
        """
        Tokenize and encode source smile and/or target smile (if prediction_mode is True)
        :param i:
        :return:
        """

        row = self._data.iloc[i]
        # tokenize and encode source smiles
        main_cls = row['main_cls']
        minor_cls = row['minor_cls']
        target_name = row['target_name']
        target_name = target_name if isinstance(target_name, str) else ''
        value = row['Delta_Value']
        # value = row['Delta_pki']
        source_tokens = []

        if self._data_type == Data_Type.frag.value:
            sourceConstant = self.smiles_preprocess(row['constantSMILES'])
            sourceVariable = self.smiles_preprocess(row['fromVarSMILES'])
            # 先variable
            source_tokens.extend(self._tokenizer.tokenize(sourceVariable))  ## add source variable SMILES token
            # 接着constant
            source_tokens.extend(self._tokenizer.tokenize(sourceConstant)) ## add source constant SMILES token
        elif self._data_type == Data_Type.whole.value:
            sourceSmi = self.smiles_preprocess(row['cpd1SMILES'])
            source_tokens.extend(self._tokenizer.tokenize(sourceSmi))
        # 再 major class eg activity
        source_tokens.append(main_cls)
        # 再 minor class eg Ki
        source_tokens.append(minor_cls)
        # 然后value
        source_tokens.append(value)
        # 然后target name
        source_tokens.extend(list(target_name))

        source_encoded = self._vocabulary.encode(source_tokens)
        
        # print(source_tokens,'\n=====\n', source_encoded)
        # tokenize and encode target smiles if it is for training instead of evaluation
        if not self._prediction_mode:
            target_smi = ''
            if self._data_type == Data_Type.frag.value:
                target_smi = row['toVarSMILES']
            elif self._data_type == Data_Type.whole.value:
                target_smi = row['cpd2SMILES']
            target_tokens = self._tokenizer.tokenize(target_smi)
            target_encoded = self._vocabulary.encode(target_tokens)

            return torch.tensor(source_encoded, dtype=torch.long), torch.tensor(
                target_encoded, dtype=torch.long), row
        else:
            return torch.tensor(source_encoded, dtype=torch.long),  row

    def __len__(self):
        return len(self._data)

    @classmethod
    def collate_fn(cls, data_all):
        # sort based on source sequence's length
        data_all.sort(key=lambda x: len(x[0]), reverse=True)
        is_prediction_mode = True if len(data_all[0]) == 2 else False
        if is_prediction_mode:
            source_encoded, data = zip(*data_all)
            data = pd.DataFrame(data)
        
        else:
            source_encoded, target_encoded, data = zip(*data_all)
            data = pd.DataFrame(data)
        
        # maximum length of source sequences
        max_length_source = max([seq.size(0) for seq in source_encoded])
        # print('=====max len', max_length_source)
        # padded source sequences with zeroes
        collated_arr_source = torch.zeros(len(source_encoded), max_length_source, dtype=torch.long)
        for i, seq in enumerate(source_encoded):
            collated_arr_source[i, :seq.size(0)] = seq
        # length of each source sequence
        source_length = [seq.size(0) for seq in source_encoded]
        source_length = torch.tensor(source_length)
        # mask of source seqs
        src_mask = (collated_arr_source !=0).unsqueeze(-2)

        # target seq
        if not is_prediction_mode:
            max_length_target = max([seq.size(0) for seq in target_encoded])
            collated_arr_target = torch.zeros(len(target_encoded), max_length_target, dtype=torch.long)
            for i, seq in enumerate(target_encoded):
                collated_arr_target[i, :seq.size(0)] = seq

            trg_mask = (collated_arr_target != 0).unsqueeze(-2)
            trg_mask = trg_mask & Variable(subsequent_mask(collated_arr_target.size(-1)).type_as(trg_mask))
            trg_mask = trg_mask[:, :-1, :-1]  # save start token, skip end token
        else:
            trg_mask = None
            max_length_target = None
            collated_arr_target = None

        return collated_arr_source, source_length, collated_arr_target, src_mask, trg_mask, max_length_target, data