File size: 4,361 Bytes
d5233a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5946936
 
 
d5233a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5946936
d5233a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os, json
import torch
import utils

def calc_feats(smi, ms, nls, cfg):
    item = {}
    item['ms_bins'] = utils.ms_binner(ms, nls,
                                      min_mz=cfg.min_mz,
                                      max_mz=cfg.max_mz,
                                      bin_size=cfg.bin_size,
                                      add_nl=cfg.add_nl,
                                      binary_intn=cfg.binary_intn)
    
    fmcalced = False
    if 'fp' in cfg.mol_encoder:
        if not 'fm' in cfg.mol_encoder:
            item['mol_fps'] = utils.mol_fp_encoder(smi,
                                               tp=cfg.fptype,
                                               nbits=cfg.mol_embedding_dim)
        else:
            item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
                                                tp=cfg.fptype,
                                                nbits=cfg.mol_embedding_dim)
            fmcalced = True
    if 'gnn' in cfg.mol_encoder:
        f = utils.mol_graph_featurizer(smi)
        if not f:
            return None
        item.update(f)
        if 'fm' in cfg.mol_encoder and not fmcalced:
            item['mol_fmvec'] = utils.smi2fmvec(smi)

    return item

class Dataset(torch.utils.data.Dataset):
    def __init__(self, inp, cfg):
        if type(inp) is str:
            self.data = json.load(open(inp))
        else:
            self.data = inp

        self.cfg = cfg

    def __getitem__(self, idx):
        item = {}
        try:
            if 'ms_bins' in self.data[idx]:
                return self.data[idx]
                
            if 'nls' in self.data[idx]:
                nls = self.data[idx]['nls']
            else:
                nls = []

            ms = self.data[idx]['ms']
            smi = self.data[idx]['smiles']

            item = calc_feats(smi, ms, nls, self.cfg)

        except Exception as e:
            print('='*50, idx, str(e))
            return None

        return item

    def __len__(self):
        return len(self.data)
    
class DatasetGNNFP(torch.utils.data.Dataset):
    def __init__(self, inp, cfg):
        if type(inp) is str:
            self.data = json.load(open(inp))
        else:
            self.data = inp

        self.cfg = cfg

    def __getitem__(self, idx):
        try:
            smi = self.data[idx]['smiles']
            item = {}
            item['mol_fps'] = utils.mol_fp_encoder(smi,
                                                   tp=self.cfg.fptype,
                                                   nbits=self.cfg.mol_embedding_dim)
            item.update(utils.mol_graph_featurizer(smi))
        except Exception as e:
            print('='*50, idx, str(e))
            return None

        return item

    def __len__(self):
        return len(self.data)

class PathDataset(torch.utils.data.Dataset):
    def __init__(self, pathlist, cfg):
        self.fns = pathlist
        self.cfg = cfg
        self.data = {}

    def __getitem__(self, idx):
        try:
            item = {}
            nls = []
            if not idx in self.data:
                out = self.proc_data(self.fns[idx], self.cfg.energy)
                if out is None:
                    return None
                self.data[idx] = out

            ms = self.data[idx]['ms']
            smi = self.data[idx]['smiles']

            item = calc_feats(smi, ms, nls, self.cfg)

        except Exception as e:
            #print('='*50, idx, str(e))
            return None

        return item

    def proc_data(self, fn, energy='Energy1'):
        tl = open(fn).readlines()
        l = []
        try:
            flag = False
            for i in tl:
                if energy in i:
                    smi = i.split(';')[-2]
                    flag = True
                    continue
                if 'END IONS' in i:
                    if flag:
                        break
                if flag:
                    mz, intn = i.split(' ')
                    l.append((float(mz), float(intn)))
        except:
            return None

        out = {'ms': l, 'smiles': smi}
        return out

    def __len__(self):
        return len(self.fns)