import os, json import torch import utils def calc_feats(smi, ms, nls, cfg): item = {} item['ms_bins'] = utils.ms_binner(ms, nls, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn) fmcalced = False if 'fp' in cfg.mol_encoder: if not 'fm' in cfg.mol_encoder: item['mol_fps'] = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim) else: item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim) fmcalced = True if 'gnn' in cfg.mol_encoder: f = utils.mol_graph_featurizer(smi) if not f: return None item.update(f) if 'fm' in cfg.mol_encoder and not fmcalced: item['mol_fmvec'] = utils.smi2fmvec(smi) return item class Dataset(torch.utils.data.Dataset): def __init__(self, inp, cfg): if type(inp) is str: self.data = json.load(open(inp)) else: self.data = inp self.cfg = cfg def __getitem__(self, idx): item = {} try: if 'ms_bins' in self.data[idx]: return self.data[idx] if 'nls' in self.data[idx]: nls = self.data[idx]['nls'] else: nls = [] ms = self.data[idx]['ms'] smi = self.data[idx]['smiles'] item = calc_feats(smi, ms, nls, self.cfg) except Exception as e: print('='*50, idx, str(e)) return None return item def __len__(self): return len(self.data) class DatasetGNNFP(torch.utils.data.Dataset): def __init__(self, inp, cfg): if type(inp) is str: self.data = json.load(open(inp)) else: self.data = inp self.cfg = cfg def __getitem__(self, idx): try: smi = self.data[idx]['smiles'] item = {} item['mol_fps'] = utils.mol_fp_encoder(smi, tp=self.cfg.fptype, nbits=self.cfg.mol_embedding_dim) item.update(utils.mol_graph_featurizer(smi)) except Exception as e: print('='*50, idx, str(e)) return None return item def __len__(self): return len(self.data) class PathDataset(torch.utils.data.Dataset): def __init__(self, pathlist, cfg): self.fns = pathlist self.cfg = cfg self.data = {} def __getitem__(self, idx): try: item = {} nls = [] if not idx in self.data: out = self.proc_data(self.fns[idx], self.cfg.energy) if out is None: return None self.data[idx] = out ms = self.data[idx]['ms'] smi = self.data[idx]['smiles'] item = calc_feats(smi, ms, nls, self.cfg) except Exception as e: #print('='*50, idx, str(e)) return None return item def proc_data(self, fn, energy='Energy1'): tl = open(fn).readlines() l = [] try: flag = False for i in tl: if energy in i: smi = i.split(';')[-2] flag = True continue if 'END IONS' in i: if flag: break if flag: mz, intn = i.split(' ') l.append((float(mz), float(intn))) except: return None out = {'ms': l, 'smiles': smi} return out def __len__(self): return len(self.fns)