Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| """ | |
| Implementation of a SMILES dataset. | |
| """ | |
| import pandas as pd | |
| import torch | |
| import torch.utils.data as tud | |
| from torch.autograd import Variable | |
| import configuration.config_default as cfgd | |
| from models.transformer.module.subsequent_mask import subsequent_mask | |
| from rdkit.Chem.SaltRemover import SaltRemover | |
| import random | |
| import rdkit.Chem as rkc | |
| from common.utils import Data_Type | |
| class Dataset(tud.Dataset): | |
| """Custom PyTorch Dataset that takes a file containing | |
| Source_Mol_ID,Target_Mol_ID,Source_Mol,Target_Mol, | |
| Source_Mol_LogD,Target_Mol_LogD,Delta_LogD, | |
| Source_Mol_Solubility,Target_Mol_Solubility,Delta_Solubility, | |
| Source_Mol_Clint,Target_Mol_Clint,Delta_Clint, | |
| Transformation,Core""" | |
| def __init__(self, data, vocabulary, tokenizer, prediction_mode=False, use_random=False, data_type=Data_Type.frag.value): | |
| """ | |
| :param data: dataframe read from training, validation or test file | |
| :param vocabulary: used to encode source/target tokens | |
| :param tokenizer: used to tokenize source/target smiles | |
| :param prediction_mode: if use target smiles or not (training or test) | |
| """ | |
| self._vocabulary = vocabulary | |
| self._tokenizer = tokenizer | |
| self._data = data | |
| self._prediction_mode = prediction_mode | |
| self._use_random = use_random | |
| self._data_type = data_type | |
| def smiles_preprocess(self, smiles, random_type="unrestricted"): | |
| """ | |
| Returns a random SMILES given a SMILES of a molecule. | |
| :param mol: A Mol object | |
| :param random_type: The type (unrestricted, restricted) of randomization performed. | |
| :return : A random SMILES string of the same molecule or None if the molecule is invalid. | |
| """ | |
| if not self._use_random: | |
| return smiles | |
| mol = rkc.MolFromSmiles(smiles) | |
| if not mol: | |
| return None | |
| remover = SaltRemover() ## default salt remover | |
| if random_type == "unrestricted": | |
| stripped = remover.StripMol(mol) | |
| if stripped == None: | |
| return smiles | |
| ret = rkc.MolToSmiles(stripped, canonical=False, doRandom=True, isomericSmiles=False) | |
| if not bool(ret): | |
| return smiles | |
| return ret | |
| if random_type == "restricted": | |
| new_atom_order = list(range(mol.GetNumAtoms())) | |
| random.shuffle(new_atom_order) | |
| random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order) | |
| ret = rkc.MolToSmiles(random_mol, canonical=False, isomericSmiles=False) | |
| if not bool(ret): | |
| return smiles | |
| return ret | |
| raise ValueError("Type '{}' is not valid".format(random_type)) | |
| def __getitem__(self, i): | |
| """ | |
| Tokenize and encode source smile and/or target smile (if prediction_mode is True) | |
| :param i: | |
| :return: | |
| """ | |
| row = self._data.iloc[i] | |
| # tokenize and encode source smiles | |
| main_cls = row['main_cls'] | |
| minor_cls = row['minor_cls'] | |
| target_name = row['target_name'] | |
| target_name = target_name if isinstance(target_name, str) else '' | |
| value = row['Delta_Value'] | |
| # value = row['Delta_pki'] | |
| source_tokens = [] | |
| if self._data_type == Data_Type.frag.value: | |
| sourceConstant = self.smiles_preprocess(row['constantSMILES']) | |
| sourceVariable = self.smiles_preprocess(row['fromVarSMILES']) | |
| # 先variable | |
| source_tokens.extend(self._tokenizer.tokenize(sourceVariable)) ## add source variable SMILES token | |
| # 接着constant | |
| source_tokens.extend(self._tokenizer.tokenize(sourceConstant)) ## add source constant SMILES token | |
| elif self._data_type == Data_Type.whole.value: | |
| sourceSmi = self.smiles_preprocess(row['cpd1SMILES']) | |
| source_tokens.extend(self._tokenizer.tokenize(sourceSmi)) | |
| # 再 major class eg activity | |
| source_tokens.append(main_cls) | |
| # 再 minor class eg Ki | |
| source_tokens.append(minor_cls) | |
| # 然后value | |
| source_tokens.append(value) | |
| # 然后target name | |
| source_tokens.extend(list(target_name)) | |
| source_encoded = self._vocabulary.encode(source_tokens) | |
| # print(source_tokens,'\n=====\n', source_encoded) | |
| # tokenize and encode target smiles if it is for training instead of evaluation | |
| if not self._prediction_mode: | |
| target_smi = '' | |
| if self._data_type == Data_Type.frag.value: | |
| target_smi = row['toVarSMILES'] | |
| elif self._data_type == Data_Type.whole.value: | |
| target_smi = row['cpd2SMILES'] | |
| target_tokens = self._tokenizer.tokenize(target_smi) | |
| target_encoded = self._vocabulary.encode(target_tokens) | |
| return torch.tensor(source_encoded, dtype=torch.long), torch.tensor( | |
| target_encoded, dtype=torch.long), row | |
| else: | |
| return torch.tensor(source_encoded, dtype=torch.long), row | |
| def __len__(self): | |
| return len(self._data) | |
| def collate_fn(cls, data_all): | |
| # sort based on source sequence's length | |
| data_all.sort(key=lambda x: len(x[0]), reverse=True) | |
| is_prediction_mode = True if len(data_all[0]) == 2 else False | |
| if is_prediction_mode: | |
| source_encoded, data = zip(*data_all) | |
| data = pd.DataFrame(data) | |
| else: | |
| source_encoded, target_encoded, data = zip(*data_all) | |
| data = pd.DataFrame(data) | |
| # maximum length of source sequences | |
| max_length_source = max([seq.size(0) for seq in source_encoded]) | |
| # print('=====max len', max_length_source) | |
| # padded source sequences with zeroes | |
| collated_arr_source = torch.zeros(len(source_encoded), max_length_source, dtype=torch.long) | |
| for i, seq in enumerate(source_encoded): | |
| collated_arr_source[i, :seq.size(0)] = seq | |
| # length of each source sequence | |
| source_length = [seq.size(0) for seq in source_encoded] | |
| source_length = torch.tensor(source_length) | |
| # mask of source seqs | |
| src_mask = (collated_arr_source !=0).unsqueeze(-2) | |
| # target seq | |
| if not is_prediction_mode: | |
| max_length_target = max([seq.size(0) for seq in target_encoded]) | |
| collated_arr_target = torch.zeros(len(target_encoded), max_length_target, dtype=torch.long) | |
| for i, seq in enumerate(target_encoded): | |
| collated_arr_target[i, :seq.size(0)] = seq | |
| trg_mask = (collated_arr_target != 0).unsqueeze(-2) | |
| trg_mask = trg_mask & Variable(subsequent_mask(collated_arr_target.size(-1)).type_as(trg_mask)) | |
| trg_mask = trg_mask[:, :-1, :-1] # save start token, skip end token | |
| else: | |
| trg_mask = None | |
| max_length_target = None | |
| collated_arr_target = None | |
| return collated_arr_source, source_length, collated_arr_target, src_mask, trg_mask, max_length_target, data | |