|
|
import os, json
|
|
|
import torch
|
|
|
import utils
|
|
|
|
|
|
def calc_feats(smi, ms, nls, cfg):
|
|
|
item = {}
|
|
|
item['ms_bins'] = utils.ms_binner(ms, nls,
|
|
|
min_mz=cfg.min_mz,
|
|
|
max_mz=cfg.max_mz,
|
|
|
bin_size=cfg.bin_size,
|
|
|
add_nl=cfg.add_nl,
|
|
|
binary_intn=cfg.binary_intn)
|
|
|
|
|
|
fmcalced = False
|
|
|
if 'fp' in cfg.mol_encoder:
|
|
|
if not 'fm' in cfg.mol_encoder:
|
|
|
item['mol_fps'] = utils.mol_fp_encoder(smi,
|
|
|
tp=cfg.fptype,
|
|
|
nbits=cfg.mol_embedding_dim)
|
|
|
else:
|
|
|
item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
|
|
|
tp=cfg.fptype,
|
|
|
nbits=cfg.mol_embedding_dim)
|
|
|
fmcalced = True
|
|
|
if 'gnn' in cfg.mol_encoder:
|
|
|
f = utils.mol_graph_featurizer(smi)
|
|
|
if not f:
|
|
|
return None
|
|
|
item.update(f)
|
|
|
if 'fm' in cfg.mol_encoder and not fmcalced:
|
|
|
item['mol_fmvec'] = utils.smi2fmvec(smi)
|
|
|
|
|
|
return item
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, inp, cfg):
|
|
|
if type(inp) is str:
|
|
|
self.data = json.load(open(inp))
|
|
|
else:
|
|
|
self.data = inp
|
|
|
|
|
|
self.cfg = cfg
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
item = {}
|
|
|
try:
|
|
|
if 'ms_bins' in self.data[idx]:
|
|
|
return self.data[idx]
|
|
|
|
|
|
if 'nls' in self.data[idx]:
|
|
|
nls = self.data[idx]['nls']
|
|
|
else:
|
|
|
nls = []
|
|
|
|
|
|
ms = self.data[idx]['ms']
|
|
|
smi = self.data[idx]['smiles']
|
|
|
|
|
|
item = calc_feats(smi, ms, nls, self.cfg)
|
|
|
|
|
|
except Exception as e:
|
|
|
print('='*50, idx, str(e))
|
|
|
return None
|
|
|
|
|
|
return item
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
class DatasetGNNFP(torch.utils.data.Dataset):
|
|
|
def __init__(self, inp, cfg):
|
|
|
if type(inp) is str:
|
|
|
self.data = json.load(open(inp))
|
|
|
else:
|
|
|
self.data = inp
|
|
|
|
|
|
self.cfg = cfg
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
try:
|
|
|
smi = self.data[idx]['smiles']
|
|
|
item = {}
|
|
|
item['mol_fps'] = utils.mol_fp_encoder(smi,
|
|
|
tp=self.cfg.fptype,
|
|
|
nbits=self.cfg.mol_embedding_dim)
|
|
|
item.update(utils.mol_graph_featurizer(smi))
|
|
|
except Exception as e:
|
|
|
print('='*50, idx, str(e))
|
|
|
return None
|
|
|
|
|
|
return item
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
class PathDataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, pathlist, cfg):
|
|
|
self.fns = pathlist
|
|
|
self.cfg = cfg
|
|
|
self.data = {}
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
try:
|
|
|
item = {}
|
|
|
nls = []
|
|
|
if not idx in self.data:
|
|
|
out = self.proc_data(self.fns[idx], self.cfg.energy)
|
|
|
if out is None:
|
|
|
return None
|
|
|
self.data[idx] = out
|
|
|
|
|
|
ms = self.data[idx]['ms']
|
|
|
smi = self.data[idx]['smiles']
|
|
|
|
|
|
item = calc_feats(smi, ms, nls, self.cfg)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return item
|
|
|
|
|
|
def proc_data(self, fn, energy='Energy1'):
|
|
|
tl = open(fn).readlines()
|
|
|
l = []
|
|
|
try:
|
|
|
flag = False
|
|
|
for i in tl:
|
|
|
if energy in i:
|
|
|
smi = i.split(';')[-2]
|
|
|
flag = True
|
|
|
continue
|
|
|
if 'END IONS' in i:
|
|
|
if flag:
|
|
|
break
|
|
|
if flag:
|
|
|
mz, intn = i.split(' ')
|
|
|
l.append((float(mz), float(intn)))
|
|
|
except:
|
|
|
return None
|
|
|
|
|
|
out = {'ms': l, 'smiles': smi}
|
|
|
return out
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.fns) |