File size: 14,202 Bytes
9e93243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adf1e27
 
 
 
 
9e93243
17b0e42
adf1e27
 
 
 
 
9e93243
 
 
 
 
 
 
17b0e42
 
 
f93cc1d
17b0e42
 
 
 
 
 
 
9e93243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df9869
 
 
9e93243
 
 
 
 
17b0e42
 
 
 
9e93243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adf1e27
9e93243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adf1e27
 
9e93243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b0e42
 
9e93243
 
 
17b0e42
 
9e93243
17b0e42
9e93243
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import numpy as np
import pickle as pkl
import os
import argparse
import pandas as pd
from pathlib import Path
import torch

import utils.chem as uc
import utils.torch_util as ut
import utils.log as ul
import utils.plot as up
import configuration.config_default as cfgd
import models.dataset as md
import preprocess.vocabulary as mv
import configuration.opts as opts
from models.transformer.module.decode import decode
from models.transformer.encode_decode.model import EncoderDecoder
# from models.seq2seq.model import Model

def prepare_input(opt):
    ''' check if the input file contain pk_diff 
    if not the pk_diff will be added 
        '''
    df_input = pd.read_csv(os.path.join(opt.data_path, opt.test_file_name + '.csv'), sep=",")
    delta_pkis=['(-0.5, 0.5]','(0.5, 1.5]','(1.5, 2.5]','(2.5, 4.5]','(4.5, 6.5]','(6.5, 8.5]','(8.5, 10.5]','(10.5, inf]']
    # TODO:这是分子到分子的generate数据处理?
    # if "Delta_pki" not in df_input.columns:
    #     newData=[]
    #     for idx,irow in df_input.iterrows():
    #         for idelta_pki in delta_pkis:
    #             newData.append([irow['fromVarSMILES'],irow['constantSMILES'],idelta_pki])
    #     df_new=pd.DataFrame(newData, columns=['fromVarSMILES','constantSMILES','Delta_pki'])
    #     df_new.to_csv(os.path.join(opt.data_path, opt.test_file_name + '_prepared'+'.csv'), index=None)
    #     return opt.test_file_name + '_prepared'
    return opt.test_file_name

class GenerateRunner():

    def __init__(self, opt):

        # self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
                                    #   f'evaluation_{opt.epoch}')
        # path = Path(os.path.join(opt.save_directory))
        # path.mkdir(parents=True, exist_ok=True)
        # self.save_path = os.path.join(path)
        # self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
        # self.overwrite = opt.overwrite
        self.dev_no = opt.dev_no
        self.device = torch.device('cpu')
        # global LOG
        # LOG = ul.get_logger(name="generate",
        #                     log_path=os.path.join(self.save_path, 'generate.log'))
        # LOG.info(opt)
        # LOG.info("Save directory: {}".format(self.save_path))

        # Load vocabulary
        with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
            vocab = pkl.load(input_file)
        self.vocab = vocab
        self.tokenizer = mv.SMILESTokenizer()

        # 加载模型
        file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
        if opt.model_choice == 'transformer':
            self.model = EncoderDecoder.load_from_file(file_name)
            self.model.to(self.device)
            self.model.eval()
        elif opt.model_choice == 'seq2seq':
            self.model = Model.load_from_file(file_name, evaluation_mode=True)
            self.model.network.encoder.to(self.device)
            self.model.network.decoder.to(self.device)

    def initialize_dataloader(self, opt, vocab, test_file):
        """
        Initialize dataloader
        :param opt:
        :param vocab: vocabulary
        :param test_file: test_file_name
        :return:
        """

        # Read test
        data = pd.read_csv(os.path.join(opt.data_path, test_file + '.csv'), sep=",")
        dataset = md.Dataset(data=data, vocabulary=vocab, tokenizer=self.tokenizer, prediction_mode=True)
        dataloader = torch.utils.data.DataLoader(dataset, opt.batch_size,
                                                 shuffle=False, collate_fn=md.Dataset.collate_fn)
        return dataloader

    def generate(self, opt):
        # if not self.overwrite and self.exist_flag:
        #     print('GENERATED MOL EXIST, SKIP GENERATING!')
        #     return
        # set device
        #device = ut.allocate_gpu()
        # torch.cuda.set_device(1)
        # current_device = torch.cuda.current_device()
        # print("当前使用的 CUDA 设备编号是:", current_device)
        # device = torch.device(f'cuda:{self.dev_no}')
        device = torch.device('cpu')
        print(f"-------device:---------")
        print(device)
        # 构造loader
        dataloader_test = self.initialize_dataloader(opt, self.vocab, opt.test_file_name)

        # Load model
        file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
        if opt.model_choice == 'transformer':
            model = EncoderDecoder.load_from_file(file_name)
            model.to(device)
            model.eval()
        elif opt.model_choice == 'seq2seq':
            model = Model.load_from_file(file_name, evaluation_mode=True)
            # move to GPU
            model.network.encoder.to(device)
            model.network.decoder.to(device)
        # TODO: 有没有可能超长?模型崩溃调整长度,规则是2的倍数
        max_len = cfgd.DATA_DEFAULT['max_sequence_length']
        df_list = []
        sampled_smiles_list = []
        for j, batch in enumerate(ul.progress_bar(dataloader_test, total=len(dataloader_test))):

            # df是dataframe 是一行的原始数据
            src, source_length, _, src_mask, _, _, df = batch

            # Move to GPU
            src = src.to(device)
            src_mask = src_mask.to(device)
            smiles= self.sample(opt.model_choice, model, src, src_mask,
                                                                       source_length,
                                                                       opt.decode_type,
                                                                       num_samples=opt.num_samples,
                                                                       max_len=max_len,
                                                                       device=device)

            df_list.append(df)
            sampled_smiles_list.extend(smiles)

        # prepare dataframe
        data_sorted = pd.concat(df_list)
        sampled_smiles_list = np.array(sampled_smiles_list)

        for i in range(opt.num_samples):
            data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]

        result_path = os.path.join(self.save_path, "generated_molecules.csv")
        # LOG.info("Save to {}".format(result_path))
        data_sorted.to_csv(result_path, index=False)

    def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
               max_len=cfgd.DATA_DEFAULT['max_sequence_length'],
               device=None):
        batch_size = src.shape[0]
        num_valid_batch = np.zeros(batch_size)  # current number of unique and valid samples out of total sampled
        num_valid_batch_total = np.zeros(batch_size)  # current number of sampling times no matter unique or valid
        num_valid_batch_desired = np.asarray([num_samples] * batch_size)
        unique_set_num_samples = [set() for i in range(batch_size)]   # for each starting molecule
        batch_index = torch.LongTensor(range(batch_size))
        batch_index_current = torch.LongTensor(range(batch_size)).to(device)
        # TODO:这个好像没有用到?
        start_mols = []
        # zeros correspondes to ****** which is valid according to RDKit
        sequences_all = torch.ones((num_samples, batch_size, max_len))
        sequences_all = sequences_all.type(torch.LongTensor)
        max_trials = 100000  # Maximum trials for sampling
        current_trials = 0

        # greedy意思是只尝试一次生成,成了就有分子式,没成的话就没有
        if decode_type == 'greedy':
            max_trials = 1

        # Set of unique starting molecules
        if src is not None:
            # 这里需要修改,delta_value并不是放在第一位置
            start_ind = len(cfgd.PROPERTIES)
            for ibatch in range(batch_size):
                source_smi = self.tokenizer.untokenize(self.vocab.decode(src[ibatch].tolist()[start_ind:]))
                source_smi = uc.get_canonical_smile(source_smi)
                if source_smi:
                    # 先添加source,用于后面去重,TODO: 但这里也不太对,因为这里已经是被mmpdb分开的,而不是一个完整的SMILES
                    unique_set_num_samples[ibatch].add(source_smi)
                    start_mols.append(source_smi)

        with torch.no_grad():
            if model_choice == 'seq2seq':
                encoder_outputs, decoder_hidden = model.network.encoder(src, source_length)
            while not all(num_valid_batch >= num_valid_batch_desired) and current_trials < max_trials:
                current_trials += 1

                # batch input for current trial
                if src is not None:
                    # 这个不就是全选嘛?
                    src_current = src.index_select(0, batch_index_current)
                if src_mask is not None:
                    mask_current = src_mask.index_select(0, batch_index_current)
                batch_size = src_current.shape[0]

                # sample molecule
                if model_choice == 'transformer':
                    sequences = decode(model, src_current, mask_current, max_len, decode_type)
                    padding = (0, max_len-sequences.shape[1],
                               0, 0)
                    sequences = torch.nn.functional.pad(sequences, padding)
                elif model_choice == 'seq2seq':
                    sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
                                                    encoder_outputs, max_len, device)
                # else:
                #     LOG.info('Specify transformer or seq2seq for model_choice')

                # Check valid and unique
                smiles = []
                is_valid_index = []
                batch_index_map = dict(zip(list(range(batch_size)), batch_index_current))
                # Valid, ibatch index is different from original, need map back
                for ibatch in range(batch_size):
                    seq = sequences[ibatch]
                    smi = self.tokenizer.untokenize(self.vocab.decode(seq.cpu().numpy()))
                    smi = uc.get_canonical_smile(smi)
                    smiles.append(smi)
                    # valid and not same as starting molecules
                    if uc.is_valid(smi):
                        is_valid_index.append(ibatch)
                    # total sampled times
                    num_valid_batch_total[batch_index_map[ibatch]] += 1

                # Check if duplicated and update num_valid_batch and unique
                for good_index in is_valid_index:
                    index_in_original_batch = batch_index_map[good_index]
                    if smiles[good_index] not in unique_set_num_samples[index_in_original_batch]:
                        unique_set_num_samples[index_in_original_batch].add(smiles[good_index])
                        num_valid_batch[index_in_original_batch] += 1

                        sequences_all[int(num_valid_batch[index_in_original_batch] - 1), index_in_original_batch, :] = \
                            sequences[good_index]

                not_completed_index = np.where(num_valid_batch < num_valid_batch_desired)[0]
                # 选择未生成满的source样本继续生成
                if len(not_completed_index) > 0:
                    batch_index_current = batch_index.index_select(0, torch.LongTensor(not_completed_index)).to(device)

        # Convert to SMILES
        smiles_list = [] # [batch, topk]
        seqs = np.asarray(sequences_all.numpy())
        # [num_sample, batch_size, max_len]
        batch_size = len(seqs[0])
        for ibatch in range(batch_size):
            topk_list = []
            for k in range(num_samples):
                seq = seqs[k, ibatch, :]
                topk_list.extend([self.tokenizer.untokenize(self.vocab.decode(seq))])
            smiles_list.append(topk_list)

        return smiles_list

    def sample_seq2seq(self, model, mask, batch_index_current, decoder_hidden, encoder_outputs, max_len, device):
        # batch size will change when some of the generated molecules are valid
        encoder_outputs_current = encoder_outputs.index_select(0, batch_index_current)
        batch_size = encoder_outputs_current.shape[0]

        # start token
        start_token = torch.zeros(batch_size, dtype=torch.long)
        start_token[:] = self.vocab["^"]
        decoder_input = start_token.to(device)
        sequences = []
        mask = torch.squeeze(mask, 1).to(device)

        # initial decoder hidden states
        if isinstance(decoder_hidden, tuple):
            decoder_hidden_current = (decoder_hidden[0].index_select(1, batch_index_current),
                                      decoder_hidden[1].index_select(1, batch_index_current))
        else:
            decoder_hidden_current = decoder_hidden.index_select(1, batch_index_current)
        for i in range(max_len):
            logits, decoder_hidden_current = model.network.decoder(decoder_input.unsqueeze(1),
                                                                  decoder_hidden_current,
                                                                  encoder_outputs_current, mask)
            logits = logits.squeeze(1)
            probabilities = logits.softmax(dim=1)  # torch.Size([batch_size, vocab_size])
            topi = torch.multinomial(probabilities, 1)  # torch.Size([batch_size, 1])
            decoder_input = topi.view(-1).detach()
            sequences.append(decoder_input.view(-1, 1))

        sequences = torch.cat(sequences, 1)
        return sequences

def run_main():
    """Main function."""
    parser = argparse.ArgumentParser(
        description='generate.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    print("--------parser-------------")
    print(parser)
    opts.generate_opts(parser)
    opt = parser.parse_args()
    opt.test_file_name = prepare_input(opt)
    print("opt输出如下")
    print(opt)
    runner = GenerateRunner(opt)
    print()
    runner.generate(opt)


if __name__ == "__main__":
    run_main()