antoniaebner's picture
unify preprocessing & add config usage
f40527e
"""
Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
"""
import os
import json
import logging
import argparse
import numpy as np
from datetime import datetime
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from src.model import Tox21XGBClassifier
from src.utils import create_dir, normalize_config
parser = argparse.ArgumentParser(
description="XGBoost Training script for Tox21 dataset"
)
parser.add_argument(
"--config",
type=str,
default="config/config.json",
)
def main(config):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# setup logger
logger = logging.getLogger(__name__)
script_name = os.path.splitext(os.path.basename(__file__))[0]
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(
os.path.join(
config["log_folder"],
f"{script_name}_{timestamp}.log",
)
),
logging.StreamHandler(),
],
)
logger.info(f"Config: {config}")
model_configs_repr = "Model configs: \n" + "\n".join(
[str(val) for val in config["model_config"].values()]
)
logger.info(f"Model configs: \n{model_configs_repr}")
logger.info("Preprocess train molecules")
train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
# filter out unsanitized molecules
train_is_clean = train_data["clean_mol_mask"]
val_is_clean = val_data["clean_mol_mask"]
train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
if config["merge_train_val"]:
train_X = {
descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
for descr in config["descriptors"]
}
train_y = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
else:
train_X = {descr: train_data[descr] for descr in config["descriptors"]}
train_y = train_data["labels"]
train_X = np.concatenate(
[train_X[descr] for descr in config["descriptors"]], axis=1
)
# clean data
bad_entries = np.isinf(train_X) | np.isnan(train_X)
bad_cols = np.any(bad_entries, axis=0)
if np.any(bad_cols):
train_X[:, bad_cols] = 0.0
# update model config
for i, task in enumerate(config["model_config"].keys()):
npos = np.nansum(train_y[:, i])
nneg = np.sum(~np.isnan(train_y[:, i])) - npos
config["model_config"][task].update(
{
"tree_method": "hist",
"n_estimators": 10_000,
"early_stopping_rounds": 50,
"eval_metric": "auc",
"scale_pos_weight": nneg / max(npos, 1),
"device": config["device"],
}
)
model = Tox21XGBClassifier(seed=config["seed"], task_configs=config["model_config"])
logger.info("Start training.")
for i, task in enumerate(model.tasks):
# Training -----------------------
task_labels = train_y[:, i]
label_mask = ~np.isnan(task_labels)
task_data = train_X[label_mask]
task_labels = task_labels[label_mask].astype(int)
# Remove low variance features and scale
var_thresh = VarianceThreshold(
threshold=config["model_config"][task]["var_threshold"]
)
task_data = var_thresh.fit_transform(task_data)
scaler = StandardScaler()
task_data = scaler.fit_transform(task_data)
model.feature_processors[task] = {
"selector": var_thresh,
"scaler": scaler,
}
# From X_train split 10% for an early stopping validation set
np.random.seed(config["seed"])
random_numbers = np.random.rand(task_data.shape[0])
es_val_mask = random_numbers < 0.1
es_train_mask = random_numbers >= 0.1
X_es_val, y_es_val = task_data[es_val_mask], task_labels[es_val_mask]
X_es_train, y_es_train = task_data[es_train_mask], task_labels[es_train_mask]
logger.info(
f"Fit task {task} using {sum(label_mask)} samples and {task_data.shape[1]} features"
)
model.fit(
task, X_es_train, y_es_train, eval_set=[(X_es_val, y_es_val)], verbose=False
)
if config["debug"]:
break
logger.info("Finished training.")
logger.info(f"Save model under {config['ckpt_path']}")
logger.info(f"Save feature preprocessors under {config['preprocessor_path']}")
model.save_model(config["ckpt_path"], config["preprocessor_path"])
if __name__ == "__main__":
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
config = normalize_config(config)
create_dir(config["log_folder"])
main(config)