LLM-fastAPI / models /dataset.py
Songyou's picture
add files
f3b11f9
# coding=utf-8
"""
Implementation of a SMILES dataset.
"""
import pandas as pd
import torch
import torch.utils.data as tud
from torch.autograd import Variable
import configuration.config_default as cfgd
from models.transformer.module.subsequent_mask import subsequent_mask
from rdkit.Chem.SaltRemover import SaltRemover
import random
import rdkit.Chem as rkc
from common.utils import Data_Type
class Dataset(tud.Dataset):
"""Custom PyTorch Dataset that takes a file containing
Source_Mol_ID,Target_Mol_ID,Source_Mol,Target_Mol,
Source_Mol_LogD,Target_Mol_LogD,Delta_LogD,
Source_Mol_Solubility,Target_Mol_Solubility,Delta_Solubility,
Source_Mol_Clint,Target_Mol_Clint,Delta_Clint,
Transformation,Core"""
def __init__(self, data, vocabulary, tokenizer, prediction_mode=False, use_random=False, data_type=Data_Type.frag.value):
"""
:param data: dataframe read from training, validation or test file
:param vocabulary: used to encode source/target tokens
:param tokenizer: used to tokenize source/target smiles
:param prediction_mode: if use target smiles or not (training or test)
"""
self._vocabulary = vocabulary
self._tokenizer = tokenizer
self._data = data
self._prediction_mode = prediction_mode
self._use_random = use_random
self._data_type = data_type
def smiles_preprocess(self, smiles, random_type="unrestricted"):
"""
Returns a random SMILES given a SMILES of a molecule.
:param mol: A Mol object
:param random_type: The type (unrestricted, restricted) of randomization performed.
:return : A random SMILES string of the same molecule or None if the molecule is invalid.
"""
if not self._use_random:
return smiles
mol = rkc.MolFromSmiles(smiles)
if not mol:
return None
remover = SaltRemover() ## default salt remover
if random_type == "unrestricted":
stripped = remover.StripMol(mol)
if stripped == None:
return smiles
ret = rkc.MolToSmiles(stripped, canonical=False, doRandom=True, isomericSmiles=False)
if not bool(ret):
return smiles
return ret
if random_type == "restricted":
new_atom_order = list(range(mol.GetNumAtoms()))
random.shuffle(new_atom_order)
random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order)
ret = rkc.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
if not bool(ret):
return smiles
return ret
raise ValueError("Type '{}' is not valid".format(random_type))
def __getitem__(self, i):
"""
Tokenize and encode source smile and/or target smile (if prediction_mode is True)
:param i:
:return:
"""
row = self._data.iloc[i]
# tokenize and encode source smiles
main_cls = row['main_cls']
minor_cls = row['minor_cls']
target_name = row['target_name']
target_name = target_name if isinstance(target_name, str) else ''
value = row['Delta_Value']
# value = row['Delta_pki']
source_tokens = []
if self._data_type == Data_Type.frag.value:
sourceConstant = self.smiles_preprocess(row['constantSMILES'])
sourceVariable = self.smiles_preprocess(row['fromVarSMILES'])
# 先variable
source_tokens.extend(self._tokenizer.tokenize(sourceVariable)) ## add source variable SMILES token
# 接着constant
source_tokens.extend(self._tokenizer.tokenize(sourceConstant)) ## add source constant SMILES token
elif self._data_type == Data_Type.whole.value:
sourceSmi = self.smiles_preprocess(row['cpd1SMILES'])
source_tokens.extend(self._tokenizer.tokenize(sourceSmi))
# 再 major class eg activity
source_tokens.append(main_cls)
# 再 minor class eg Ki
source_tokens.append(minor_cls)
# 然后value
source_tokens.append(value)
# 然后target name
source_tokens.extend(list(target_name))
source_encoded = self._vocabulary.encode(source_tokens)
# print(source_tokens,'\n=====\n', source_encoded)
# tokenize and encode target smiles if it is for training instead of evaluation
if not self._prediction_mode:
target_smi = ''
if self._data_type == Data_Type.frag.value:
target_smi = row['toVarSMILES']
elif self._data_type == Data_Type.whole.value:
target_smi = row['cpd2SMILES']
target_tokens = self._tokenizer.tokenize(target_smi)
target_encoded = self._vocabulary.encode(target_tokens)
return torch.tensor(source_encoded, dtype=torch.long), torch.tensor(
target_encoded, dtype=torch.long), row
else:
return torch.tensor(source_encoded, dtype=torch.long), row
def __len__(self):
return len(self._data)
@classmethod
def collate_fn(cls, data_all):
# sort based on source sequence's length
data_all.sort(key=lambda x: len(x[0]), reverse=True)
is_prediction_mode = True if len(data_all[0]) == 2 else False
if is_prediction_mode:
source_encoded, data = zip(*data_all)
data = pd.DataFrame(data)
else:
source_encoded, target_encoded, data = zip(*data_all)
data = pd.DataFrame(data)
# maximum length of source sequences
max_length_source = max([seq.size(0) for seq in source_encoded])
# print('=====max len', max_length_source)
# padded source sequences with zeroes
collated_arr_source = torch.zeros(len(source_encoded), max_length_source, dtype=torch.long)
for i, seq in enumerate(source_encoded):
collated_arr_source[i, :seq.size(0)] = seq
# length of each source sequence
source_length = [seq.size(0) for seq in source_encoded]
source_length = torch.tensor(source_length)
# mask of source seqs
src_mask = (collated_arr_source !=0).unsqueeze(-2)
# target seq
if not is_prediction_mode:
max_length_target = max([seq.size(0) for seq in target_encoded])
collated_arr_target = torch.zeros(len(target_encoded), max_length_target, dtype=torch.long)
for i, seq in enumerate(target_encoded):
collated_arr_target[i, :seq.size(0)] = seq
trg_mask = (collated_arr_target != 0).unsqueeze(-2)
trg_mask = trg_mask & Variable(subsequent_mask(collated_arr_target.size(-1)).type_as(trg_mask))
trg_mask = trg_mask[:, :-1, :-1] # save start token, skip end token
else:
trg_mask = None
max_length_target = None
collated_arr_target = None
return collated_arr_source, source_length, collated_arr_target, src_mask, trg_mask, max_length_target, data