Tingxie commited on
Commit
603d88b
·
1 Parent(s): c8bfe50

Upload 2 files

Browse files
dataloader/dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Apr 27 10:43:40 2022
4
+
5
+ @author: ZNDX002
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+ import ast
12
+
13
+ class ClrDataset(Dataset):
14
+ """Contrastive Learning Representations Dataset."""
15
+
16
+ def __init__(self,
17
+ file,
18
+ list_IDs,
19
+ transform=None):
20
+
21
+ self.clr_frame = file
22
+ self.list_IDs = list_IDs
23
+
24
+ def __len__(self):
25
+ return len(self.clr_frame)
26
+
27
+ def __getitem__(self, idx):
28
+ index = self.list_IDs[idx]
29
+ v_d = self.clr_frame.loc[index,'Graph']
30
+ spec = self.clr_frame.loc[idx,'MS2']
31
+ spec_mz = spec.mz
32
+ spec_intens = spec.intensities
33
+ spec_mz = np.around(spec_mz, decimals=4)
34
+ #spec_mz = torch.from_numpy(spec_mz).float()
35
+ #spec_intens = torch.from_numpy(spec_intens).float()
36
+ num_peak = len(spec_mz)
37
+ #spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
38
+ #spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
39
+ return v_d,spec_mz,spec_intens,num_peak
40
+ #return {'graph':v_d,'mz':spec_mz,'inten':spec_intens}
41
+
42
+ class re_train_dataset(Dataset):
43
+ def __init__(self,
44
+ file,
45
+ list_IDs,
46
+ transform=None):
47
+
48
+ self.clr_frame = file
49
+ self.list_IDs = list_IDs
50
+
51
+ def __len__(self):
52
+ return len(self.clr_frame)
53
+
54
+ def __getitem__(self, idx):
55
+ index = self.list_IDs[idx]
56
+ v_d = self.clr_frame.loc[index,'Graph']
57
+ spec = self.clr_frame.loc[index,'MS2']
58
+ #spec = np.array(ast.literal_eval(spec))
59
+ spec = torch.from_numpy(spec).to(torch.float32)
60
+ return v_d,spec
61
+
62
+ class re_eval_dataset(Dataset):
63
+ def __init__(self,
64
+ file,
65
+ list_IDs,
66
+ smiles_reference,
67
+ transform=None):
68
+
69
+ self.clr_frame = file
70
+ self.list_IDs = list_IDs
71
+ self.valid_formulas = list(self.clr_frame['formula'])
72
+ self.smiles_reference = smiles_reference
73
+ self.structures = list(self.clr_frame['Graph']) + list(self.smiles_reference['Graph'])
74
+ self.spectra = list(self.clr_frame['MS2'])
75
+ self.spec2smi = {}
76
+ smi_id = 0
77
+ for spec_id, ann in enumerate(self.spectra):
78
+ self.spec2smi[spec_id] = []
79
+ self.spec2smi[spec_id].append(smi_id)
80
+ smi_id += 1
81
+
82
+ def __len__(self):
83
+ return len(self.clr_frame)
84
+
85
+ def __getitem__(self, idx):
86
+ index = self.list_IDs[idx]
87
+ spec = self.clr_frame.loc[index,'MS2']
88
+ spec = torch.from_numpy(spec).to(torch.float32)
89
+ formula = self.clr_frame.loc[index,'formula']
90
+ return spec
dataloader/dataset_wrapper.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from tqdm import tqdm
5
+ #from torch.utils.data import DataLoader
6
+ from torch.utils.data.sampler import SubsetRandomSampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torchvision import datasets
9
+ from .dataset import ClrDataset,re_train_dataset,re_eval_dataset
10
+ from functools import partial
11
+ from rdkit import RDConfig
12
+ from rdkit import Chem
13
+ from rdkit.Chem import AllChem,rdMolDescriptors,Descriptors
14
+ from matchms.Fragments import Fragments
15
+ import matchms.filtering as msfilters
16
+ from matchms.importing import load_from_mgf
17
+ import warnings
18
+ from torch_geometric.data import Data, DataLoader,Batch
19
+ #from torch_geometric.data import Data
20
+ warnings.filterwarnings('ignore')
21
+ import json
22
+ import random
23
+ import ast
24
+ from rdkit.Chem.rdchem import BondType as BT
25
+ from rdkit import RDLogger
26
+ from toolz.sandbox import unzip
27
+ RDLogger.DisableLog('rdApp.*')
28
+ ATOM_LIST = list(range(1,119))
29
+ CHIRALITY_LIST = [
30
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
31
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
32
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
33
+ Chem.rdchem.ChiralType.CHI_OTHER
34
+ ]
35
+ HYBRID_TYPE = [Chem.rdchem.HybridizationType.SP,
36
+ Chem.rdchem.HybridizationType.SP2,
37
+ Chem.rdchem.HybridizationType.SP2D,
38
+ Chem.rdchem.HybridizationType.SP3,
39
+ Chem.rdchem.HybridizationType.SP3D,
40
+ Chem.rdchem.HybridizationType.SP3D2,
41
+ Chem.rdchem.HybridizationType.UNSPECIFIED,
42
+ Chem.rdchem.HybridizationType.S]
43
+ VALENCE_LIST = list(range(1,7))
44
+ DRGREE_LIST = list(range(1,6))
45
+ BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]
46
+ BONDDIR_LIST = [
47
+ Chem.rdchem.BondDir.NONE,
48
+ Chem.rdchem.BondDir.ENDUPRIGHT,
49
+ Chem.rdchem.BondDir.ENDDOWNRIGHT,
50
+ ]
51
+
52
+
53
+ def collate_func(input_list):
54
+ x,mzs,intens,num_peaks = map(list, unzip(input_list))
55
+ num_peaks = torch.LongTensor(num_peaks)
56
+ mzs = [torch.from_numpy(spec_mz).float() for spec_mz in mzs]
57
+ intens = [torch.from_numpy(spec_intens).float() for spec_intens in intens]
58
+ mzs_tensors = torch.nn.utils.rnn.pad_sequence(
59
+ mzs, batch_first=True, padding_value=0
60
+ )
61
+ intens_tensors = torch.nn.utils.rnn.pad_sequence(
62
+ intens, batch_first=True, padding_value=0
63
+ )
64
+ x = Batch.from_data_list(x)
65
+ return x,mzs_tensors,intens_tensors,num_peaks
66
+
67
+ '''def valid_collate_func(x):
68
+ ms, formula = zip(*x)
69
+ return ms, formula
70
+ '''
71
+ def valid_collate_func(x):
72
+ ms = zip(*x)
73
+ return ms
74
+
75
+ def MolToGraph(smiles):
76
+ mol = Chem.MolFromSmiles(smiles)
77
+ mol = Chem.AddHs(mol)
78
+
79
+ N = mol.GetNumAtoms()
80
+ M = mol.GetNumBonds()
81
+ type_idx = []
82
+ chirality_idx = []
83
+ atomic_number = []
84
+ hybrid_type_idx = []
85
+ valence_idx=[]
86
+ degree_idx=[]
87
+ for atom in mol.GetAtoms():
88
+ atom_index = atom.GetIdx()
89
+ type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
90
+ atom_charity = atom.GetChiralTag()
91
+ if atom_charity in CHIRALITY_LIST:
92
+ chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
93
+ else:
94
+ chirality_idx.append(CHIRALITY_LIST.index(Chem.rdchem.ChiralType.CHI_OTHER))
95
+ atomic_number.append(atom.GetAtomicNum())
96
+ hybrid_type_idx.append(HYBRID_TYPE.index(atom.GetHybridization()))
97
+ valence_idx.append(VALENCE_LIST.index(min(atom.GetTotalValence(),6)))
98
+ degree_idx.append(DRGREE_LIST.index(min(atom.GetDegree(),5)))
99
+ x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
100
+ x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
101
+ x3 = torch.tensor(hybrid_type_idx, dtype=torch.long).view(-1,1)
102
+ x4 = torch.tensor(valence_idx, dtype=torch.long).view(-1,1)
103
+ x5 = torch.tensor(degree_idx, dtype=torch.long).view(-1,1)
104
+ x = torch.cat([x1, x2, x3, x4, x5], dim=-1)
105
+
106
+ row, col, edge_feat = [], [], []
107
+ for bond in mol.GetBonds():
108
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
109
+ row += [start, end]
110
+ col += [end, start]
111
+ edge_feat.append([
112
+ BOND_LIST.index(bond.GetBondType()),
113
+ BONDDIR_LIST.index(bond.GetBondDir())
114
+ ])
115
+ edge_feat.append([
116
+ BOND_LIST.index(bond.GetBondType()),
117
+ BONDDIR_LIST.index(bond.GetBondDir())
118
+ ])
119
+
120
+ edge_index = torch.tensor([row, col], dtype=torch.long)
121
+ edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
122
+
123
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
124
+ return data
125
+
126
+ def remove_peaks(mz,peak_intensities, threshold, percentage):
127
+ low_intensity_peaks_indices = [i for i,intensitie in enumerate(peak_intensities) if intensitie < threshold]
128
+ num_peaks_to_remove = int(len(low_intensity_peaks_indices) * percentage)
129
+ peaks_to_remove = random.sample(low_intensity_peaks_indices, num_peaks_to_remove)
130
+ for i in peaks_to_remove:
131
+ peak_intensities[i] = 0
132
+ return mz,peak_intensities
133
+
134
+ def enhance_peak_intensities(mz,peak_intensities, jitter_range):
135
+ enhanced_intensities = []
136
+ for intensity in peak_intensities:
137
+ jitter = random.uniform(-jitter_range, jitter_range)
138
+ enhanced_intensity = intensity + (intensity * jitter)
139
+ enhanced_intensities.append(enhanced_intensity)
140
+ return mz,enhanced_intensities
141
+
142
+ def peak_addition(mz,peak_intensities,noise_max):
143
+ n_noise_peaks = np.random.randint(0, noise_max)
144
+ max_mz=int(max(mz)*100)
145
+ min_mz=int(min(mz)*100)
146
+ idx_no_peaks = np.setdiff1d([i/100 for i in range(min_mz, max_mz)], mz)
147
+ idx_noise_peaks = np.random.choice(idx_no_peaks, n_noise_peaks)
148
+ mz = np.concatenate((mz, idx_noise_peaks))
149
+ new_values = 0.01 * np.random.random(len(idx_noise_peaks))
150
+ peak_intensities = np.concatenate((peak_intensities, new_values))
151
+ return mz,peak_intensities
152
+
153
+ def data_augmentation(spectrum):
154
+ mz_initial=spectrum.mz
155
+ intens_initial=spectrum.intensities
156
+ mz_rp,peak_rp = remove_peaks(mz_initial, intens_initial, threshold=0.001, percentage=0.2)
157
+ mz_enhance,peak_enhance=enhance_peak_intensities(mz_rp, peak_rp, jitter_range=0.4)
158
+ mz_add,peak_add = peak_addition(mz_enhance, peak_enhance, noise_max=10)
159
+ indices= np.where(mz_add == 0)[0]
160
+ mz_f = np.array([mz_add[i] for i in range(len(mz_add)) if i not in indices])
161
+ peak_f = np.array([peak_add[i] for i in range(len(mz_add)) if i not in indices])
162
+ peak_f = np.array([peak_f[i] for i in mz_f.argsort()])
163
+ mz_f.sort()
164
+ spectrum.set('num_peaks',str(len(mz_f)))
165
+ spectrum.peaks = Fragments(mz=mz_f,intensities=peak_f)
166
+ spectrum = msfilters.normalize_intensities(spectrum)
167
+ return spectrum
168
+
169
+ def graph_spec2vec_calculation(smiles,spectra):
170
+ print("calculating molecular graphs")
171
+ df = pd.DataFrame(columns=['Graph','MS2'])
172
+ for i in tqdm(range(len(smiles))):
173
+ try:
174
+ smi = smiles[i]
175
+ v_d = MolToGraph(smi)
176
+ spectrum = spectra[i]
177
+ #spec2 = data_augmentation(spectrum)
178
+ spectrum = msfilters.reduce_to_number_of_peaks(spectrum,n_required=3, n_max=300)
179
+ if spectrum is not None:
180
+ df.loc[len(df.index)] = [v_d,spectrum]
181
+ except:
182
+ print("SMILES", smi, "calculation failure")
183
+ print("Calculated", len(df), "molecular graph-mass spectrometry pairs")
184
+ return df
185
+
186
+ def graph_spec2vec_valid_calculation(smiles,spectra,formulas):
187
+ print("calculating molecular graphs")
188
+ df = pd.DataFrame(columns=['Graph','MS2','formula'])
189
+ for i in tqdm(range(len(smiles))):
190
+ try:
191
+ smi = smiles[i]
192
+ formula = formulas[i]
193
+ v_d = MolToGraph(smi)
194
+ spectrum = spectra[i]
195
+ #spec2 = data_augmentation(spectrum)
196
+ df.loc[len(df.index)] = [v_d,spectrum,formula]
197
+ except:
198
+ pass
199
+ print("Calculated", len(df), "molecular graph-mass spectrometry pairs")
200
+ return df
201
+
202
+ def graph_calculation(smiles,formulas):
203
+ print("calculating molecular graphs")
204
+ df = pd.DataFrame(columns=['Graph','formula'])
205
+ for i in tqdm(range(len(smiles))):
206
+ try:
207
+ smi = smiles[i]
208
+ formula=formulas[i]
209
+ v_d = MolToGraph(smi)
210
+ df.loc[len(df.index)] = [v_d,formula]
211
+ except:
212
+ pass
213
+ print("Calculated", len(df), "molecular graphs")
214
+ return df
215
+
216
+ class DataSetWrapper(object):
217
+ def __init__(self,
218
+ world_size,
219
+ rank,
220
+ batch_size,
221
+ num_workers,
222
+ valid_size,
223
+ s,
224
+ ms2_file,
225
+ smi_file):
226
+ self.world_size = world_size
227
+ self.rank = rank
228
+ self.batch_size = batch_size
229
+ self.num_workers = num_workers
230
+ self.valid_size = valid_size
231
+ self.s = s
232
+ self.ms2_file = ms2_file
233
+ self.smi_file = smi_file
234
+
235
+ def get_data_loaders(self):
236
+ self.smiles = np.load(self.smi_file).tolist()
237
+ self.ms2 = list(load_from_mgf(self.ms2_file))
238
+
239
+ # obtain training indices that will be used for validation
240
+
241
+ num_train = len(self.smiles)
242
+ indices = list(range(num_train))
243
+ np.random.shuffle(indices)
244
+
245
+ split = int(np.floor(self.valid_size * num_train))
246
+ train_idx, valid_idx = indices[split:], indices[:split]
247
+ self.train_smiles = [self.smiles[i] for i in train_idx]
248
+ self.train_ms2 = [self.ms2[i] for i in train_idx]
249
+ self.valid_smiles = [self.smiles[i] for i in valid_idx]
250
+ self.valid_ms2 = [self.ms2[i] for i in valid_idx]
251
+ self.train_graph_file = graph_spec2vec_calculation(self.train_smiles,self.train_ms2)
252
+ self.valid_graph_file = graph_spec2vec_calculation(self.valid_smiles,self.valid_ms2)
253
+ train_dataset = ClrDataset(self.train_graph_file,self.train_graph_file.index.values)
254
+ valid_dataset = ClrDataset(self.valid_graph_file,self.valid_graph_file.index.values)
255
+
256
+ train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset,valid_dataset)
257
+ return train_loader, valid_loader
258
+
259
+ def get_train_validation_data_loaders(self, train_dataset,valid_dataset):
260
+ train_sampler = DistributedSampler(train_dataset, num_replicas = self.world_size, rank=self.rank, shuffle = True)
261
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size,
262
+ sampler=train_sampler,shuffle=False,collate_fn = collate_func)
263
+ valid_sampler = DistributedSampler(valid_dataset, num_replicas = self.world_size, rank=self.rank, shuffle = False)
264
+ valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=self.batch_size,
265
+ sampler=valid_sampler,shuffle=False,collate_fn = collate_func)
266
+
267
+ #train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
268
+ # num_workers=self.num_workers, drop_last=True, shuffle=False,collate_fn = collate_func)
269
+
270
+ #valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
271
+ # num_workers=self.num_workers, drop_last=True,collate_fn = collate_func)
272
+ return train_loader, valid_loader
273
+
274
+
275
+ class DataSetWrapper_noddp(object):
276
+ def __init__(self,
277
+ batch_size,
278
+ num_workers,
279
+ valid_size,
280
+ s,
281
+ ms2_file,
282
+ smi_file):
283
+ self.batch_size = batch_size
284
+ self.num_workers = num_workers
285
+ self.valid_size = valid_size
286
+ self.s = s
287
+ self.ms2_file = ms2_file
288
+ self.smi_file = smi_file
289
+
290
+
291
+ def get_data_loaders(self):
292
+ self.smiles = np.load(self.smi_file).tolist()
293
+ self.ms2 = list(load_from_mgf(self.ms2_file))
294
+
295
+ # obtain training indices that will be used for validation
296
+
297
+ num_train = len(self.smiles)
298
+ indices = list(range(num_train))
299
+ np.random.shuffle(indices)
300
+
301
+ split = int(np.floor(self.valid_size * num_train))
302
+ train_idx, valid_idx = indices[split:], indices[:split]
303
+ self.train_smiles = [self.smiles[i] for i in train_idx]
304
+ self.train_ms2 = [self.ms2[i] for i in train_idx]
305
+ self.valid_smiles = [self.smiles[i] for i in valid_idx]
306
+ self.valid_ms2 = [self.ms2[i] for i in valid_idx]
307
+ self.train_graph_file = graph_spec2vec_calculation(self.train_smiles,self.train_ms2)
308
+ self.valid_graph_file = graph_spec2vec_calculation(self.valid_smiles,self.valid_ms2)
309
+ train_dataset = ClrDataset(self.train_graph_file,self.train_graph_file.index.values)
310
+ valid_dataset = ClrDataset(self.valid_graph_file,self.valid_graph_file.index.values)
311
+
312
+ train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset,valid_dataset)
313
+ return train_loader, valid_loader
314
+
315
+ def get_train_validation_data_loaders(self, train_dataset,valid_dataset):
316
+ train_loader =torch.utils.data.DataLoader(
317
+ train_dataset,
318
+ batch_size=self.batch_size,
319
+ num_workers=self.num_workers,
320
+ shuffle=False,
321
+ collate_fn=collate_func,
322
+ drop_last=True
323
+ )
324
+ valid_loader = torch.utils.data.DataLoader(
325
+ valid_dataset,
326
+ batch_size=self.batch_size,
327
+ num_workers=self.num_workers,
328
+ shuffle=False,
329
+ collate_fn=collate_func,
330
+ drop_last=False
331
+ )
332
+ return train_loader, valid_loader
333
+
334
+