Spaces:
Sleeping
Sleeping
File size: 7,152 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | # coding=utf-8
"""
Implementation of a SMILES dataset.
"""
import pandas as pd
import torch
import torch.utils.data as tud
from torch.autograd import Variable
import configuration.config_default as cfgd
from models.transformer.module.subsequent_mask import subsequent_mask
from rdkit.Chem.SaltRemover import SaltRemover
import random
import rdkit.Chem as rkc
from common.utils import Data_Type
class Dataset(tud.Dataset):
"""Custom PyTorch Dataset that takes a file containing
Source_Mol_ID,Target_Mol_ID,Source_Mol,Target_Mol,
Source_Mol_LogD,Target_Mol_LogD,Delta_LogD,
Source_Mol_Solubility,Target_Mol_Solubility,Delta_Solubility,
Source_Mol_Clint,Target_Mol_Clint,Delta_Clint,
Transformation,Core"""
def __init__(self, data, vocabulary, tokenizer, prediction_mode=False, use_random=False, data_type=Data_Type.frag.value):
"""
:param data: dataframe read from training, validation or test file
:param vocabulary: used to encode source/target tokens
:param tokenizer: used to tokenize source/target smiles
:param prediction_mode: if use target smiles or not (training or test)
"""
self._vocabulary = vocabulary
self._tokenizer = tokenizer
self._data = data
self._prediction_mode = prediction_mode
self._use_random = use_random
self._data_type = data_type
def smiles_preprocess(self, smiles, random_type="unrestricted"):
"""
Returns a random SMILES given a SMILES of a molecule.
:param mol: A Mol object
:param random_type: The type (unrestricted, restricted) of randomization performed.
:return : A random SMILES string of the same molecule or None if the molecule is invalid.
"""
if not self._use_random:
return smiles
mol = rkc.MolFromSmiles(smiles)
if not mol:
return None
remover = SaltRemover() ## default salt remover
if random_type == "unrestricted":
stripped = remover.StripMol(mol)
if stripped == None:
return smiles
ret = rkc.MolToSmiles(stripped, canonical=False, doRandom=True, isomericSmiles=False)
if not bool(ret):
return smiles
return ret
if random_type == "restricted":
new_atom_order = list(range(mol.GetNumAtoms()))
random.shuffle(new_atom_order)
random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order)
ret = rkc.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
if not bool(ret):
return smiles
return ret
raise ValueError("Type '{}' is not valid".format(random_type))
def __getitem__(self, i):
"""
Tokenize and encode source smile and/or target smile (if prediction_mode is True)
:param i:
:return:
"""
row = self._data.iloc[i]
# tokenize and encode source smiles
main_cls = row['main_cls']
minor_cls = row['minor_cls']
target_name = row['target_name']
target_name = target_name if isinstance(target_name, str) else ''
value = row['Delta_Value']
# value = row['Delta_pki']
source_tokens = []
if self._data_type == Data_Type.frag.value:
sourceConstant = self.smiles_preprocess(row['constantSMILES'])
sourceVariable = self.smiles_preprocess(row['fromVarSMILES'])
# 先variable
source_tokens.extend(self._tokenizer.tokenize(sourceVariable)) ## add source variable SMILES token
# 接着constant
source_tokens.extend(self._tokenizer.tokenize(sourceConstant)) ## add source constant SMILES token
elif self._data_type == Data_Type.whole.value:
sourceSmi = self.smiles_preprocess(row['cpd1SMILES'])
source_tokens.extend(self._tokenizer.tokenize(sourceSmi))
# 再 major class eg activity
source_tokens.append(main_cls)
# 再 minor class eg Ki
source_tokens.append(minor_cls)
# 然后value
source_tokens.append(value)
# 然后target name
source_tokens.extend(list(target_name))
source_encoded = self._vocabulary.encode(source_tokens)
# print(source_tokens,'\n=====\n', source_encoded)
# tokenize and encode target smiles if it is for training instead of evaluation
if not self._prediction_mode:
target_smi = ''
if self._data_type == Data_Type.frag.value:
target_smi = row['toVarSMILES']
elif self._data_type == Data_Type.whole.value:
target_smi = row['cpd2SMILES']
target_tokens = self._tokenizer.tokenize(target_smi)
target_encoded = self._vocabulary.encode(target_tokens)
return torch.tensor(source_encoded, dtype=torch.long), torch.tensor(
target_encoded, dtype=torch.long), row
else:
return torch.tensor(source_encoded, dtype=torch.long), row
def __len__(self):
return len(self._data)
@classmethod
def collate_fn(cls, data_all):
# sort based on source sequence's length
data_all.sort(key=lambda x: len(x[0]), reverse=True)
is_prediction_mode = True if len(data_all[0]) == 2 else False
if is_prediction_mode:
source_encoded, data = zip(*data_all)
data = pd.DataFrame(data)
else:
source_encoded, target_encoded, data = zip(*data_all)
data = pd.DataFrame(data)
# maximum length of source sequences
max_length_source = max([seq.size(0) for seq in source_encoded])
# print('=====max len', max_length_source)
# padded source sequences with zeroes
collated_arr_source = torch.zeros(len(source_encoded), max_length_source, dtype=torch.long)
for i, seq in enumerate(source_encoded):
collated_arr_source[i, :seq.size(0)] = seq
# length of each source sequence
source_length = [seq.size(0) for seq in source_encoded]
source_length = torch.tensor(source_length)
# mask of source seqs
src_mask = (collated_arr_source !=0).unsqueeze(-2)
# target seq
if not is_prediction_mode:
max_length_target = max([seq.size(0) for seq in target_encoded])
collated_arr_target = torch.zeros(len(target_encoded), max_length_target, dtype=torch.long)
for i, seq in enumerate(target_encoded):
collated_arr_target[i, :seq.size(0)] = seq
trg_mask = (collated_arr_target != 0).unsqueeze(-2)
trg_mask = trg_mask & Variable(subsequent_mask(collated_arr_target.size(-1)).type_as(trg_mask))
trg_mask = trg_mask[:, :-1, :-1] # save start token, skip end token
else:
trg_mask = None
max_length_target = None
collated_arr_target = None
return collated_arr_source, source_length, collated_arr_target, src_mask, trg_mask, max_length_target, data
|