import torch from torch.utils.data import Dataset from utils import SmilesEnumerator import numpy as np import re class SmileDataset(Dataset): def __init__(self, args, data, content, block_size, aug_prob = 0.5, prop = None, scaffold = None, scaffold_maxlen = None): chars = sorted(list(set(content))) data_size, vocab_size = len(data), len(chars) print('data has %d smiles, %d unique characters.' % (data_size, vocab_size)) self.stoi = { ch:i for i,ch in enumerate(chars) } self.itos = { i:ch for i,ch in enumerate(chars) } self.max_len = block_size self.vocab_size = vocab_size self.data = data self.prop = prop self.sca = scaffold self.scaf_max_len = scaffold_maxlen self.debug = args.debug self.tfm = SmilesEnumerator() self.aug_prob = aug_prob def __len__(self): if self.debug: return math.ceil(len(self.data) / (self.max_len + 1)) else: return len(self.data) def __getitem__(self, idx): smiles, prop, scaffold = self.data[idx], self.prop[idx], self.sca[idx] # self.prop.iloc[idx, :].values --> if multiple properties smiles = smiles.strip() scaffold = scaffold.strip() p = np.random.uniform() if p < self.aug_prob: smiles = self.tfm.randomize_smiles(smiles) pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" regex = re.compile(pattern) smiles += str('<')*(self.max_len - len(regex.findall(smiles))) if len(regex.findall(smiles)) > self.max_len: smiles = smiles[:self.max_len] smiles=regex.findall(smiles) scaffold += str('<')*(self.scaf_max_len - len(regex.findall(scaffold))) if len(regex.findall(scaffold)) > self.scaf_max_len: scaffold = scaffold[:self.scaf_max_len] scaffold=regex.findall(scaffold) dix = [self.stoi[s] for s in smiles] sca_dix = [self.stoi[s] for s in scaffold] sca_tensor = torch.tensor(sca_dix, dtype=torch.long) x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) # prop = torch.tensor([prop], dtype=torch.long) prop = torch.tensor([prop], dtype = torch.float) return x, y, prop, sca_tensor