Spaces:
Runtime error
Runtime error
| """ | |
| Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import argparse | |
| import numpy as np | |
| from datetime import datetime | |
| from sklearn.feature_selection import VarianceThreshold | |
| from sklearn.preprocessing import StandardScaler | |
| from src.model import Tox21XGBClassifier | |
| from src.utils import create_dir, normalize_config | |
| parser = argparse.ArgumentParser( | |
| description="XGBoost Training script for Tox21 dataset" | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="config/config.json", | |
| ) | |
| def main(config): | |
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| # setup logger | |
| logger = logging.getLogger(__name__) | |
| script_name = os.path.splitext(os.path.basename(__file__))[0] | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[ | |
| logging.FileHandler( | |
| os.path.join( | |
| config["log_folder"], | |
| f"{script_name}_{timestamp}.log", | |
| ) | |
| ), | |
| logging.StreamHandler(), | |
| ], | |
| ) | |
| logger.info(f"Config: {config}") | |
| model_configs_repr = "Model configs: \n" + "\n".join( | |
| [str(val) for val in config["model_config"].values()] | |
| ) | |
| logger.info(f"Model configs: \n{model_configs_repr}") | |
| logger.info("Preprocess train molecules") | |
| train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz")) | |
| val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz")) | |
| # filter out unsanitized molecules | |
| train_is_clean = train_data["clean_mol_mask"] | |
| val_is_clean = val_data["clean_mol_mask"] | |
| train_data = {descr: array[train_is_clean] for descr, array in train_data.items()} | |
| val_data = {descr: array[val_is_clean] for descr, array in val_data.items()} | |
| if config["merge_train_val"]: | |
| train_X = { | |
| descr: np.concatenate([train_data[descr], val_data[descr]], axis=0) | |
| for descr in config["descriptors"] | |
| } | |
| train_y = np.concatenate([train_data["labels"], val_data["labels"]], axis=0) | |
| else: | |
| train_X = {descr: train_data[descr] for descr in config["descriptors"]} | |
| train_y = train_data["labels"] | |
| train_X = np.concatenate( | |
| [train_X[descr] for descr in config["descriptors"]], axis=1 | |
| ) | |
| # clean data | |
| bad_entries = np.isinf(train_X) | np.isnan(train_X) | |
| bad_cols = np.any(bad_entries, axis=0) | |
| if np.any(bad_cols): | |
| train_X[:, bad_cols] = 0.0 | |
| # update model config | |
| for i, task in enumerate(config["model_config"].keys()): | |
| npos = np.nansum(train_y[:, i]) | |
| nneg = np.sum(~np.isnan(train_y[:, i])) - npos | |
| config["model_config"][task].update( | |
| { | |
| "tree_method": "hist", | |
| "n_estimators": 10_000, | |
| "early_stopping_rounds": 50, | |
| "eval_metric": "auc", | |
| "scale_pos_weight": nneg / max(npos, 1), | |
| "device": config["device"], | |
| } | |
| ) | |
| model = Tox21XGBClassifier(seed=config["seed"], task_configs=config["model_config"]) | |
| logger.info("Start training.") | |
| for i, task in enumerate(model.tasks): | |
| # Training ----------------------- | |
| task_labels = train_y[:, i] | |
| label_mask = ~np.isnan(task_labels) | |
| task_data = train_X[label_mask] | |
| task_labels = task_labels[label_mask].astype(int) | |
| # Remove low variance features and scale | |
| var_thresh = VarianceThreshold( | |
| threshold=config["model_config"][task]["var_threshold"] | |
| ) | |
| task_data = var_thresh.fit_transform(task_data) | |
| scaler = StandardScaler() | |
| task_data = scaler.fit_transform(task_data) | |
| model.feature_processors[task] = { | |
| "selector": var_thresh, | |
| "scaler": scaler, | |
| } | |
| # From X_train split 10% for an early stopping validation set | |
| np.random.seed(config["seed"]) | |
| random_numbers = np.random.rand(task_data.shape[0]) | |
| es_val_mask = random_numbers < 0.1 | |
| es_train_mask = random_numbers >= 0.1 | |
| X_es_val, y_es_val = task_data[es_val_mask], task_labels[es_val_mask] | |
| X_es_train, y_es_train = task_data[es_train_mask], task_labels[es_train_mask] | |
| logger.info( | |
| f"Fit task {task} using {sum(label_mask)} samples and {task_data.shape[1]} features" | |
| ) | |
| model.fit( | |
| task, X_es_train, y_es_train, eval_set=[(X_es_val, y_es_val)], verbose=False | |
| ) | |
| if config["debug"]: | |
| break | |
| logger.info("Finished training.") | |
| logger.info(f"Save model under {config['ckpt_path']}") | |
| logger.info(f"Save feature preprocessors under {config['preprocessor_path']}") | |
| model.save_model(config["ckpt_path"], config["preprocessor_path"]) | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| with open(args.config, "r") as f: | |
| config = json.load(f) | |
| config = normalize_config(config) | |
| create_dir(config["log_folder"]) | |
| main(config) | |