tox21_tabpfn_classifier / preprocess.py
antoniaebner's picture
add preprocessing
7c1c2c8
# pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
"""
This files includes a the data processing for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
import os
import argparse
import numpy as np
from src.data import create_descriptors, get_tox21_split
from src.utils import (
TASKS,
HF_TOKEN,
write_pickle,
create_dir,
)
parser = argparse.ArgumentParser(
description="Data preprocessing script for the Tox21 dataset"
)
parser.add_argument(
"--data_folder",
type=str,
default="data/",
help="Folder containing the tox21_compoundData.csv file.",
)
parser.add_argument(
"--save_folder",
type=str,
default="data/",
help="Folder to which preprocessed the data CSV and NPZ files should be saved.",
)
parser.add_argument(
"--cv_fold",
type=int,
default=4,
help="Select fold used as validation set.",
)
parser.add_argument(
"--feature_selection",
type=int,
default=1,
help="True (=1) to use feature selection.",
)
parser.add_argument(
"--feature_selection_path",
type=str,
default="feat_selection.npz",
help="Filename for saving feature selections.",
)
parser.add_argument(
"--min_var",
type=float,
default=0.05,
help="Minimum variance threshold for selecting features.",
)
parser.add_argument(
"--max_corr",
type=float,
default=0.95,
help="Maximum correlation threshold for selecting features.",
)
parser.add_argument(
"--ecdfs_path",
type=str,
default="ecdfs.pkl",
help="Filename to save ECDFs.",
)
parser.add_argument(
"--ecfps_radius",
type=int,
default=3,
help="Radius used for creating ECFPs.",
)
parser.add_argument(
"--ecfps_folds",
type=int,
default=8192,
help="Folds used for creating ECFPs.",
)
def main(args):
"""Preprocessing train/val data to use for TabPFN.
1. Download Tox21 train/val data from HF
2. Preprocess dataset splits
"""
ds = get_tox21_split(HF_TOKEN, cvfold=args.cv_fold)
feature_creation_kwargs = {
"radius": args.ecfps_radius,
"fpsize": args.ecfps_folds,
"min_var": args.min_var,
"max_corr": args.max_corr,
}
splits = ["train", "validation"]
for split in splits:
print(f"Preprocess {split} molecules")
ds_split = ds[split]
smiles = list(ds_split["smiles"])
if split == "train":
output = create_descriptors(
smiles,
return_feature_selection=True,
return_ecdfs=True,
**feature_creation_kwargs,
)
features = output.pop("features")
feature_selection = output.pop("feature_selection")
ecdfs = output.pop("ecdfs")
np.savez(
args.feature_selection_path,
ecfps_selec=feature_selection["ecfps_selec"],
tox_selec=feature_selection["tox_selec"],
)
print(f"Saved feature selection under {args.feature_selection_path}")
write_pickle(args.ecdfs_path, ecdfs)
print(f"Saved ECDFs under {args.ecdfs_path}")
else:
features = create_descriptors(
smiles,
ecdfs=ecdfs,
feature_selection=feature_selection,
**feature_creation_kwargs,
)["features"]
labels = []
for task in TASKS:
labels.append(ds_split[task].to_numpy())
labels = np.stack(labels, axis=1)
save_path = os.path.join(args.save_folder, f"tox21_{split}_cv4.npz")
with open(save_path, "wb") as f:
np.savez(
f,
labels=labels,
**features,
)
print(f"Saved preprocessed {split} split under {save_path}")
print("Preprocessing finished successfully")
if __name__ == "__main__":
args = parser.parse_args()
args.ecdfs_path = os.path.join(args.save_folder, args.ecdfs_path)
args.feature_selection_path = os.path.join(
args.save_folder, args.feature_selection_path
)
create_dir(args.save_folder)
create_dir(args.ecdfs_path, is_file=True)
create_dir(args.feature_selection_path, is_file=True)
main(args)