antoniaebner commited on
Commit
f0ecde9
·
1 Parent(s): 1561c1d

add preprocessing pipeline

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -3
  2. src/preprocess_old.py +513 -0
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
  fastapi
2
  uvicorn[standard]
3
- statsmodels
4
- rdkit
5
  numpy==2.2.6
6
  scikit-learn==1.6.1
7
  joblib
8
  tabulate
9
- datasets
10
  scipy==1.16.1
11
  pandas==2.3.2
12
  tabpfn==2.2.1
 
1
  fastapi
2
  uvicorn[standard]
3
+ statsmodels==0.14.5
4
+ rdkit==2025.03.5
5
  numpy==2.2.6
6
  scikit-learn==1.6.1
7
  joblib
8
  tabulate
9
+ datasets==4.0.0
10
  scipy==1.16.1
11
  pandas==2.3.2
12
  tabpfn==2.2.1
src/preprocess_old.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import json
12
+ from typing import Iterable
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from sklearn.preprocessing import StandardScaler
18
+ from sklearn.feature_selection import VarianceThreshold
19
+ from statsmodels.distributions.empirical_distribution import ECDF
20
+ from datasets import load_dataset
21
+
22
+ from rdkit import Chem, DataStructs
23
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
24
+ from rdkit.Chem.rdchem import Mol
25
+
26
+ from utils import (
27
+ TASKS,
28
+ KNOWN_DESCR,
29
+ HF_TOKEN,
30
+ USED_200_DESCR,
31
+ Standardizer,
32
+ load_pickle,
33
+ write_pickle,
34
+ )
35
+
36
+ parser = argparse.ArgumentParser(
37
+ description="Data preprocessing script for the Tox21 dataset"
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--data_folder",
42
+ type=str,
43
+ default="data/",
44
+ )
45
+
46
+ parser.add_argument(
47
+ "--save_folder",
48
+ type=str,
49
+ default="data/",
50
+ )
51
+
52
+ parser.add_argument(
53
+ "--use_hf",
54
+ type=int,
55
+ default=0,
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--path_ecdfs",
60
+ type=str,
61
+ default="ecdfs.pkl",
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--path_feat_selec",
66
+ type=str,
67
+ default="feat_selection.npz",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--tox_smarts_filepath",
72
+ type=str,
73
+ default="tox_smarts.json",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--feature_selection",
78
+ type=int,
79
+ default=1,
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--min_var",
84
+ type=float,
85
+ default=0.05,
86
+ )
87
+
88
+ parser.add_argument(
89
+ "--max_corr",
90
+ type=float,
91
+ default=0.95,
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--ecfps_radius",
96
+ type=int,
97
+ default=None,
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--ecfps_folds",
102
+ type=int,
103
+ default=8192,
104
+ )
105
+
106
+
107
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
108
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
109
+
110
+ Args:
111
+ smiles (list[str]): list of SMILES
112
+
113
+ Returns:
114
+ list[Mol]: list of cleaned molecules
115
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
116
+ index `i` could not be cleaned and was removed.
117
+ """
118
+ sm = Standardizer(canon_taut=True)
119
+
120
+ clean_mol_mask = list()
121
+ mols = list()
122
+ for i, smile in enumerate(smiles):
123
+ mol = Chem.MolFromSmiles(smile)
124
+ standardized_mol, _ = sm.standardize_mol(mol)
125
+ is_cleaned = standardized_mol is not None
126
+ clean_mol_mask.append(is_cleaned)
127
+ if not is_cleaned:
128
+ continue
129
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
130
+ mols.append(can_mol)
131
+
132
+ return mols, np.array(clean_mol_mask)
133
+
134
+
135
+ def create_ecfp_fps(mols: list[Mol], radius=None, fpsize=None) -> np.ndarray:
136
+ """This function ECFP fingerprints for a list of molecules.
137
+
138
+ Args:
139
+ mols (list[Mol]): list of molecules
140
+
141
+ Returns:
142
+ np.ndarray: ECFP fingerprints of molecules
143
+ """
144
+ ecfps = list()
145
+
146
+ kwargs = {}
147
+ if not fpsize is None:
148
+ kwargs["fpSize"] = fpsize
149
+ if not radius is None:
150
+ kwargs["radius"] = radius
151
+ for mol in mols:
152
+ gen = rdFingerprintGenerator.GetMorganGenerator(countSimulation=True, **kwargs)
153
+ fp_sparse_vec = gen.GetCountFingerprint(mol)
154
+
155
+ fp = np.zeros((0,), np.int8)
156
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
157
+
158
+ ecfps.append(fp)
159
+
160
+ return np.array(ecfps)
161
+
162
+
163
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
164
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
165
+ return np.array(maccs)
166
+
167
+
168
+ def get_tox_patterns(filepath: str):
169
+ """This calculates tox features defined in tox_smarts.json.
170
+ Args:
171
+ mols: A list of Mol
172
+ n_jobs: If >1 multiprocessing is used
173
+ """
174
+ # load patterns
175
+ with open(filepath) as f:
176
+ smarts_list = [s[1] for s in json.load(f)]
177
+
178
+ # Code does not work for this case
179
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
180
+
181
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
182
+ # and then use them for all molecules. This gives a huge speedup over existing code.
183
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
184
+ all_patterns = []
185
+ for smarts in smarts_list:
186
+ patterns = [] # list of smarts-patterns
187
+ # value for each of the patterns above. Negates the values of the above later.
188
+ negations = []
189
+
190
+ if " AND " in smarts:
191
+ smarts = smarts.split(" AND ")
192
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
193
+ else:
194
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
195
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
196
+ smarts = smarts.split(" OR ")
197
+ merge_any = True
198
+
199
+ # for all subsmarts check if they are preceded by 'NOT '
200
+ for s in smarts:
201
+ neg = s.startswith("NOT ")
202
+ if neg:
203
+ s = s[4:]
204
+ patterns.append(Chem.MolFromSmarts(s))
205
+ negations.append(neg)
206
+
207
+ all_patterns.append((patterns, negations, merge_any))
208
+ return all_patterns
209
+
210
+
211
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
212
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
213
+ tox_data = []
214
+ for mol in mols:
215
+ mol_features = []
216
+ for patts, negations, merge_any in patterns:
217
+ matches = [mol.HasSubstructMatch(p) for p in patts]
218
+ matches = [m != n for m, n in zip(matches, negations)]
219
+ if merge_any:
220
+ pres = any(matches)
221
+ else:
222
+ pres = all(matches)
223
+ mol_features.append(pres)
224
+
225
+ tox_data.append(np.array(mol_features))
226
+
227
+ return np.array(tox_data)
228
+
229
+
230
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
231
+ """This function creates RDKit descriptors for a list of molecules.
232
+
233
+ Args:
234
+ mols (list[Mol]): list of molecules
235
+
236
+ Returns:
237
+ np.ndarray: RDKit descriptors of molecules
238
+ """
239
+ rdkit_descriptors = list()
240
+
241
+ for mol in mols:
242
+ descrs = []
243
+ for _, descr_calc_fn in Descriptors._descList:
244
+ descrs.append(descr_calc_fn(mol))
245
+
246
+ descrs = np.array(descrs)
247
+ descrs = descrs[USED_200_DESCR]
248
+ rdkit_descriptors.append(descrs)
249
+
250
+ return np.array(rdkit_descriptors)
251
+
252
+
253
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
254
+ """Create quantile values for given features using the columns
255
+
256
+ Args:
257
+ raw_features (np.ndarray): values to put into quantiles
258
+ ecdfs (list): ECDFs to use
259
+
260
+ Returns:
261
+ np.ndarray: computed quantiles
262
+ """
263
+ quantiles = np.zeros_like(raw_features)
264
+
265
+ for column in range(raw_features.shape[1]):
266
+ raw_values = raw_features[:, column].reshape(-1)
267
+ ecdf = ecdfs[column]
268
+ q = ecdf(raw_values)
269
+ quantiles[:, column] = q
270
+
271
+ return quantiles
272
+
273
+
274
+ def fill(features, mask, value=np.nan):
275
+ n_mols = len(mask)
276
+ n_features = features.shape[1]
277
+
278
+ data = np.zeros(shape=(n_mols, n_features))
279
+ data.fill(value)
280
+ data[~mask] = features
281
+ return data
282
+
283
+
284
+ def normalize_features(
285
+ raw_features,
286
+ scaler=None,
287
+ save_scaler_path: str = "",
288
+ verbose=True,
289
+ ):
290
+ if scaler is None:
291
+ scaler = StandardScaler()
292
+ scaler.fit(raw_features)
293
+ if verbose:
294
+ print("Fitted the StandardScaler")
295
+ if save_scaler_path:
296
+ write_pickle(save_scaler_path, scaler)
297
+ if verbose:
298
+ print(f"Saved the StandardScaler under {save_scaler_path}")
299
+
300
+ # Normalize feature vectors
301
+ normalized_features = scaler.transform(raw_features)
302
+ if verbose:
303
+ print("Normalized molecule features")
304
+ return normalized_features, scaler
305
+
306
+
307
+ def create_descriptors(
308
+ smiles,
309
+ ecdfs=None,
310
+ scaler=None,
311
+ feature_selection=None,
312
+ descriptors: Iterable = KNOWN_DESCR,
313
+ ):
314
+ # Create cleanded rdkit mol objects
315
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
316
+ print("Cleaned molecules")
317
+
318
+ features = []
319
+ # if "ecfps" in descriptors:
320
+ # Create fingerprints and descriptors
321
+ ecfps = create_ecfp_fps(mols)
322
+ # expand using mol_mask
323
+ ecfps = fill(ecfps, ~clean_mol_mask)
324
+ features.append(ecfps)
325
+ print("Created ECFP fingerprints")
326
+
327
+ # if "maccs" in descriptors:
328
+ maccs = create_maccs_keys(mols)
329
+ maccs = fill(maccs, ~clean_mol_mask)
330
+ features.append(maccs)
331
+ print("Created MACCS keys")
332
+
333
+ # if "tox" in descriptors:
334
+ tox_patterns = get_tox_patterns("assets/tox_smarts.json")
335
+ tox = create_tox_features(mols, tox_patterns)
336
+ tox = fill(tox, ~clean_mol_mask)
337
+ features.append(tox)
338
+ print("Created Tox features")
339
+
340
+ # if "rdkit_descr_quantiles" in descriptors:
341
+ rdkit_descrs = create_rdkit_descriptors(mols)
342
+ print("Created RDKit descriptors")
343
+
344
+ # Create and save ecdfs
345
+ if ecdfs is None:
346
+ print("Create ECDFs")
347
+ ecdfs = []
348
+ for column in range(rdkit_descrs.shape[1]):
349
+ raw_values = rdkit_descrs[:, column].reshape(-1)
350
+ ecdfs.append(ECDF(raw_values))
351
+
352
+ # Create quantiles
353
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
354
+ # expand using mol_mask
355
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
356
+ features.append(rdkit_descr_quantiles)
357
+ print("Created quantiles of RDKit descriptors")
358
+
359
+ # concatenate features
360
+ raw_features = np.concatenate(features, axis=1)
361
+
362
+ # normalize with scaler if scaler is passed, else create scaler
363
+ features, _ = normalize_features(
364
+ raw_features,
365
+ scaler=scaler,
366
+ verbose=True,
367
+ )
368
+
369
+ return features, clean_mol_mask
370
+
371
+
372
+ def get_feature_selection(
373
+ raw_features: np.ndarray, min_var=0.01, max_corr=0.95
374
+ ) -> np.ndarray:
375
+ # select features with at least 0.01 variation
376
+ var_thresh = VarianceThreshold(threshold=min_var)
377
+ feature_selection = var_thresh.fit(raw_features).get_support(
378
+ indices=True
379
+ ) # list containing selected feature indices
380
+
381
+ n_features_preselected = len(feature_selection)
382
+
383
+ # Remove highly correlated features
384
+ corr_matrix = np.corrcoef(raw_features[:, feature_selection], rowvar=False)
385
+ upper_tri = np.triu(corr_matrix, k=1)
386
+ to_keep = np.ones((n_features_preselected,), dtype=bool)
387
+ for i in range(upper_tri.shape[0]):
388
+ for j in range(upper_tri.shape[1]):
389
+ if upper_tri[i, j] > max_corr:
390
+ to_keep[j] = False
391
+
392
+ feature_selection = feature_selection[to_keep]
393
+ return feature_selection
394
+
395
+
396
+ def main(args):
397
+ splits = ["train", "validation", "test"] # TODO: remove test
398
+ if args.use_hf:
399
+ ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
400
+
401
+ else:
402
+ ds = {}
403
+ for split in splits:
404
+ if split == "train":
405
+ ds[split] = pd.read_csv(
406
+ os.path.join(args.data_folder, f"tox21_{split}_cv4.csv")
407
+ )
408
+ else:
409
+ ds[split] = pd.read_csv(
410
+ os.path.join(args.data_folder, f"tox21_{split}_cv4.csv")
411
+ )
412
+
413
+ for split in splits:
414
+
415
+ print(f"Preprocess {split} molecules")
416
+ smiles = list(ds[split]["smiles"])
417
+
418
+ # Create cleanded rdkit mol objects
419
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
420
+ print("Cleaned molecules")
421
+
422
+ tox_patterns = get_tox_patterns(args.tox_smarts_filepath)
423
+
424
+ # Create fingerprints and descriptors
425
+ ecfps = create_ecfp_fps(mols, radius=args.ecfps_radius, fpsize=args.ecfps_folds)
426
+ # expand using mol_mask
427
+ ecfps = fill(ecfps, ~clean_mol_mask)
428
+ print("Created ECFP fingerprints")
429
+
430
+ tox = create_tox_features(mols, tox_patterns)
431
+ tox = fill(tox, ~clean_mol_mask)
432
+ print("Created Tox features")
433
+
434
+ # Create and save feature selection for ecfps and tox
435
+ if args.feature_selection:
436
+ if split == "train":
437
+ print("Create Feature selection")
438
+ ecfps_selec = get_feature_selection(ecfps, args.min_var, args.max_corr)
439
+ tox_selec = get_feature_selection(tox, args.min_var, args.max_corr)
440
+ np.savez(
441
+ args.path_feat_selec, ecfps_selec=ecfps_selec, tox_selec=tox_selec
442
+ )
443
+ else:
444
+ print(f"Load feature selection from {args.path_feat_selec}")
445
+ feature_selection = np.load(args.path_feat_selec)
446
+ ecfps_selec = feature_selection["ecfps_selec"]
447
+ tox_selec = feature_selection["tox_selec"]
448
+
449
+ ecfps = ecfps[:, ecfps_selec]
450
+ tox = tox[:, tox_selec]
451
+
452
+ maccs = create_maccs_keys(mols)
453
+ maccs = fill(maccs, ~clean_mol_mask)
454
+ print("Created MACCS keys")
455
+
456
+ rdkit_descrs = create_rdkit_descriptors(mols)
457
+ print("Created RDKit descriptors")
458
+
459
+ # Create and save ecdfs
460
+ if split == "train":
461
+ print("Create ECDFs")
462
+ ecdfs = []
463
+ for column in range(rdkit_descrs.shape[1]):
464
+ raw_values = rdkit_descrs[:, column].reshape(-1)
465
+ ecdfs.append(ECDF(raw_values))
466
+
467
+ write_pickle(args.path_ecdfs, ecdfs)
468
+ print(f"Saved ECDFs under {args.path_ecdfs}")
469
+ else:
470
+ print(f"Load ECDFs from {args.path_ecdfs}")
471
+ ecdfs = load_pickle(args.path_ecdfs)
472
+
473
+ # Create quantiles
474
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
475
+ # expand using mol_mask
476
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
477
+ print("Created quantiles of RDKit descriptors")
478
+
479
+ labels = []
480
+ for task in TASKS:
481
+ datasplit = ds[split].to_pandas() if args.use_hf else ds[split]
482
+ labels.append(datasplit[task].to_numpy())
483
+ labels = np.stack(labels, axis=1)
484
+
485
+ save_path = os.path.join(args.save_folder, f"tox21_{split}_cv4.npz")
486
+ with open(save_path, "wb") as f:
487
+ np.savez(
488
+ f,
489
+ labels=labels,
490
+ ecfps=ecfps,
491
+ tox=tox,
492
+ maccs=maccs,
493
+ rdkit_descr_quantiles=rdkit_descr_quantiles,
494
+ )
495
+ print(f"Saved preprocessed {split} split under {save_path}")
496
+
497
+ print("Preprocessing finished successfully")
498
+
499
+
500
+ if __name__ == "__main__":
501
+ args = parser.parse_args()
502
+
503
+ if not os.path.exists(args.save_folder):
504
+ os.makedirs(args.save_folder)
505
+
506
+ args.path_ecdfs = os.path.join(args.save_folder, args.path_ecdfs)
507
+ args.path_feat_selec = os.path.join(args.save_folder, args.path_feat_selec)
508
+ args.tox_smarts_filepath = os.path.join(args.data_folder, args.tox_smarts_filepath)
509
+
510
+ if not os.path.exists(os.path.dirname(args.path_ecdfs)):
511
+ os.makedirs(os.path.dirname(args.path_ecdfs))
512
+
513
+ main(args)