|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from utils import SmilesEnumerator |
|
|
import numpy as np |
|
|
import re |
|
|
|
|
|
class SmileDataset(Dataset): |
|
|
|
|
|
def __init__(self, args, data, content, block_size, aug_prob = 0.5, prop = None, scaffold = None, scaffold_maxlen = None): |
|
|
chars = sorted(list(set(content))) |
|
|
data_size, vocab_size = len(data), len(chars) |
|
|
print('data has %d smiles, %d unique characters.' % (data_size, vocab_size)) |
|
|
|
|
|
self.stoi = { ch:i for i,ch in enumerate(chars) } |
|
|
self.itos = { i:ch for i,ch in enumerate(chars) } |
|
|
self.max_len = block_size |
|
|
self.vocab_size = vocab_size |
|
|
self.data = data |
|
|
self.prop = prop |
|
|
self.sca = scaffold |
|
|
self.scaf_max_len = scaffold_maxlen |
|
|
self.debug = args.debug |
|
|
self.tfm = SmilesEnumerator() |
|
|
self.aug_prob = aug_prob |
|
|
|
|
|
def __len__(self): |
|
|
if self.debug: |
|
|
return math.ceil(len(self.data) / (self.max_len + 1)) |
|
|
else: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
smiles, prop, scaffold = self.data[idx], self.prop[idx], self.sca[idx] |
|
|
smiles = smiles.strip() |
|
|
scaffold = scaffold.strip() |
|
|
|
|
|
p = np.random.uniform() |
|
|
if p < self.aug_prob: |
|
|
smiles = self.tfm.randomize_smiles(smiles) |
|
|
|
|
|
pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" |
|
|
regex = re.compile(pattern) |
|
|
smiles += str('<')*(self.max_len - len(regex.findall(smiles))) |
|
|
|
|
|
if len(regex.findall(smiles)) > self.max_len: |
|
|
smiles = smiles[:self.max_len] |
|
|
|
|
|
smiles=regex.findall(smiles) |
|
|
|
|
|
scaffold += str('<')*(self.scaf_max_len - len(regex.findall(scaffold))) |
|
|
|
|
|
if len(regex.findall(scaffold)) > self.scaf_max_len: |
|
|
scaffold = scaffold[:self.scaf_max_len] |
|
|
|
|
|
scaffold=regex.findall(scaffold) |
|
|
|
|
|
dix = [self.stoi[s] for s in smiles] |
|
|
sca_dix = [self.stoi[s] for s in scaffold] |
|
|
|
|
|
sca_tensor = torch.tensor(sca_dix, dtype=torch.long) |
|
|
x = torch.tensor(dix[:-1], dtype=torch.long) |
|
|
y = torch.tensor(dix[1:], dtype=torch.long) |
|
|
|
|
|
prop = torch.tensor([prop], dtype = torch.float) |
|
|
return x, y, prop, sca_tensor |
|
|
|