| | """Perform CV (with explainability) on different feature sets and log to mlflow. |
| | |
| | Includes functionality to nest runs under parent run (e.g. different feature sets |
| | under a main run) and set a decision threshold for model scores. Logs the following |
| | artifacts as well as metrics and parameters: |
| | 1. List of model features |
| | 2. Feature correlation matrix |
| | 3. Global explainability (averaged over K folds) |
| | 4. Cumulative gains curve |
| | 5. Lift curve |
| | 6. Probability distributions with KDE |
| | """ |
| | from imblearn.ensemble import BalancedRandomForestClassifier |
| | from lenusml import splits, crossvalidation, plots |
| | import numpy as np |
| | import os |
| | import pandas as pd |
| | from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay |
| | import mlflow |
| | import matplotlib.pyplot as plt |
| | |
| |
|
| |
|
| | def get_crossvalidation_importance(*, feature_names, crossval): |
| | """ |
| | Create dataframe of mean global feature importance for all EBMs used in CV. |
| | |
| | Args: |
| | feature_names (list): list of model feature names |
| | crossval (dict): output of cross_validation_return_estimator_and_scores |
| | |
| | Returns: |
| | pd.DataFrame: contains feature names, global importance for each of the K |
| | estimators, mean importance across the estimators and scaled mean importance |
| | relative to the most important feature. |
| | """ |
| | |
| | for i, est in enumerate(crossval['estimator']): |
| | exp_global = crossval['estimator'][i].feature_importances_ |
| |
|
| | explanations = pd.DataFrame([feature_names, exp_global]).T |
| | explanations.columns = ['Feature', 'Score_{}'.format(i)] |
| |
|
| | |
| | if i == 0: |
| | explanations_all = explanations.copy() |
| | else: |
| | explanations_all = explanations_all.merge(explanations, on='Feature') |
| |
|
| | |
| | explanations_all['Mean'] = explanations_all.drop(columns=['Feature']).mean(axis=1) |
| | explanations_all = explanations_all.sort_values('Mean', ascending=False) |
| | |
| | explanations_all['Mean_scaled'] = explanations_all['Mean'] /\ |
| | explanations_all['Mean'].abs().max() |
| | return explanations_all |
| |
|
| |
|
| | data_dir = '../data/models/model1/' |
| | cohort_info_dir = '../data/cohort_info/' |
| | output_dir = '../data/models/model1/output' |
| |
|
| | |
| | fold_patients = np.load(os.path.join(cohort_info_dir, 'fold_patients.npy'), |
| | allow_pickle=True) |
| | train_data = pd.read_pickle(os.path.join(data_dir, 'train_data_cv.pkl')) |
| |
|
| | |
| | cross_validation_fold_indices = splits.custom_cv_fold_indices(fold_patients=fold_patients, |
| | id_column='StudyId', |
| | train_data=train_data) |
| |
|
| | mlflow.set_tracking_uri("sqlite:///mlruns.sqlite") |
| | mlflow.set_experiment('model_drop2') |
| |
|
| | |
| | scoring = ['f1', 'balanced_accuracy', 'accuracy', 'precision', 'recall', 'roc_auc', |
| | 'average_precision'] |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | cols_to_drop = ['StudyId', 'IsExac'] |
| | features_list = [col for col in train_data.columns if col not in cols_to_drop] |
| |
|
| | |
| | features = train_data[features_list].astype('float') |
| | target = train_data.IsExac.astype('float') |
| |
|
| | |
| | |
| | artifact_dir = './tmp' |
| | |
| | os.makedirs(artifact_dir, exist_ok=True) |
| | |
| | for f in os.listdir(artifact_dir): |
| | os.remove(os.path.join(artifact_dir, f)) |
| |
|
| | np.savetxt(os.path.join(artifact_dir, 'features.txt'), features_list, |
| | delimiter=",", fmt='%s') |
| |
|
| | plots.plot_feature_correlations(features=features, figsize=( |
| | len(features_list) // 2, len(features_list) // 2), |
| | savefig=True, output_dir=artifact_dir, |
| | figname='features_correlations.png') |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | with mlflow.start_run(run_name='eosinophil_count_0.3_threshold'): |
| | |
| | |
| | |
| |
|
| | |
| | model = BalancedRandomForestClassifier(random_state=0) |
| | |
| | |
| | |
| |
|
| | |
| | |
| | crossval, model_scores =\ |
| | crossvalidation.cross_validation_return_estimator_and_scores( |
| | model=model, features=features, |
| | target=target, |
| | fold_indices=cross_validation_fold_indices) |
| |
|
| | |
| | for score in scoring: |
| | mlflow.log_metric(score, np.mean(crossval['test_' + score])) |
| |
|
| | |
| | params = model.get_params() |
| | for param in params: |
| | mlflow.log_param(param, params[param]) |
| |
|
| | |
| | explainability = get_crossvalidation_importance(feature_names=features_list, |
| | crossval=crossval) |
| | explainability.to_csv(os.path.join(artifact_dir, |
| | 'global_feature_importances.csv'), index=False) |
| | plots.plot_global_explainability_cv(importances=explainability, |
| | scaled=True, |
| | figsize=( |
| | len(features_list) // 2.5, |
| | len(features_list) // 6), |
| | savefig=True, output_dir=artifact_dir) |
| | |
| | plots.plot_lift_curve(scores=model_scores, savefig=True, output_dir=artifact_dir, |
| | figname='cumulative_gains_curve.png') |
| | plots.plot_cumulative_gains_curve(scores=model_scores, savefig=True, |
| | output_dir=artifact_dir, |
| | figname='lift_curve.png') |
| |
|
| | |
| | plots.plot_score_distribution(scores=model_scores, postive_class_name='Exac', |
| | negative_class_name='No exac', savefig=True, |
| | output_dir=artifact_dir, |
| | figname='model_score_distribution.png') |
| |
|
| | |
| | for threshold in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]: |
| | plots.plot_confusion_matrix( |
| | target_true=model_scores.true_label, |
| | target_predicted=np.where(model_scores.model_score > threshold, 1, 0), |
| | classes=['No exac', 'Exac'], savefig=True, |
| | output_dir=artifact_dir, |
| | figname='confusion_matrix_{}.png'.format(threshold)) |
| |
|
| | |
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | RocCurveDisplay.from_predictions(y_true=model_scores.true_label, |
| | y_pred=model_scores.model_score, ax=ax) |
| | ax.set_xlabel('False Positive Rate') |
| | ax.set_ylabel('True Positive Rate') |
| | plt.legend(frameon=False) |
| | plt.tight_layout() |
| | plt.savefig(os.path.join(artifact_dir, 'roc_curve.png'), dpi=150) |
| | plt.close() |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | PrecisionRecallDisplay.from_predictions(y_true=model_scores.true_label, |
| | y_pred=model_scores.model_score, ax=ax) |
| | ax.set_xlabel('Recall') |
| | ax.set_ylabel('Precision') |
| | plt.legend(frameon=False) |
| | plt.tight_layout() |
| | plt.savefig(os.path.join(artifact_dir, 'precision_recall_curve.png'), dpi=150) |
| | plt.close() |
| |
|
| | |
| | mlflow.log_artifacts(artifact_dir) |
| | mlflow.end_run() |
| | |
| |
|