import numpy as np import pickle as pkl import os import argparse import pandas as pd from pathlib import Path import torch import utils.chem as uc import utils.torch_util as ut import utils.log as ul import utils.plot as up import configuration.config_default as cfgd import models.dataset as md import preprocess.vocabulary as mv import configuration.opts as opts from models.transformer.module.decode import decode from models.transformer.encode_decode.model import EncoderDecoder # from models.seq2seq.model import Model def prepare_input(opt): ''' check if the input file contain pk_diff if not the pk_diff will be added ''' df_input = pd.read_csv(os.path.join(opt.data_path, opt.test_file_name + '.csv'), sep=",") delta_pkis=['(-0.5, 0.5]','(0.5, 1.5]','(1.5, 2.5]','(2.5, 4.5]','(4.5, 6.5]','(6.5, 8.5]','(8.5, 10.5]','(10.5, inf]'] # TODO:这是分子到分子的generate数据处理? # if "Delta_pki" not in df_input.columns: # newData=[] # for idx,irow in df_input.iterrows(): # for idelta_pki in delta_pkis: # newData.append([irow['fromVarSMILES'],irow['constantSMILES'],idelta_pki]) # df_new=pd.DataFrame(newData, columns=['fromVarSMILES','constantSMILES','Delta_pki']) # df_new.to_csv(os.path.join(opt.data_path, opt.test_file_name + '_prepared'+'.csv'), index=None) # return opt.test_file_name + '_prepared' return opt.test_file_name class GenerateRunner(): def __init__(self, opt): # self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name, # f'evaluation_{opt.epoch}') # path = Path(os.path.join(opt.save_directory)) # path.mkdir(parents=True, exist_ok=True) # self.save_path = os.path.join(path) # self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists() # self.overwrite = opt.overwrite self.dev_no = opt.dev_no self.device = torch.device('cpu') # global LOG # LOG = ul.get_logger(name="generate", # log_path=os.path.join(self.save_path, 'generate.log')) # LOG.info(opt) # LOG.info("Save directory: {}".format(self.save_path)) # Load vocabulary with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file: vocab = pkl.load(input_file) self.vocab = vocab self.tokenizer = mv.SMILESTokenizer() # 加载模型 file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt') if opt.model_choice == 'transformer': self.model = EncoderDecoder.load_from_file(file_name) self.model.to(self.device) self.model.eval() elif opt.model_choice == 'seq2seq': self.model = Model.load_from_file(file_name, evaluation_mode=True) self.model.network.encoder.to(self.device) self.model.network.decoder.to(self.device) def initialize_dataloader(self, opt, vocab, test_file): """ Initialize dataloader :param opt: :param vocab: vocabulary :param test_file: test_file_name :return: """ # Read test data = pd.read_csv(os.path.join(opt.data_path, test_file + '.csv'), sep=",") dataset = md.Dataset(data=data, vocabulary=vocab, tokenizer=self.tokenizer, prediction_mode=True) dataloader = torch.utils.data.DataLoader(dataset, opt.batch_size, shuffle=False, collate_fn=md.Dataset.collate_fn) return dataloader def generate(self, opt): # if not self.overwrite and self.exist_flag: # print('GENERATED MOL EXIST, SKIP GENERATING!') # return # set device #device = ut.allocate_gpu() # torch.cuda.set_device(1) # current_device = torch.cuda.current_device() # print("当前使用的 CUDA 设备编号是:", current_device) # device = torch.device(f'cuda:{self.dev_no}') device = torch.device('cpu') print(f"-------device:---------") print(device) # 构造loader dataloader_test = self.initialize_dataloader(opt, self.vocab, opt.test_file_name) # Load model file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt') if opt.model_choice == 'transformer': model = EncoderDecoder.load_from_file(file_name) model.to(device) model.eval() elif opt.model_choice == 'seq2seq': model = Model.load_from_file(file_name, evaluation_mode=True) # move to GPU model.network.encoder.to(device) model.network.decoder.to(device) # TODO: 有没有可能超长?模型崩溃调整长度,规则是2的倍数 max_len = cfgd.DATA_DEFAULT['max_sequence_length'] df_list = [] sampled_smiles_list = [] for j, batch in enumerate(ul.progress_bar(dataloader_test, total=len(dataloader_test))): # df是dataframe 是一行的原始数据 src, source_length, _, src_mask, _, _, df = batch # Move to GPU src = src.to(device) src_mask = src_mask.to(device) smiles= self.sample(opt.model_choice, model, src, src_mask, source_length, opt.decode_type, num_samples=opt.num_samples, max_len=max_len, device=device) df_list.append(df) sampled_smiles_list.extend(smiles) # prepare dataframe data_sorted = pd.concat(df_list) sampled_smiles_list = np.array(sampled_smiles_list) for i in range(opt.num_samples): data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i] result_path = os.path.join(self.save_path, "generated_molecules.csv") # LOG.info("Save to {}".format(result_path)) data_sorted.to_csv(result_path, index=False) def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10, max_len=cfgd.DATA_DEFAULT['max_sequence_length'], device=None): batch_size = src.shape[0] num_valid_batch = np.zeros(batch_size) # current number of unique and valid samples out of total sampled num_valid_batch_total = np.zeros(batch_size) # current number of sampling times no matter unique or valid num_valid_batch_desired = np.asarray([num_samples] * batch_size) unique_set_num_samples = [set() for i in range(batch_size)] # for each starting molecule batch_index = torch.LongTensor(range(batch_size)) batch_index_current = torch.LongTensor(range(batch_size)).to(device) # TODO:这个好像没有用到? start_mols = [] # zeros correspondes to ****** which is valid according to RDKit sequences_all = torch.ones((num_samples, batch_size, max_len)) sequences_all = sequences_all.type(torch.LongTensor) max_trials = 100000 # Maximum trials for sampling current_trials = 0 # greedy意思是只尝试一次生成,成了就有分子式,没成的话就没有 if decode_type == 'greedy': max_trials = 1 # Set of unique starting molecules if src is not None: # 这里需要修改,delta_value并不是放在第一位置 start_ind = len(cfgd.PROPERTIES) for ibatch in range(batch_size): source_smi = self.tokenizer.untokenize(self.vocab.decode(src[ibatch].tolist()[start_ind:])) source_smi = uc.get_canonical_smile(source_smi) if source_smi: # 先添加source,用于后面去重,TODO: 但这里也不太对,因为这里已经是被mmpdb分开的,而不是一个完整的SMILES unique_set_num_samples[ibatch].add(source_smi) start_mols.append(source_smi) with torch.no_grad(): if model_choice == 'seq2seq': encoder_outputs, decoder_hidden = model.network.encoder(src, source_length) while not all(num_valid_batch >= num_valid_batch_desired) and current_trials < max_trials: current_trials += 1 # batch input for current trial if src is not None: # 这个不就是全选嘛? src_current = src.index_select(0, batch_index_current) if src_mask is not None: mask_current = src_mask.index_select(0, batch_index_current) batch_size = src_current.shape[0] # sample molecule if model_choice == 'transformer': sequences = decode(model, src_current, mask_current, max_len, decode_type) padding = (0, max_len-sequences.shape[1], 0, 0) sequences = torch.nn.functional.pad(sequences, padding) elif model_choice == 'seq2seq': sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden, encoder_outputs, max_len, device) # else: # LOG.info('Specify transformer or seq2seq for model_choice') # Check valid and unique smiles = [] is_valid_index = [] batch_index_map = dict(zip(list(range(batch_size)), batch_index_current)) # Valid, ibatch index is different from original, need map back for ibatch in range(batch_size): seq = sequences[ibatch] smi = self.tokenizer.untokenize(self.vocab.decode(seq.cpu().numpy())) smi = uc.get_canonical_smile(smi) smiles.append(smi) # valid and not same as starting molecules if uc.is_valid(smi): is_valid_index.append(ibatch) # total sampled times num_valid_batch_total[batch_index_map[ibatch]] += 1 # Check if duplicated and update num_valid_batch and unique for good_index in is_valid_index: index_in_original_batch = batch_index_map[good_index] if smiles[good_index] not in unique_set_num_samples[index_in_original_batch]: unique_set_num_samples[index_in_original_batch].add(smiles[good_index]) num_valid_batch[index_in_original_batch] += 1 sequences_all[int(num_valid_batch[index_in_original_batch] - 1), index_in_original_batch, :] = \ sequences[good_index] not_completed_index = np.where(num_valid_batch < num_valid_batch_desired)[0] # 选择未生成满的source样本继续生成 if len(not_completed_index) > 0: batch_index_current = batch_index.index_select(0, torch.LongTensor(not_completed_index)).to(device) # Convert to SMILES smiles_list = [] # [batch, topk] seqs = np.asarray(sequences_all.numpy()) # [num_sample, batch_size, max_len] batch_size = len(seqs[0]) for ibatch in range(batch_size): topk_list = [] for k in range(num_samples): seq = seqs[k, ibatch, :] topk_list.extend([self.tokenizer.untokenize(self.vocab.decode(seq))]) smiles_list.append(topk_list) return smiles_list def sample_seq2seq(self, model, mask, batch_index_current, decoder_hidden, encoder_outputs, max_len, device): # batch size will change when some of the generated molecules are valid encoder_outputs_current = encoder_outputs.index_select(0, batch_index_current) batch_size = encoder_outputs_current.shape[0] # start token start_token = torch.zeros(batch_size, dtype=torch.long) start_token[:] = self.vocab["^"] decoder_input = start_token.to(device) sequences = [] mask = torch.squeeze(mask, 1).to(device) # initial decoder hidden states if isinstance(decoder_hidden, tuple): decoder_hidden_current = (decoder_hidden[0].index_select(1, batch_index_current), decoder_hidden[1].index_select(1, batch_index_current)) else: decoder_hidden_current = decoder_hidden.index_select(1, batch_index_current) for i in range(max_len): logits, decoder_hidden_current = model.network.decoder(decoder_input.unsqueeze(1), decoder_hidden_current, encoder_outputs_current, mask) logits = logits.squeeze(1) probabilities = logits.softmax(dim=1) # torch.Size([batch_size, vocab_size]) topi = torch.multinomial(probabilities, 1) # torch.Size([batch_size, 1]) decoder_input = topi.view(-1).detach() sequences.append(decoder_input.view(-1, 1)) sequences = torch.cat(sequences, 1) return sequences def run_main(): """Main function.""" parser = argparse.ArgumentParser( description='generate.py', formatter_class=argparse.ArgumentDefaultsHelpFormatter) print("--------parser-------------") print(parser) opts.generate_opts(parser) opt = parser.parse_args() opt.test_file_name = prepare_input(opt) print("opt输出如下") print(opt) runner = GenerateRunner(opt) print() runner.generate(opt) if __name__ == "__main__": run_main()