File size: 4,361 Bytes
d5233a9 5946936 d5233a9 5946936 d5233a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os, json
import torch
import utils
def calc_feats(smi, ms, nls, cfg):
item = {}
item['ms_bins'] = utils.ms_binner(ms, nls,
min_mz=cfg.min_mz,
max_mz=cfg.max_mz,
bin_size=cfg.bin_size,
add_nl=cfg.add_nl,
binary_intn=cfg.binary_intn)
fmcalced = False
if 'fp' in cfg.mol_encoder:
if not 'fm' in cfg.mol_encoder:
item['mol_fps'] = utils.mol_fp_encoder(smi,
tp=cfg.fptype,
nbits=cfg.mol_embedding_dim)
else:
item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
tp=cfg.fptype,
nbits=cfg.mol_embedding_dim)
fmcalced = True
if 'gnn' in cfg.mol_encoder:
f = utils.mol_graph_featurizer(smi)
if not f:
return None
item.update(f)
if 'fm' in cfg.mol_encoder and not fmcalced:
item['mol_fmvec'] = utils.smi2fmvec(smi)
return item
class Dataset(torch.utils.data.Dataset):
def __init__(self, inp, cfg):
if type(inp) is str:
self.data = json.load(open(inp))
else:
self.data = inp
self.cfg = cfg
def __getitem__(self, idx):
item = {}
try:
if 'ms_bins' in self.data[idx]:
return self.data[idx]
if 'nls' in self.data[idx]:
nls = self.data[idx]['nls']
else:
nls = []
ms = self.data[idx]['ms']
smi = self.data[idx]['smiles']
item = calc_feats(smi, ms, nls, self.cfg)
except Exception as e:
print('='*50, idx, str(e))
return None
return item
def __len__(self):
return len(self.data)
class DatasetGNNFP(torch.utils.data.Dataset):
def __init__(self, inp, cfg):
if type(inp) is str:
self.data = json.load(open(inp))
else:
self.data = inp
self.cfg = cfg
def __getitem__(self, idx):
try:
smi = self.data[idx]['smiles']
item = {}
item['mol_fps'] = utils.mol_fp_encoder(smi,
tp=self.cfg.fptype,
nbits=self.cfg.mol_embedding_dim)
item.update(utils.mol_graph_featurizer(smi))
except Exception as e:
print('='*50, idx, str(e))
return None
return item
def __len__(self):
return len(self.data)
class PathDataset(torch.utils.data.Dataset):
def __init__(self, pathlist, cfg):
self.fns = pathlist
self.cfg = cfg
self.data = {}
def __getitem__(self, idx):
try:
item = {}
nls = []
if not idx in self.data:
out = self.proc_data(self.fns[idx], self.cfg.energy)
if out is None:
return None
self.data[idx] = out
ms = self.data[idx]['ms']
smi = self.data[idx]['smiles']
item = calc_feats(smi, ms, nls, self.cfg)
except Exception as e:
#print('='*50, idx, str(e))
return None
return item
def proc_data(self, fn, energy='Energy1'):
tl = open(fn).readlines()
l = []
try:
flag = False
for i in tl:
if energy in i:
smi = i.split(';')[-2]
flag = True
continue
if 'END IONS' in i:
if flag:
break
if flag:
mz, intn = i.split(' ')
l.append((float(mz), float(intn)))
except:
return None
out = {'ms': l, 'smiles': smi}
return out
def __len__(self):
return len(self.fns) |