diff --git "a/XGBOOST_XAI/ML2.json" "b/XGBOOST_XAI/ML2.json" new file mode 100644--- /dev/null +++ "b/XGBOOST_XAI/ML2.json" @@ -0,0 +1,2172 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting xgboost\n", + " Downloading xgboost-3.2.0-py3-none-manylinux_2_28_x86_64.whl.metadata (2.1 kB)\n", + "Collecting shap\n", + " Downloading shap-0.51.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (25 kB)\n", + "Requirement already satisfied: pandas in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.1.4)\n", + "Requirement already satisfied: numpy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (1.26.4)\n", + "Requirement already satisfied: scikit-learn in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (1.3.2)\n", + "Requirement already satisfied: matplotlib in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (3.8.2)\n", + "Collecting seaborn\n", + " Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n", + "Collecting pyarrow\n", + " Downloading pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.1 kB)\n", + "Collecting fastparquet\n", + " Downloading fastparquet-2026.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (4.8 kB)\n", + "Requirement already satisfied: nvidia-nccl-cu12 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from xgboost) (2.27.3)\n", + "Requirement already satisfied: scipy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from xgboost) (1.11.4)\n", + "Collecting numpy\n", + " Downloading numpy-2.4.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)\n", + "Requirement already satisfied: tqdm>=4.27.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from shap) (4.67.3)\n", + "Requirement already satisfied: packaging>20.9 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from shap) (25.0)\n", + "Collecting slicer==0.0.8 (from shap)\n", + " Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB)\n", + "Collecting numba (from shap)\n", + " Downloading numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.9 kB)\n", + "Collecting llvmlite (from shap)\n", + " Downloading llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.0 kB)\n", + "Collecting cloudpickle (from shap)\n", + " Downloading cloudpickle-3.1.2-py3-none-any.whl.metadata (7.1 kB)\n", + "Requirement already satisfied: typing-extensions in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from shap) (4.15.0)\n", + "INFO: pip is looking at multiple versions of pandas to determine which version is compatible with other requirements. This could take a while.\n", + "Collecting pandas\n", + " Downloading pandas-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (79 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas) (2.9.0.post0)\n", + "INFO: pip is looking at multiple versions of scikit-learn to determine which version is compatible with other requirements. This could take a while.\n", + "Collecting scikit-learn\n", + " Downloading scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)\n", + "Requirement already satisfied: joblib>=1.3.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from scikit-learn) (1.5.3)\n", + "Requirement already satisfied: threadpoolctl>=3.2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from scikit-learn) (3.6.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (1.3.3)\n", + "Requirement already satisfied: cycler>=0.10 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (4.62.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (1.5.0)\n", + "INFO: pip is looking at multiple versions of matplotlib to determine which version is compatible with other requirements. This could take a while.\n", + "Collecting matplotlib\n", + " Downloading matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (52 kB)\n", + "Requirement already satisfied: pillow>=8 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (12.1.1)\n", + "Requirement already satisfied: pyparsing>=3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (3.3.2)\n", + "Collecting cramjam>=2.3 (from fastparquet)\n", + " Downloading cramjam-2.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: fsspec in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from fastparquet) (2026.2.0)\n", + "Requirement already satisfied: six>=1.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n", + "INFO: pip is looking at multiple versions of scipy to determine which version is compatible with other requirements. This could take a while.\n", + "Collecting scipy (from xgboost)\n", + " Downloading scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)\n", + "Downloading xgboost-3.2.0-py3-none-manylinux_2_28_x86_64.whl (131.7 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m131.7/131.7 MB\u001b[0m \u001b[31m210.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading shap-0.51.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m104.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading slicer-0.0.8-py3-none-any.whl (15 kB)\n", + "Downloading pandas-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (10.9 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10.9/10.9 MB\u001b[0m \u001b[31m190.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading numpy-2.4.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m16.6/16.6 MB\u001b[0m \u001b[31m174.3 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (8.9 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m180.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m179.1 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n", + "Downloading pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl (47.6 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m47.6/47.6 MB\u001b[0m \u001b[31m203.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0mm0:00:01\u001b[0m\n", + "\u001b[?25hDownloading fastparquet-2026.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m161.4 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading cramjam-2.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m168.1 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (35.2 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m35.2/35.2 MB\u001b[0m \u001b[31m198.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading cloudpickle-3.1.2-py3-none-any.whl (22 kB)\n", + "Downloading llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m205.4 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0mm0:00:01\u001b[0m\n", + "\u001b[?25hDownloading numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.8 MB)\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m179.7 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: slicer, pyarrow, numpy, llvmlite, cramjam, cloudpickle, scipy, pandas, numba, xgboost, scikit-learn, matplotlib, fastparquet, shap, seaborn\n", + "\u001b[2K Attempting uninstall: numpy\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", + "\u001b[2K Found existing installation: numpy 1.26.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", + "\u001b[2K Uninstalling numpy-1.26.4:\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", + "\u001b[2K Successfully uninstalled numpy-1.26.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", + "\u001b[2K Attempting uninstall: scipy[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]]\n", + "\u001b[2K Found existing installation: scipy 1.11.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]\n", + "\u001b[2K Uninstalling scipy-1.11.4:90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]\n", + "\u001b[2K Successfully uninstalled scipy-1.11.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]\n", + "\u001b[2K Attempting uninstall: pandas90m\u257a\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 6/15\u001b[0m [scipy]\n", + "\u001b[2K Found existing installation: pandas 2.1.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 6/15\u001b[0m [scipy]\n", + "\u001b[2K Uninstalling pandas-2.1.4:\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 7/15\u001b[0m [pandas]\n", + "\u001b[2K Successfully uninstalled pandas-2.1.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 7/15\u001b[0m [pandas]\n", + "\u001b[2K Attempting uninstall: scikit-learn\u001b[90m\u257a\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 9/15\u001b[0m [xgboost]\n", + "\u001b[2K Found existing installation: scikit-learn 1.3.2\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 9/15\u001b[0m [xgboost]\n", + "\u001b[2K Uninstalling scikit-learn-1.3.2:0m\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", + "\u001b[2K Successfully uninstalled scikit-learn-1.3.2\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", + "\u001b[2K Attempting uninstall: matplotlib\u001b[0m\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", + "\u001b[2K Found existing installation: matplotlib 3.8.2\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", + "\u001b[2K Uninstalling matplotlib-3.8.2:\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", + "\u001b[2K Successfully uninstalled matplotlib-3.8.2\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m11/15\u001b[0m [matplotlib]\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m15/15\u001b[0m [seaborn]4/15\u001b[0m [seaborn]uet]\n", + "\u001b[1A\u001b[2KSuccessfully installed cloudpickle-3.1.2 cramjam-2.11.0 fastparquet-2026.3.0 llvmlite-0.46.0 matplotlib-3.10.8 numba-0.64.0 numpy-2.4.3 pandas-3.0.1 pyarrow-23.0.1 scikit-learn-1.8.0 scipy-1.17.1 seaborn-0.13.2 shap-0.51.0 slicer-0.0.8 xgboost-3.2.0\n" + ] + } + ], + "source": [ + "#cell-0: Install Dependencies\n", + "!pip install xgboost shap pandas numpy scikit-learn matplotlib seaborn pyarrow fastparquet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Launching Training on 32 Cores...\n" + ] + } + ], + "source": [ + "# Cell 1: Imports & Hardware Configuration\n", + "import pandas as pd\n", + "import numpy as np\n", + "import xgboost as xgb\n", + "import shap\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import multiprocessing\n", + "import gc\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import roc_auc_score, average_precision_score, classification_report, confusion_matrix, f1_score, balanced_accuracy_score, roc_curve, auc\n", + "\n", + "# 1. Hardware Setup\n", + "n_cores = multiprocessing.cpu_count()\n", + "print(f\"Launching Training on {n_cores} Cores...\")\n", + "\n", + "# 2. Global Config\n", + "SEED = 42\n", + "pd.set_option('display.max_columns', None)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u26a1 Loading Parquet Files...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Shape: (299999, 66)\n", + "Test Shape: (14073, 66)\n", + "\n", + "Using 10 Tier-1 Features:\n", + "['chrom', 'pos', 'ref', 'alt', 'gnomad_af', 'GERP_91_mammals_rankscore', 'phyloP100way_vertebrate_rankscore', 'phyloP470way_mammalian_rankscore', 'phastCons470way_mammalian_rankscore', 'phastCons17way_primate_rankscore']\n", + " -> Encoding labels to 0 (Benign) and 1 (Pathogenic)...\n", + " -> Train unique labels: [1 0]\n", + " -> Test unique labels: [0 1]\n", + "Labels successfully encoded to integers.\n" + ] + } + ], + "source": [ + "# Cell 2: Load Data & Define Features (FIXED)\n", + "print(\"\u26a1 Loading Parquet Files...\")\n", + "\n", + "# Load your specific files\n", + "train_df = pd.read_parquet(\"Balanced_300k_SEQUENCES.parquet\")\n", + "test_df = pd.read_parquet(\"test_enriched_SEQUENCES.parquet\")\n", + "\n", + "print(f\"Train Shape: {train_df.shape}\")\n", + "print(f\"Test Shape: {test_df.shape}\")\n", + "\n", + "# ---- Define Tier-1 Features (Physics/Evolution only) ----\n", + "tier1_features = [\n", + " \"chrom\", \"pos\", \"ref\", \"alt\", \"gnomad_af\",\n", + " \"GERP_91_mammals_rankscore\", # keep (weak but not harmful)\n", + " \"phyloP100way_vertebrate_rankscore\", # strong\n", + " \"phyloP470way_mammalian_rankscore\", # moderate\n", + " \"phastCons470way_mammalian_rankscore\",\n", + " \"phastCons17way_primate_rankscore\"\n", + "]\n", + "\n", + "# Filter to keep only what exists in your DF\n", + "features = [c for c in tier1_features if c in train_df.columns]\n", + "print(f\"\\nUsing {len(features)} Tier-1 Features:\")\n", + "print(features)\n", + "\n", + "# ---- FIXED LABEL MAPPING ----\n", + "# We force conversion to string first to handle both 'category' and 'object' types\n", + "print(\" -> Encoding labels to 0 (Benign) and 1 (Pathogenic)...\")\n", + "\n", + "# 1. Force strings and strip whitespace to be safe\n", + "train_df[\"clean_label\"] = train_df[\"clean_label\"].astype(str).str.strip()\n", + "test_df[\"clean_label\"] = test_df[\"clean_label\"].astype(str).str.strip()\n", + "\n", + "# 2. Map explicitly\n", + "label_map = {\"Benign\": 0, \"Pathogenic\": 1}\n", + "\n", + "train_df[\"y\"] = train_df[\"clean_label\"].map(label_map)\n", + "test_df[\"y\"] = test_df[\"clean_label\"].map(label_map)\n", + "\n", + "# 3. Safety Check: Verify we have only 0s and 1s\n", + "print(f\" -> Train unique labels: {train_df['y'].unique()}\")\n", + "print(f\" -> Test unique labels: {test_df['y'].unique()}\")\n", + "\n", + "if train_df[\"y\"].isna().any():\n", + " raise ValueError(\"\u274c Error: Some labels in Train could not be mapped! Check for typos in 'Benign'/'Pathogenic'.\")\n", + "if test_df[\"y\"].isna().any():\n", + " raise ValueError(\"\u274c Error: Some labels in Test could not be mapped! Check for typos.\")\n", + "\n", + "print(\"Labels successfully encoded to integers.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " -> Casting chrom to 'category'...\n", + " -> Casting ref to 'category'...\n", + " -> Casting alt to 'category'...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Data Ready: Train(269999), Val(30000), Hard Test(14073)\n" + ] + } + ], + "source": [ + "# Cell 3: Preprocessing & Type Casting\n", + "\n", + "# 1. Cast Categoricals for Native XGBoost Support\n", + "cat_cols = [\"chrom\", \"ref\", \"alt\"]\n", + "for col in cat_cols:\n", + " if col in features:\n", + " print(f\" -> Casting {col} to 'category'...\")\n", + " train_df[col] = train_df[col].astype(\"category\")\n", + " test_df[col] = test_df[col].astype(\"category\")\n", + "\n", + "# 2. Ensure 'pos' and scores are numeric\n", + "num_cols = [c for c in features if c not in cat_cols]\n", + "for col in num_cols:\n", + " train_df[col] = pd.to_numeric(train_df[col], errors='coerce').fillna(0)\n", + " test_df[col] = pd.to_numeric(test_df[col], errors='coerce').fillna(0)\n", + "\n", + "# 3. Create X and y matrices\n", + "X = train_df[features]\n", + "y = train_df['y']\n", + "\n", + "# Keep benchmarking columns in test_df separate, but extract X_test for prediction\n", + "X_test = test_df[features]\n", + "y_test = test_df['y']\n", + "\n", + "# 4. Stratified Split for Validation (Using 10% of Train)\n", + "X_train, X_val, y_train, y_val = train_test_split(\n", + " X, y, test_size=0.10, stratify=y, random_state=SEED\n", + ")\n", + "\n", + "print(f\"\\nData Ready: Train({len(X_train)}), Val({len(X_val)}), Hard Test({len(X_test)})\")\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udee0 Configuring XGBoost SOTA Hyperparameters...\n" + ] + } + ], + "source": [ + "# Cell 4: Initialize SOTA Model\n", + "print(\"\ud83d\udee0 Configuring XGBoost SOTA Hyperparameters...\")\n", + "\n", + "model = xgb.XGBClassifier(\n", + " # --- Architecture ---\n", + " n_estimators=10000, # High ceiling (relies on early stopping)\n", + " learning_rate=0.01, # Very slow learning to squeeze out signal\n", + " max_depth=10, # Deep trees for complex genomic interactions\n", + " \n", + " # --- Regularization (Prevent Overfitting) ---\n", + " gamma=1.5, # High gamma: Conservative splitting (crucial for hard test sets)\n", + " min_child_weight=5, # Requires 5 samples to make a leaf\n", + " reg_alpha=0.5, # L1 Regularization\n", + " reg_lambda=1.0, # L2 Regularization\n", + " \n", + " # --- Sampling (Stochastic Gradient Boosting) ---\n", + " subsample=0.8, # Use 80% of rows per tree\n", + " colsample_bytree=0.8, # Use 80% of features per tree\n", + " \n", + " # --- Speed & Hardware ---\n", + " n_jobs=n_cores, # Use all 32 cores\n", + " tree_method='hist', # Histogram-based (Fastest)\n", + " enable_categorical=True, # Native categorical support\n", + " \n", + " # --- Objective ---\n", + " objective='binary:logistic',\n", + " eval_metric=['auc', 'logloss'],\n", + " early_stopping_rounds=200, # Patience\n", + " random_state=SEED\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting Training (this will take time due to LR=0.01)...\n", + "[0]\tvalidation_0-auc:0.96589\tvalidation_0-logloss:0.68516\tvalidation_1-auc:0.96544\tvalidation_1-logloss:0.68518\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[500]\tvalidation_0-auc:0.97908\tvalidation_0-logloss:0.18481\tvalidation_1-auc:0.97640\tvalidation_1-logloss:0.19517\n", + "[1000]\tvalidation_0-auc:0.98221\tvalidation_0-logloss:0.17017\tvalidation_1-auc:0.97810\tvalidation_1-logloss:0.18719\n", + "[1500]\tvalidation_0-auc:0.98369\tvalidation_0-logloss:0.16353\tvalidation_1-auc:0.97882\tvalidation_1-logloss:0.18417\n", + "[2000]\tvalidation_0-auc:0.98446\tvalidation_0-logloss:0.16001\tvalidation_1-auc:0.97909\tvalidation_1-logloss:0.18296\n", + "[2500]\tvalidation_0-auc:0.98493\tvalidation_0-logloss:0.15781\tvalidation_1-auc:0.97922\tvalidation_1-logloss:0.18237\n", + "[3000]\tvalidation_0-auc:0.98525\tvalidation_0-logloss:0.15626\tvalidation_1-auc:0.97929\tvalidation_1-logloss:0.18202\n", + "[3500]\tvalidation_0-auc:0.98549\tvalidation_0-logloss:0.15512\tvalidation_1-auc:0.97935\tvalidation_1-logloss:0.18178\n", + "[4000]\tvalidation_0-auc:0.98569\tvalidation_0-logloss:0.15414\tvalidation_1-auc:0.97939\tvalidation_1-logloss:0.18156\n", + "[4500]\tvalidation_0-auc:0.98585\tvalidation_0-logloss:0.15336\tvalidation_1-auc:0.97943\tvalidation_1-logloss:0.18136\n", + "[5000]\tvalidation_0-auc:0.98600\tvalidation_0-logloss:0.15262\tvalidation_1-auc:0.97945\tvalidation_1-logloss:0.18125\n", + "[5500]\tvalidation_0-auc:0.98612\tvalidation_0-logloss:0.15201\tvalidation_1-auc:0.97947\tvalidation_1-logloss:0.18117\n", + "[6000]\tvalidation_0-auc:0.98621\tvalidation_0-logloss:0.15157\tvalidation_1-auc:0.97948\tvalidation_1-logloss:0.18110\n", + "[6500]\tvalidation_0-auc:0.98631\tvalidation_0-logloss:0.15106\tvalidation_1-auc:0.97949\tvalidation_1-logloss:0.18103\n", + "[7000]\tvalidation_0-auc:0.98640\tvalidation_0-logloss:0.15059\tvalidation_1-auc:0.97951\tvalidation_1-logloss:0.18097\n", + "[7500]\tvalidation_0-auc:0.98649\tvalidation_0-logloss:0.15015\tvalidation_1-auc:0.97952\tvalidation_1-logloss:0.18091\n", + "[8000]\tvalidation_0-auc:0.98656\tvalidation_0-logloss:0.14978\tvalidation_1-auc:0.97953\tvalidation_1-logloss:0.18085\n", + "[8500]\tvalidation_0-auc:0.98661\tvalidation_0-logloss:0.14951\tvalidation_1-auc:0.97953\tvalidation_1-logloss:0.18084\n", + "[8582]\tvalidation_0-auc:0.98662\tvalidation_0-logloss:0.14946\tvalidation_1-auc:0.97953\tvalidation_1-logloss:0.18084\n", + "\n", + " Training Complete! Best Iteration: 8382\n" + ] + } + ], + "source": [ + "# Cell 5: Training Run\n", + "print(\"Starting Training (this will take time due to LR=0.01)...\")\n", + "\n", + "eval_set = [(X_train, y_train), (X_val, y_val)]\n", + "\n", + "model.fit(\n", + " X_train, y_train,\n", + " eval_set=eval_set,\n", + " verbose=500 # Log every 500 trees\n", + ")\n", + "\n", + "print(f\"\\n Training Complete! Best Iteration: {model.best_iteration}\")\n", + "\n", + "# --- Plot Learning Curve ---\n", + "results = model.evals_result()\n", + "epochs = len(results['validation_0']['auc'])\n", + "x_axis = range(0, epochs)\n", + "\n", + "plt.figure(figsize=(10, 6))\n", + "plt.plot(x_axis, results['validation_0']['auc'], label='Train')\n", + "plt.plot(x_axis, results['validation_1']['auc'], label='Validation')\n", + "plt.legend()\n", + "plt.ylabel('AUC Score')\n", + "plt.xlabel('Iterations')\n", + "plt.title('XGBoost Learning Curve (Check for divergence!)')\n", + "plt.grid(True)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + ], + "source": [ + "# Cell 6: ROC-AUC Visualization\n", + "from sklearn.metrics import roc_curve, auc\n", + "\n", + "# Predict probabilities\n", + "y_prob = model.predict_proba(X_test)[:, 1]\n", + "\n", + "# Calculate ROC\n", + "fpr, tpr, _ = roc_curve(y_test, y_prob)\n", + "roc_auc = auc(fpr, tpr)\n", + "\n", + "plt.figure(figsize=(8, 8))\n", + "plt.plot(fpr, tpr, color='darkorange', lw=2, label=f' Model (AUC = {roc_auc:.4f})')\n", + "plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n", + "plt.xlim([0.0, 1.0])\n", + "plt.ylim([0.0, 1.05])\n", + "plt.xlabel('False Positive Rate')\n", + "plt.ylabel('True Positive Rate')\n", + "plt.title('ROC Curve on 14k Test Set')\n", + "plt.legend(loc=\"lower right\")\n", + "plt.grid(True, alpha=0.3)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Calculating SHAP Values...\n" + ] + } + ], + "source": [ + "# Cell 7: SHAP Explainability\n", + "print(\"\\n Calculating SHAP Values...\")\n", + "\n", + "explainer = shap.TreeExplainer(model)\n", + "\n", + "# Subsample 5000 rows for speed (calculating on all 14k is fine on 32 cores too)\n", + "shap_sample = X_test.sample(n=5000, random_state=SEED) if len(X_test) > 5000 else X_test\n", + "shap_values = explainer.shap_values(shap_sample)\n", + "\n", + "# Beeswarm Plot\n", + "plt.figure(figsize=(12, 8))\n", + "plt.title(\"SHAP Feature Importance (Global View)\")\n", + "shap.summary_plot(shap_values, shap_sample, show=False)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==========================================\n", + " XGBOOST MODEL PERFORMANCE\n", + "==========================================\n", + "ROC-AUC : 0.766809\n", + "PR-AUC : 0.771672\n", + "\n", + "==========================================\n", + " COMPARISION WITH ESTABLISHED TOOLS\n", + "==========================================\n", + "CADD ROC-AUC = 0.592509 PR-AUC = 0.704837\n", + "REVEL ROC-AUC = 0.587114 PR-AUC = 0.600423\n", + "AlphaMissense ROC-AUC = 0.592032 PR-AUC = 0.603459\n", + "PrimateAI ROC-AUC = 0.566141 PR-AUC = 0.581314\n", + "MPC ROC-AUC = 0.559678 PR-AUC = 0.578559\n", + "ClinPred ROC-AUC = 0.590860 PR-AUC = 0.606637\n", + "BayesDel ROC-AUC = 0.594821 PR-AUC = 0.706338\n", + "\n", + "==========================================\n", + " FINAL LEADERBOARD\n", + "==========================================\n", + " 1. My XGBoost AUC=0.766809 PR=0.771672\n", + " 2. BayesDel AUC=0.594821 PR=0.706338\n", + " 3. CADD AUC=0.592509 PR=0.704837\n", + " 4. AlphaMissense AUC=0.592032 PR=0.603459\n", + " 5. ClinPred AUC=0.590860 PR=0.606637\n", + " 6. REVEL AUC=0.587114 PR=0.600423\n", + " 7. PrimateAI AUC=0.566141 PR=0.581314\n", + " 8. MPC AUC=0.559678 PR=0.578559\n" + ] + } + ], + "source": [ + "# Cell 8: Benchmarking Logic + Plots\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve\n", + "\n", + "roc_auc = roc_auc_score(y_test, y_prob)\n", + "pr_auc = average_precision_score(y_test, y_prob)\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(f\" XGBOOST MODEL PERFORMANCE\")\n", + "print(\"==========================================\")\n", + "print(f\"ROC-AUC : {roc_auc:.6f}\")\n", + "print(f\"PR-AUC : {pr_auc:.6f}\")\n", + "\n", + "# ---- Benchmark against DBNSFP models ----\n", + "benchmarks = {\n", + " \"CADD\": \"CADD_raw_rankscore\",\n", + " \"REVEL\": \"REVEL_rankscore\",\n", + " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", + " \"PrimateAI\": \"PrimateAI_rankscore\",\n", + " \"MPC\": \"MPC_rankscore\",\n", + " \"ClinPred\": \"ClinPred_rankscore\",\n", + " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", + "}\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\" COMPARISION WITH ESTABLISHED TOOLS\")\n", + "print(\"==========================================\")\n", + "\n", + "benchmark_results = []\n", + "\n", + "for name, col in benchmarks.items():\n", + " if col not in test_df.columns:\n", + " continue\n", + " \n", + " scores = test_df[col].fillna(0).values\n", + " \n", + " try:\n", + " b_auc = roc_auc_score(y_test, scores)\n", + " b_pr = average_precision_score(y_test, scores)\n", + " benchmark_results.append((name, b_auc, b_pr))\n", + " print(f\"{name:14s} ROC-AUC = {b_auc:.6f} PR-AUC = {b_pr:.6f}\")\n", + " except ValueError:\n", + " pass\n", + "\n", + "# ---- Leaderboard ----\n", + "print(\"\\n==========================================\")\n", + "print(\" FINAL LEADERBOARD\")\n", + "print(\"==========================================\")\n", + "\n", + "ranking = sorted([(\"My XGBoost\", roc_auc, pr_auc)] + benchmark_results, key=lambda x: x[1], reverse=True)\n", + "\n", + "for i, (name, a, p) in enumerate(ranking, 1):\n", + " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", + "\n", + "# ---- ROC Curve ----\n", + "fpr, tpr, _ = roc_curve(y_test, y_prob)\n", + "\n", + "plt.figure()\n", + "plt.plot(fpr, tpr)\n", + "plt.xlabel(\"False Positive Rate\")\n", + "plt.ylabel(\"True Positive Rate\")\n", + "plt.title(f\"ROC Curve (AUC = {roc_auc:.4f})\")\n", + "plt.grid()\n", + "plt.show()\n", + "\n", + "# ---- Precision-Recall Curve ----\n", + "precision, recall, _ = precision_recall_curve(y_test, y_prob)\n", + "\n", + "plt.figure()\n", + "plt.plot(recall, precision)\n", + "plt.xlabel(\"Recall\")\n", + "plt.ylabel(\"Precision\")\n", + "plt.title(f\"Precision-Recall Curve (AUC = {pr_auc:.4f})\")\n", + "plt.grid()\n", + "plt.show()\n", + "\n", + "# ---- Bar Plot Comparison ----\n", + "df_bench = pd.DataFrame(\n", + " [(\"My XGBoost\", roc_auc, pr_auc)] + benchmark_results,\n", + " columns=[\"Model\", \"ROC-AUC\", \"PR-AUC\"]\n", + ")\n", + "\n", + "df_bench = df_bench.sort_values(by=\"ROC-AUC\", ascending=False)\n", + "\n", + "plt.figure(figsize=(10,5))\n", + "plt.bar(df_bench[\"Model\"], df_bench[\"ROC-AUC\"])\n", + "plt.xticks(rotation=45)\n", + "plt.ylabel(\"ROC-AUC\")\n", + "plt.title(\"Model Comparison (ROC-AUC)\")\n", + "plt.grid(axis='y')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==============================\n", + " VALIDATION RESULTS (TUNING)\n", + "==============================\n", + "Optimal Threshold : 0.488\n", + "Best F1 (Val) : 0.9242\n", + "Balanced Acc (Val): 0.9268\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==============================\n", + " TEST PERFORMANCE (FINAL)\n", + "==============================\n", + "Threshold Used : 0.488\n", + "F1 (Test) : 0.7215\n", + "Balanced Acc (Test): 0.6174\n", + "\n", + "Confusion Matrix (Test):\n", + "[[1754 5258]\n", + " [ 109 6952]]\n", + "\n", + "Classification Report (Test):\n", + " precision recall f1-score support\n", + "\n", + " 0 0.9415 0.2501 0.3953 7012\n", + " 1 0.5694 0.9846 0.7215 7061\n", + "\n", + " accuracy 0.6186 14073\n", + " macro avg 0.7554 0.6174 0.5584 14073\n", + "weighted avg 0.7548 0.6186 0.5590 14073\n", + "\n" + ] + } + ], + "source": [ + "# Cell 9: Threshold Tuning (FULL VERSION - VAL + TEST, NO LEAKAGE)\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import (\n", + " f1_score, balanced_accuracy_score,\n", + " confusion_matrix, classification_report\n", + ")\n", + "\n", + "# =========================================================\n", + "# 1. GET PROBABILITIES\n", + "# =========================================================\n", + "y_val_prob = model.predict_proba(X_val)[:, 1]\n", + "y_test_prob = model.predict_proba(X_test)[:, 1]\n", + "\n", + "thresholds = np.linspace(0.01, 0.99, 200)\n", + "\n", + "# =========================================================\n", + "# 2. TUNE THRESHOLD ON VALIDATION SET\n", + "# =========================================================\n", + "best_f1 = 0\n", + "best_t = 0\n", + "best_bal_acc = 0\n", + "\n", + "f1_scores_val = []\n", + "bal_scores_val = []\n", + "\n", + "for t in thresholds:\n", + " y_pred_val = (y_val_prob >= t).astype(int)\n", + "\n", + " f1 = f1_score(y_val, y_pred_val)\n", + " bal = balanced_accuracy_score(y_val, y_pred_val)\n", + "\n", + " f1_scores_val.append(f1)\n", + " bal_scores_val.append(bal)\n", + "\n", + " if f1 > best_f1:\n", + " best_f1 = f1\n", + " best_t = t\n", + " best_bal_acc = bal\n", + "\n", + "print(\"\\n==============================\")\n", + "print(\" VALIDATION RESULTS (TUNING)\")\n", + "print(\"==============================\")\n", + "print(f\"Optimal Threshold : {best_t:.3f}\")\n", + "print(f\"Best F1 (Val) : {best_f1:.4f}\")\n", + "print(f\"Balanced Acc (Val): {best_bal_acc:.4f}\")\n", + "\n", + "# =========================================================\n", + "# 3. PLOT VALIDATION CURVE\n", + "# =========================================================\n", + "plt.figure(figsize=(8,5))\n", + "plt.plot(thresholds, f1_scores_val, label=\"F1 (Val)\")\n", + "plt.plot(thresholds, bal_scores_val, label=\"Balanced Acc (Val)\")\n", + "plt.axvline(best_t, linestyle='--', label=f\"Best T = {best_t:.3f}\")\n", + "\n", + "plt.xlabel(\"Threshold\")\n", + "plt.ylabel(\"Score\")\n", + "plt.title(\"Threshold Optimization (Validation Set)\")\n", + "plt.legend()\n", + "plt.grid()\n", + "plt.show()\n", + "\n", + "# =========================================================\n", + "# 4. APPLY THRESHOLD ON TEST SET\n", + "# =========================================================\n", + "y_pred_test = (y_test_prob >= best_t).astype(int)\n", + "\n", + "f1_test = f1_score(y_test, y_pred_test)\n", + "bal_test = balanced_accuracy_score(y_test, y_pred_test)\n", + "\n", + "print(\"\\n==============================\")\n", + "print(\" TEST PERFORMANCE (FINAL)\")\n", + "print(\"==============================\")\n", + "print(f\"Threshold Used : {best_t:.3f}\")\n", + "print(f\"F1 (Test) : {f1_test:.4f}\")\n", + "print(f\"Balanced Acc (Test): {bal_test:.4f}\")\n", + "\n", + "# =========================================================\n", + "# 5. TEST CONFUSION MATRIX\n", + "# =========================================================\n", + "cm_test = confusion_matrix(y_test, y_pred_test)\n", + "\n", + "print(\"\\nConfusion Matrix (Test):\")\n", + "print(cm_test)\n", + "\n", + "print(\"\\nClassification Report (Test):\")\n", + "print(classification_report(y_test, y_pred_test, digits=4))\n", + "\n", + "plt.figure(figsize=(6,5))\n", + "plt.imshow(cm_test)\n", + "\n", + "for i in range(cm_test.shape[0]):\n", + " for j in range(cm_test.shape[1]):\n", + " plt.text(j, i, cm_test[i, j], ha='center', va='center')\n", + "\n", + "plt.xlabel(\"Predicted Label\")\n", + "plt.ylabel(\"True Label\")\n", + "plt.title(\"Confusion Matrix (Test Set)\")\n", + "plt.xticks([0,1], [\"Class 0\", \"Class 1\"])\n", + "plt.yticks([0,1], [\"Class 0\", \"Class 1\"])\n", + "\n", + "plt.colorbar()\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# =========================================================\n", + "# 6. ALSO PLOT TEST CURVE (FOR ANALYSIS ONLY)\n", + "# =========================================================\n", + "f1_scores_test = []\n", + "bal_scores_test = []\n", + "\n", + "for t in thresholds:\n", + " y_pred_t = (y_test_prob >= t).astype(int)\n", + "\n", + " f1_scores_test.append(f1_score(y_test, y_pred_t))\n", + " bal_scores_test.append(balanced_accuracy_score(y_test, y_pred_t))\n", + "\n", + "plt.figure(figsize=(8,5))\n", + "plt.plot(thresholds, f1_scores_test, label=\"F1 (Test)\")\n", + "plt.plot(thresholds, bal_scores_test, label=\"Balanced Acc (Test)\")\n", + "plt.axvline(best_t, linestyle='--', label=f\"Chosen T = {best_t:.3f}\")\n", + "\n", + "plt.xlabel(\"Threshold\")\n", + "plt.ylabel(\"Score\")\n", + "plt.title(\"Test Threshold Curve (Analysis Only - Not Used for Tuning)\")\n", + "plt.legend()\n", + "plt.grid()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Generating Deep Feature Analysis...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_7026/3990650171.py:49: FutureWarning: \n", + "\n", + "Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.\n", + "\n", + " sns.barplot(x=\"gain\", y=\"feature\", data=fi.head(25), palette=\"viridis\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Top Features by GAIN:\n", + "\n", + " feature gain split\n", + " phyloP100way_vertebrate_rankscore 178.044281 32840.0\n", + " phyloP470way_mammalian_rankscore 35.505932 34554.0\n", + " gnomad_af 31.278227 54468.0\n", + " chrom 25.794613 51327.0\n", + "phastCons470way_mammalian_rankscore 10.803547 10682.0\n", + " phastCons17way_primate_rankscore 9.607828 36827.0\n", + " pos 9.198639 101420.0\n", + " alt 7.743703 21750.0\n", + " ref 6.535082 21038.0\n", + " GERP_91_mammals_rankscore 6.072751 30064.0\n" + ] + } + ], + "source": [ + "# Cell 10: Advanced Feature Importance (SHAP + Gain)\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import shap\n", + "\n", + "print(\" Generating Deep Feature Analysis...\")\n", + "\n", + "# -------------------------------------------------------\n", + "# 1. SHAP Mean Absolute Value (Global Feature Importance)\n", + "# -------------------------------------------------------\n", + "plt.figure(figsize=(10, 6))\n", + "plt.title(\"Mean Absolute SHAP Values (Global Impact on Model)\")\n", + "# plot_type=\"bar\" automatically calculates the mean absolute value for you\n", + "shap.summary_plot(shap_values, shap_sample, plot_type=\"bar\", show=False)\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# -------------------------------------------------------\n", + "# 2. XGBoost Native Feature Importance (Gain & Split)\n", + "# -------------------------------------------------------\n", + "# XGBoost handles this differently than LightGBM.\n", + "# We access the internal booster to get specific metrics.\n", + "booster = model.get_booster()\n", + "\n", + "# Get Gain (Average gain across all splits the feature is used in)\n", + "gain_scores = booster.get_score(importance_type='gain')\n", + "# Get Weight (Number of times a feature is used to split the data)\n", + "split_scores = booster.get_score(importance_type='weight')\n", + "\n", + "# Convert to DataFrame\n", + "# Note: XGBoost usually returns a dict {feature: score}, so we map it back\n", + "all_features = model.feature_names_in_\n", + "fi_data = []\n", + "\n", + "for feat in all_features:\n", + " fi_data.append({\n", + " \"feature\": feat,\n", + " \"gain\": gain_scores.get(feat, 0), # Default to 0 if not used\n", + " \"split\": split_scores.get(feat, 0)\n", + " })\n", + "\n", + "fi = pd.DataFrame(fi_data).sort_values(\"gain\", ascending=False)\n", + "\n", + "# -------------------------------------------------------\n", + "# 3. Plot Top 25 Features by GAIN\n", + "# -------------------------------------------------------\n", + "plt.figure(figsize=(10, 8))\n", + "sns.barplot(x=\"gain\", y=\"feature\", data=fi.head(25), palette=\"viridis\")\n", + "plt.title(\"Top Features by Gain (Structural Importance)\")\n", + "plt.xlabel(\"Average Gain (Contribution to Accuracy)\")\n", + "plt.ylabel(\"Feature\")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# -------------------------------------------------------\n", + "# 4. Print Table\n", + "# -------------------------------------------------------\n", + "print(\" Top Features by GAIN:\\n\")\n", + "print(fi.head(30).to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udcbe Saving Model Artifacts...\n", + " -> Saved Native JSON: pathopreter_sota_xgboost.json\n", + " -> Saved Joblib Pickle: pathopreter_sota_xgboost.pkl\n", + " -> Saved Feature List: pathopreter_features.txt\n", + "\n", + " Verification - Files in Root:\n", + " - pathopreter_features.txt : 0.00 MB\n", + " - pathopreter_sota_xgboost.pkl : 35.90 MB\n", + " - pathopreter_sota_xgboost.json : 50.47 MB\n" + ] + } + ], + "source": [ + "# Cell 11: Save SOTA Model to Root\n", + "import joblib\n", + "import os\n", + "\n", + "print(\"\ud83d\udcbe Saving Model Artifacts...\")\n", + "\n", + "# 1. Save as Native XGBoost JSON (Best for portability/deployment)\n", + "# This saves the tree structure, feature names, and weights\n", + "model_filename_json = \"pathopreter_sota_xgboost.json\"\n", + "model.save_model(model_filename_json)\n", + "print(f\" -> Saved Native JSON: {model_filename_json}\")\n", + "\n", + "# 2. Save as Joblib Pickle (Best for Python reloading)\n", + "# This saves the entire Sklearn wrapper, including hyperparameters and configuration\n", + "model_filename_pkl = \"pathopreter_sota_xgboost.pkl\"\n", + "joblib.dump(model, model_filename_pkl)\n", + "print(f\" -> Saved Joblib Pickle: {model_filename_pkl}\")\n", + "\n", + "# 3. Save Feature List (Crucial for inference consistency)\n", + "# If you reload later, you MUST ensure columns are in this exact order\n", + "feature_filename = \"pathopreter_features.txt\"\n", + "with open(feature_filename, \"w\") as f:\n", + " for feat in features:\n", + " f.write(f\"{feat}\\n\")\n", + "print(f\" -> Saved Feature List: {feature_filename}\")\n", + "\n", + "# --- Verification ---\n", + "print(\"\\n Verification - Files in Root:\")\n", + "files = [f for f in os.listdir('.') if \"pathopreter\" in f]\n", + "for f in files:\n", + " size_mb = os.path.getsize(f) / (1024 * 1024)\n", + " print(f\" - {f} : {size_mb:.2f} MB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " STARTING EXPERIMENT 2: ROBUSTNESS CHECK (NO LOCATION BIAS)...\n", + " -> Original Feature Count: 10\n", + " -> Robust Feature Count: 8\n", + " -> DROPPED: ['chrom', 'pos']\n", + "\ud83d\udd25 Training Robust Model (This may take a moment)...\n", + "[0]\tvalidation_0-auc:0.93564\tvalidation_0-logloss:0.68610\tvalidation_1-auc:0.93442\tvalidation_1-logloss:0.68610\n", + "[500]\tvalidation_0-auc:0.94672\tvalidation_0-logloss:0.27310\tvalidation_1-auc:0.94417\tvalidation_1-logloss:0.27910\n", + "[1000]\tvalidation_0-auc:0.94798\tvalidation_0-logloss:0.26800\tvalidation_1-auc:0.94445\tvalidation_1-logloss:0.27766\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1500]\tvalidation_0-auc:0.94832\tvalidation_0-logloss:0.26676\tvalidation_1-auc:0.94446\tvalidation_1-logloss:0.27752\n", + "[2000]\tvalidation_0-auc:0.94851\tvalidation_0-logloss:0.26601\tvalidation_1-auc:0.94451\tvalidation_1-logloss:0.27748\n", + "[2500]\tvalidation_0-auc:0.94865\tvalidation_0-logloss:0.26549\tvalidation_1-auc:0.94456\tvalidation_1-logloss:0.27744\n", + "[3000]\tvalidation_0-auc:0.94878\tvalidation_0-logloss:0.26498\tvalidation_1-auc:0.94459\tvalidation_1-logloss:0.27742\n", + "[3104]\tvalidation_0-auc:0.94880\tvalidation_0-logloss:0.26489\tvalidation_1-auc:0.94460\tvalidation_1-logloss:0.27741\n", + "\n", + "==========================================\n", + " ROBUSTNESS CHECK RESULTS\n", + "==========================================\n", + "Baseline AUC (With Loc) : 0.766809\n", + "Robust AUC (No Loc) : 0.760368\n", + "Difference : -0.006441\n", + "\n", + "\u2705 VERDICT: PASSED. The model is learning BIOLOGY, not location.\n", + " It still beats REVEL/AlphaMissense without cheating!\n", + "\n", + "==========================================\n", + " NEW LEADERBOARD (ROBUST MODEL)\n", + "==========================================\n", + " 1. My Robust XGB AUC=0.760368 PR=0.769851\n", + " 2. BayesDel AUC=0.594821 PR=0.706338\n", + " 3. CADD AUC=0.592509 PR=0.704837\n", + " 4. AlphaMissense AUC=0.592032 PR=0.603459\n", + " 5. ClinPred AUC=0.590860 PR=0.606637\n", + " 6. REVEL AUC=0.587114 PR=0.600423\n", + " 7. PrimateAI AUC=0.566141 PR=0.581314\n", + " 8. MPC AUC=0.559678 PR=0.578559\n", + "\n", + "\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\n", + " feature gain\n", + " phyloP100way_vertebrate_rankscore 230.754532\n", + " phyloP470way_mammalian_rankscore 59.637676\n", + " gnomad_af 53.463173\n", + " phastCons17way_primate_rankscore 17.051365\n", + "phastCons470way_mammalian_rankscore 13.066975\n", + " ref 11.242688\n", + " alt 10.989820\n", + " GERP_91_mammals_rankscore 7.565748\n" + ] + } + ], + "source": [ + "# Cell 12: Robustness Experiment (Drop Chrom/Pos) & Re-Benchmark\n", + "import pandas as pd\n", + "import xgboost as xgb\n", + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import roc_auc_score, average_precision_score\n", + "\n", + "print(\" STARTING EXPERIMENT 2: ROBUSTNESS CHECK (NO LOCATION BIAS)...\")\n", + "\n", + "# ==========================================\n", + "# 1. Define Robust Feature Set\n", + "# ==========================================\n", + "# Remove 'chrom' and 'pos' to force the model to learn biology only\n", + "features_robust = [f for f in features if f not in ['chrom', 'pos']]\n", + "\n", + "print(f\" -> Original Feature Count: {len(features)}\")\n", + "print(f\" -> Robust Feature Count: {len(features_robust)}\")\n", + "print(f\" -> DROPPED: ['chrom', 'pos']\")\n", + "\n", + "# ==========================================\n", + "# 2. Prepare Data (Robust Split)\n", + "# ==========================================\n", + "X_rob = train_df[features_robust]\n", + "y_rob = train_df['y']\n", + "\n", + "X_test_rob = test_df[features_robust]\n", + "y_test_rob = test_df['y']\n", + "\n", + "# Stratified Validation Split\n", + "X_train_r, X_val_r, y_train_r, y_val_r = train_test_split(\n", + " X_rob, y_rob, test_size=0.10, stratify=y_rob, random_state=SEED\n", + ")\n", + "\n", + "# ==========================================\n", + "# 3. Train Robust Model (SOTA Params)\n", + "# ==========================================\n", + "print(\"\ud83d\udd25 Training Robust Model (This may take a moment)...\")\n", + "\n", + "model_robust = xgb.XGBClassifier(\n", + " n_estimators=10000,\n", + " learning_rate=0.01,\n", + " max_depth=10,\n", + " gamma=1.5,\n", + " min_child_weight=5,\n", + " reg_alpha=0.5,\n", + " reg_lambda=1.0,\n", + " subsample=0.8,\n", + " colsample_bytree=0.8,\n", + " n_jobs=n_cores,\n", + " tree_method='hist',\n", + " enable_categorical=True, # Still needed for ref/alt\n", + " early_stopping_rounds=200,\n", + " random_state=SEED,\n", + " objective='binary:logistic',\n", + " eval_metric=['auc', 'logloss']\n", + ")\n", + "\n", + "model_robust.fit(\n", + " X_train_r, y_train_r,\n", + " eval_set=[(X_train_r, y_train_r), (X_val_r, y_val_r)],\n", + " verbose=500\n", + ")\n", + "\n", + "# ==========================================\n", + "# 4. Evaluation & Comparison\n", + "# ==========================================\n", + "print(\"\\n==========================================\")\n", + "print(\" ROBUSTNESS CHECK RESULTS\")\n", + "print(\"==========================================\")\n", + "\n", + "# Predict\n", + "y_prob_r = model_robust.predict_proba(X_test_rob)[:, 1]\n", + "\n", + "# Metrics\n", + "roc_auc_r = roc_auc_score(y_test_rob, y_prob_r)\n", + "pr_auc_r = average_precision_score(y_test_rob, y_prob_r)\n", + "\n", + "# Compare with Baseline (Assumes 'roc_auc' exists from previous cell)\n", + "# If running fresh, we assume the previous score was ~0.7477\n", + "baseline_auc = roc_auc if 'roc_auc' in locals() else 0.747704\n", + "delta = roc_auc_r - baseline_auc\n", + "\n", + "print(f\"Baseline AUC (With Loc) : {baseline_auc:.6f}\")\n", + "print(f\"Robust AUC (No Loc) : {roc_auc_r:.6f}\")\n", + "print(f\"Difference : {delta:.6f}\")\n", + "\n", + "if roc_auc_r > 0.70:\n", + " print(\"\\n\u2705 VERDICT: PASSED. The model is learning BIOLOGY, not location.\")\n", + " print(\" It still beats REVEL/AlphaMissense without cheating!\")\n", + "else:\n", + " print(\"\\n\u26a0\ufe0f VERDICT: WARNING. Significant drop. Re-evaluate features.\")\n", + "\n", + "# ==========================================\n", + "# 5. Re-Benchmark Against SOTA Tools\n", + "# ==========================================\n", + "benchmarks = {\n", + " \"CADD\": \"CADD_raw_rankscore\",\n", + " \"REVEL\": \"REVEL_rankscore\",\n", + " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", + " \"PrimateAI\": \"PrimateAI_rankscore\",\n", + " \"MPC\": \"MPC_rankscore\",\n", + " \"ClinPred\": \"ClinPred_rankscore\",\n", + " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", + "}\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\" NEW LEADERBOARD (ROBUST MODEL)\")\n", + "print(\"==========================================\")\n", + "\n", + "benchmark_results = []\n", + "\n", + "for name, col in benchmarks.items():\n", + " if col not in test_df.columns:\n", + " continue\n", + " \n", + " scores = test_df[col].fillna(0).values\n", + " \n", + " try:\n", + " b_auc = roc_auc_score(y_test_rob, scores)\n", + " b_pr = average_precision_score(y_test_rob, scores)\n", + " benchmark_results.append((name, b_auc, b_pr))\n", + " except ValueError:\n", + " pass\n", + "\n", + "# Add our Robust Model to the list\n", + "ranking = sorted([(\"My Robust XGB\", roc_auc_r, pr_auc_r)] + benchmark_results, key=lambda x: x[1], reverse=True)\n", + "\n", + "for i, (name, a, p) in enumerate(ranking, 1):\n", + " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", + "\n", + "# ==========================================\n", + "# 6. Check New Top Features (What took over?)\n", + "# ==========================================\n", + "booster = model_robust.get_booster()\n", + "gain_scores = booster.get_score(importance_type='gain')\n", + "fi_rob = pd.DataFrame([\n", + " {\"feature\": k, \"gain\": v} for k,v in gain_scores.items()\n", + "]).sort_values(\"gain\", ascending=False)\n", + "\n", + "print(\"\\n\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\")\n", + "print(fi_rob.head(10).to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udcbe Saving PathoPreter Artifacts...\n", + " -> Saved SOTA Model: pathopreter_sota_xgboost.json & pathopreter_sota_xgboost.pkl\n", + " -> Saved Robust Model: pathopreter_robust_biology.json & pathopreter_robust_biology.pkl\n", + "\n", + " Final Inventory (Root Directory):\n", + " - pathopreter_features.txt : 0.00 MB\n", + " - pathopreter_features_robust.txt : 0.00 MB\n", + " - pathopreter_features_sota.txt : 0.00 MB\n", + " - pathopreter_robust_biology.json : 19.56 MB\n", + " - pathopreter_robust_biology.pkl : 13.17 MB\n", + " - pathopreter_sota_xgboost.json : 50.47 MB\n", + " - pathopreter_sota_xgboost.pkl : 35.90 MB\n" + ] + } + ], + "source": [ + "# Cell 13: Save All Models with 'PathoPreter' Naming\n", + "import joblib\n", + "import os\n", + "\n", + "print(\"\ud83d\udcbe Saving PathoPreter Artifacts...\")\n", + "\n", + "# ==========================================\n", + "# 1. Save The \"SOTA\" Model (The Competition Winner)\n", + "# ==========================================\n", + "# Native XGBoost JSON (Portable)\n", + "sota_json = \"pathopreter_sota_xgboost.json\"\n", + "model.save_model(sota_json)\n", + "\n", + "# Joblib Pickle (Python Ready)\n", + "sota_pkl = \"pathopreter_sota_xgboost.pkl\"\n", + "joblib.dump(model, sota_pkl)\n", + "\n", + "# Feature List (SOTA)\n", + "with open(\"pathopreter_features_sota.txt\", \"w\") as f:\n", + " for feat in features:\n", + " f.write(f\"{feat}\\n\")\n", + "\n", + "print(f\" -> Saved SOTA Model: {sota_json} & {sota_pkl}\")\n", + "\n", + "# ==========================================\n", + "# 2. Save The \"Robust\" Model (The Scientist)\n", + "# ==========================================\n", + "# Native XGBoost JSON (Portable)\n", + "robust_json = \"pathopreter_robust_biology.json\"\n", + "model_robust.save_model(robust_json)\n", + "\n", + "# Joblib Pickle (Python Ready)\n", + "robust_pkl = \"pathopreter_robust_biology.pkl\"\n", + "joblib.dump(model_robust, robust_pkl)\n", + "\n", + "# Feature List (Robust - No Chrom/Pos)\n", + "with open(\"pathopreter_features_robust.txt\", \"w\") as f:\n", + " for feat in features_robust:\n", + " f.write(f\"{feat}\\n\")\n", + "\n", + "print(f\" -> Saved Robust Model: {robust_json} & {robust_pkl}\")\n", + "\n", + "# ==========================================\n", + "# 3. Verification\n", + "# ==========================================\n", + "print(\"\\n Final Inventory (Root Directory):\")\n", + "files = [f for f in os.listdir('.') if \"pathopreter\" in f]\n", + "for f in sorted(files):\n", + " size_mb = os.path.getsize(f) / (1024 * 1024)\n", + " print(f\" - {f:<35} : {size_mb:.2f} MB\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Loading ONLY bottom 100000 rows from: 900k_test.parquet\n", + " Loading trained robust model from JSON...\n", + " Model loaded.\n", + " -> Training feature count: 8\n", + " -> Loaded shape: (100000, 17)\n", + " -> Label counts: {0: 92813, 1: 7187}\n", + " -> Final X_new shape: (100000, 8)\n", + "\n", + " Running predictions...\n", + "\n", + "==========================================\n", + " RESULTS ON BOTTOM 100K ROWS (NEW DATA)\n", + "==========================================\n", + "Robust Model AUC : 0.972796\n", + "Robust Model PR-AUC : 0.724422\n", + "Baseline AUC (With Loc) : 0.766809\n", + "Robust AUC (No Loc) : 0.972796\n", + "Difference : 0.205987\n", + "\n", + "==========================================\n", + " NEW LEADERBOARD (ROBUST MODEL ON 100K)\n", + "==========================================\n", + " 1. My Robust XGB AUC=0.972796 PR=0.724422\n", + " 2. ClinPred AUC=0.960248 PR=0.592119\n", + " 3. REVEL AUC=0.955430 PR=0.582778\n", + " 4. AlphaMissense AUC=0.952558 PR=0.576359\n", + " 5. BayesDel AUC=0.935482 PR=0.688013\n", + " 6. PrimateAI AUC=0.929806 PR=0.483381\n", + " 7. MPC AUC=0.921036 PR=0.495954\n", + " 8. CADD AUC=0.898042 PR=0.656561\n", + "\n", + " Top Features in Robust Model (Pure Biology):\n", + " feature gain\n", + " phyloP100way_vertebrate_rankscore 230.754532\n", + " phyloP470way_mammalian_rankscore 59.637676\n", + " gnomad_af 53.463173\n", + " phastCons17way_primate_rankscore 17.051365\n", + "phastCons470way_mammalian_rankscore 13.066975\n", + " ref 11.242688\n", + " alt 10.989820\n", + " GERP_91_mammals_rankscore 7.565748\n", + "\n", + " Done.\n" + ] + } + ], + "source": [ + "# ============================================================\n", + "# CELL: TEST ROBUST MODEL ON BOTTOM 100K ROWS (CATEGORICAL-SAFE)\n", + "# ============================================================\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import xgboost as xgb\n", + "from sklearn.metrics import roc_auc_score, average_precision_score\n", + "\n", + "DATA_PATH = \"900k_test.parquet\"\n", + "MODEL_JSON = \"pathopreter_robust_biology.json\"\n", + "N_ROWS = 100_000\n", + "\n", + "print(f\"\\n Loading ONLY bottom {N_ROWS} rows from: {DATA_PATH}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 1) LOAD YOUR ROBUST MODEL (JSON \u2013 SAFE FORMAT)\n", + "# ------------------------------------------------------------\n", + "print(\" Loading trained robust model from JSON...\")\n", + "\n", + "booster = xgb.Booster()\n", + "booster.load_model(MODEL_JSON)\n", + "\n", + "# Exact feature list used in training\n", + "train_feature_names = booster.feature_names\n", + "\n", + "print(f\" Model loaded.\")\n", + "print(f\" -> Training feature count: {len(train_feature_names)}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 2) LOAD ONLY WHAT THE MODEL NEEDS + LABEL\n", + "# ------------------------------------------------------------\n", + "\n", + "# needed_cols = list(set(train_feature_names + [\"clean_label\"]))\n", + "needed_cols = list(set(\n", + " train_feature_names \n", + " + [\"clean_label\"] \n", + " + list(benchmarks.values())\n", + "))\n", + "df_new = (\n", + " pd.read_parquet(DATA_PATH, columns=needed_cols)\n", + " .tail(N_ROWS)\n", + " .reset_index(drop=True)\n", + ")\n", + "\n", + "# Make label exactly like training\n", + "df_new[\"y\"] = (df_new[\"clean_label\"] == \"Pathogenic\").astype(int)\n", + "\n", + "print(f\" -> Loaded shape: {df_new.shape}\")\n", + "print(f\" -> Label counts: {df_new['y'].value_counts().to_dict()}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 3) BUILD X_new IN EXACT TRAINING ORDER (CRITICAL)\n", + "# ------------------------------------------------------------\n", + "\n", + "X_new = pd.DataFrame()\n", + "\n", + "for feat in train_feature_names:\n", + " if feat in df_new.columns:\n", + " X_new[feat] = df_new[feat]\n", + " else:\n", + " # If some feature is missing in new data, add it\n", + " X_new[feat] = 0.0\n", + "\n", + "# -------- HANDLE CATEGORICAL COLUMNS SAFELY --------\n", + "for cat_col in [\"ref\", \"alt\"]:\n", + " if cat_col in X_new.columns:\n", + " # Convert to categorical\n", + " X_new[cat_col] = X_new[cat_col].astype(\"category\")\n", + "\n", + " # If NaNs exist, add \"N\" (unknown nucleotide) as a valid category\n", + " if X_new[cat_col].isna().any():\n", + " if \"N\" not in X_new[cat_col].cat.categories:\n", + " X_new[cat_col] = X_new[cat_col].cat.add_categories(\"N\")\n", + " X_new[cat_col] = X_new[cat_col].fillna(\"N\")\n", + "\n", + "# -------- FILL NA FOR NUMERIC COLUMNS ONLY --------\n", + "num_cols = X_new.select_dtypes(include=[np.number]).columns\n", + "X_new[num_cols] = X_new[num_cols].fillna(0)\n", + "\n", + "y_new = df_new[\"y\"].values\n", + "\n", + "print(f\" -> Final X_new shape: {X_new.shape}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 4) PREDICT (MATCHES YOUR TRAINING)\n", + "# ------------------------------------------------------------\n", + "print(\"\\n Running predictions...\")\n", + "\n", + "dmat = xgb.DMatrix(\n", + " X_new,\n", + " feature_names=train_feature_names,\n", + " enable_categorical=True # MATCHES YOUR TRAINING\n", + ")\n", + "\n", + "y_prob_new = booster.predict(dmat, validate_features=False)\n", + "\n", + "# ------------------------------------------------------------\n", + "# 5) METRICS (SAME STYLE AS YOUR CELL)\n", + "# ------------------------------------------------------------\n", + "roc_auc_new = roc_auc_score(y_new, y_prob_new)\n", + "pr_auc_new = average_precision_score(y_new, y_prob_new)\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\" RESULTS ON BOTTOM 100K ROWS (NEW DATA)\")\n", + "print(\"==========================================\")\n", + "print(f\"Robust Model AUC : {roc_auc_new:.6f}\")\n", + "print(f\"Robust Model PR-AUC : {pr_auc_new:.6f}\")\n", + "\n", + "# Baseline comparison (same logic as your notebook)\n", + "baseline_auc = roc_auc if 'roc_auc' in locals() else 0.747704\n", + "delta = roc_auc_new - baseline_auc\n", + "\n", + "print(f\"Baseline AUC (With Loc) : {baseline_auc:.6f}\")\n", + "print(f\"Robust AUC (No Loc) : {roc_auc_new:.6f}\")\n", + "print(f\"Difference : {delta:.6f}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 6) BENCHMARK AGAINST SOTA (SAME AS YOUR CELL)\n", + "# ------------------------------------------------------------\n", + "benchmarks = {\n", + " \"CADD\": \"CADD_raw_rankscore\",\n", + " \"REVEL\": \"REVEL_rankscore\",\n", + " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", + " \"PrimateAI\": \"PrimateAI_rankscore\",\n", + " \"MPC\": \"MPC_rankscore\",\n", + " \"ClinPred\": \"ClinPred_rankscore\",\n", + " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", + "}\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\" NEW LEADERBOARD (ROBUST MODEL ON 100K)\")\n", + "print(\"==========================================\")\n", + "\n", + "benchmark_results = []\n", + "\n", + "for name, col in benchmarks.items():\n", + " if col not in df_new.columns:\n", + " print(f\" Skipping {name} (column missing)\")\n", + " continue\n", + "\n", + " scores = df_new[col].fillna(0).values\n", + "\n", + " try:\n", + " b_auc = roc_auc_score(y_new, scores)\n", + " b_pr = average_precision_score(y_new, scores)\n", + " benchmark_results.append((name, b_auc, b_pr))\n", + " except ValueError:\n", + " pass\n", + "\n", + "ranking = sorted(\n", + " [(\"My Robust XGB\", roc_auc_new, pr_auc_new)] + benchmark_results,\n", + " key=lambda x: x[1],\n", + " reverse=True\n", + ")\n", + "\n", + "for i, (name, a, p) in enumerate(ranking, 1):\n", + " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 7) TOP FEATURES (GAIN \u2014 SAME AS YOUR CELL)\n", + "# ------------------------------------------------------------\n", + "print(\"\\n Top Features in Robust Model (Pure Biology):\")\n", + "\n", + "gain_scores = booster.get_score(importance_type='gain')\n", + "fi_rob = (\n", + " pd.DataFrame({\"feature\": list(gain_scores.keys()),\n", + " \"gain\": list(gain_scores.values())})\n", + " .sort_values(\"gain\", ascending=False)\n", + ")\n", + "\n", + "print(fi_rob.head(10).to_string(index=False))\n", + "print(\"\\n Done.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "914707\n" + ] + } + ], + "source": [ + "import pyarrow.parquet as pq\n", + "\n", + "pf = pq.ParquetFile(\"900k_test.parquet\")\n", + "print(pf.metadata.num_rows)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\ud83d\udd0d Loading ONLY bottom 900000 rows from: 900k_test.parquet\n", + "\ud83d\udd01 Loading trained robust model from JSON...\n", + "\u2705 Model loaded.\n", + " -> Training feature count: 8\n", + " -> Loaded shape: (900000, 17)\n", + " -> Label counts: {0: 891684, 1: 8316}\n", + " -> Final X_new shape: (900000, 8)\n", + "\n", + "\ud83d\udd25 Running predictions...\n", + "\n", + "==========================================\n", + "\ud83e\uddea RESULTS ON BOTTOM 100K ROWS (NEW DATA)\n", + "==========================================\n", + "Robust Model AUC : 0.982708\n", + "Robust Model PR-AUC : 0.523736\n", + "Baseline AUC (With Loc) : 0.766809\n", + "Robust AUC (No Loc) : 0.982708\n", + "Difference : 0.215899\n", + "\n", + "==========================================\n", + "\u2694\ufe0f NEW LEADERBOARD (ROBUST MODEL ON 900K)\n", + "==========================================\n", + " 1. My Robust XGB AUC=0.982708 PR=0.523736\n", + " 2. BayesDel AUC=0.949125 PR=0.669971\n", + " 3. ClinPred AUC=0.943464 PR=0.487158\n", + " 4. REVEL AUC=0.937387 PR=0.466993\n", + " 5. AlphaMissense AUC=0.933589 PR=0.433256\n", + " 6. PrimateAI AUC=0.911503 PR=0.224050\n", + " 7. CADD AUC=0.909277 PR=0.573139\n", + " 8. MPC AUC=0.901957 PR=0.250881\n", + "\n", + "\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\n", + " feature gain\n", + " phyloP100way_vertebrate_rankscore 230.754532\n", + " phyloP470way_mammalian_rankscore 59.637676\n", + " gnomad_af 53.463173\n", + " phastCons17way_primate_rankscore 17.051365\n", + "phastCons470way_mammalian_rankscore 13.066975\n", + " ref 11.242688\n", + " alt 10.989820\n", + " GERP_91_mammals_rankscore 7.565748\n", + "\n", + "\u2705 Done.\n" + ] + } + ], + "source": [ + "# ============================================================\n", + "# CELL: TEST ROBUST MODEL ON BOTTOM 900K ROWS (CATEGORICAL-SAFE)\n", + "# ============================================================\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import xgboost as xgb\n", + "from sklearn.metrics import roc_auc_score, average_precision_score\n", + "\n", + "DATA_PATH = \"900k_test.parquet\"\n", + "MODEL_JSON = \"pathopreter_robust_biology.json\"\n", + "N_ROWS = 900_000\n", + "\n", + "print(f\"\\n\ud83d\udd0d Loading ONLY bottom {N_ROWS} rows from: {DATA_PATH}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 1) LOAD YOUR ROBUST MODEL (JSON \u2013 SAFE FORMAT)\n", + "# ------------------------------------------------------------\n", + "print(\"\ud83d\udd01 Loading trained robust model from JSON...\")\n", + "\n", + "booster = xgb.Booster()\n", + "booster.load_model(MODEL_JSON)\n", + "\n", + "# Exact feature list used in training\n", + "train_feature_names = booster.feature_names\n", + "\n", + "print(f\"\u2705 Model loaded.\")\n", + "print(f\" -> Training feature count: {len(train_feature_names)}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 2) LOAD ONLY WHAT THE MODEL NEEDS + LABEL\n", + "# ------------------------------------------------------------\n", + "\n", + "# needed_cols = list(set(train_feature_names + [\"clean_label\"]))\n", + "needed_cols = list(set(\n", + " train_feature_names \n", + " + [\"clean_label\"] \n", + " + list(benchmarks.values())\n", + "))\n", + "df_new = (\n", + " pd.read_parquet(DATA_PATH, columns=needed_cols)\n", + " .tail(N_ROWS)\n", + " .reset_index(drop=True)\n", + ")\n", + "\n", + "# Make label exactly like training\n", + "df_new[\"y\"] = (df_new[\"clean_label\"] == \"Pathogenic\").astype(int)\n", + "\n", + "print(f\" -> Loaded shape: {df_new.shape}\")\n", + "print(f\" -> Label counts: {df_new['y'].value_counts().to_dict()}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 3) BUILD X_new IN EXACT TRAINING ORDER (CRITICAL)\n", + "# ------------------------------------------------------------\n", + "\n", + "X_new = pd.DataFrame()\n", + "\n", + "for feat in train_feature_names:\n", + " if feat in df_new.columns:\n", + " X_new[feat] = df_new[feat]\n", + " else:\n", + " # If some feature is missing in new data, add it\n", + " X_new[feat] = 0.0\n", + "\n", + "# -------- HANDLE CATEGORICAL COLUMNS SAFELY --------\n", + "for cat_col in [\"ref\", \"alt\"]:\n", + " if cat_col in X_new.columns:\n", + " # Convert to categorical\n", + " X_new[cat_col] = X_new[cat_col].astype(\"category\")\n", + "\n", + " # If NaNs exist, add \"N\" (unknown nucleotide) as a valid category\n", + " if X_new[cat_col].isna().any():\n", + " if \"N\" not in X_new[cat_col].cat.categories:\n", + " X_new[cat_col] = X_new[cat_col].cat.add_categories(\"N\")\n", + " X_new[cat_col] = X_new[cat_col].fillna(\"N\")\n", + "\n", + "# -------- FILL NA FOR NUMERIC COLUMNS ONLY --------\n", + "num_cols = X_new.select_dtypes(include=[np.number]).columns\n", + "X_new[num_cols] = X_new[num_cols].fillna(0)\n", + "\n", + "y_new = df_new[\"y\"].values\n", + "\n", + "print(f\" -> Final X_new shape: {X_new.shape}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 4) PREDICT (MATCHES YOUR TRAINING)\n", + "# ------------------------------------------------------------\n", + "print(\"\\n\ud83d\udd25 Running predictions...\")\n", + "\n", + "dmat = xgb.DMatrix(\n", + " X_new,\n", + " feature_names=train_feature_names,\n", + " enable_categorical=True # MATCHES YOUR TRAINING\n", + ")\n", + "\n", + "y_prob_new = booster.predict(dmat, validate_features=False)\n", + "\n", + "# ------------------------------------------------------------\n", + "# 5) METRICS (SAME STYLE AS YOUR CELL)\n", + "# ------------------------------------------------------------\n", + "roc_auc_new = roc_auc_score(y_new, y_prob_new)\n", + "pr_auc_new = average_precision_score(y_new, y_prob_new)\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\"\ud83e\uddea RESULTS ON BOTTOM 100K ROWS (NEW DATA)\")\n", + "print(\"==========================================\")\n", + "print(f\"Robust Model AUC : {roc_auc_new:.6f}\")\n", + "print(f\"Robust Model PR-AUC : {pr_auc_new:.6f}\")\n", + "\n", + "# Baseline comparison (same logic as your notebook)\n", + "baseline_auc = roc_auc if 'roc_auc' in locals() else 0.747704\n", + "delta = roc_auc_new - baseline_auc\n", + "\n", + "print(f\"Baseline AUC (With Loc) : {baseline_auc:.6f}\")\n", + "print(f\"Robust AUC (No Loc) : {roc_auc_new:.6f}\")\n", + "print(f\"Difference : {delta:.6f}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 6) BENCHMARK AGAINST SOTA (SAME AS YOUR CELL)\n", + "# ------------------------------------------------------------\n", + "benchmarks = {\n", + " \"CADD\": \"CADD_raw_rankscore\",\n", + " \"REVEL\": \"REVEL_rankscore\",\n", + " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", + " \"PrimateAI\": \"PrimateAI_rankscore\",\n", + " \"MPC\": \"MPC_rankscore\",\n", + " \"ClinPred\": \"ClinPred_rankscore\",\n", + " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", + "}\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\"\u2694\ufe0f NEW LEADERBOARD (ROBUST MODEL ON 900K)\")\n", + "print(\"==========================================\")\n", + "\n", + "benchmark_results = []\n", + "\n", + "for name, col in benchmarks.items():\n", + " if col not in df_new.columns:\n", + " print(f\"\u26a0\ufe0f Skipping {name} (column missing)\")\n", + " continue\n", + "\n", + " scores = df_new[col].fillna(0).values\n", + "\n", + " try:\n", + " b_auc = roc_auc_score(y_new, scores)\n", + " b_pr = average_precision_score(y_new, scores)\n", + " benchmark_results.append((name, b_auc, b_pr))\n", + " except ValueError:\n", + " pass\n", + "\n", + "ranking = sorted(\n", + " [(\"My Robust XGB\", roc_auc_new, pr_auc_new)] + benchmark_results,\n", + " key=lambda x: x[1],\n", + " reverse=True\n", + ")\n", + "\n", + "for i, (name, a, p) in enumerate(ranking, 1):\n", + " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", + "\n", + "# ------------------------------------------------------------\n", + "# 7) TOP FEATURES (GAIN \u2014 SAME AS YOUR CELL)\n", + "# ------------------------------------------------------------\n", + "print(\"\\n\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\")\n", + "\n", + "gain_scores = booster.get_score(importance_type='gain')\n", + "fi_rob = (\n", + " pd.DataFrame({\"feature\": list(gain_scores.keys()),\n", + " \"gain\": list(gain_scores.values())})\n", + " .sort_values(\"gain\", ascending=False)\n", + ")\n", + "\n", + "print(fi_rob.head(10).to_string(index=False))\n", + "print(\"\\n\u2705 Done.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Loading test_enriched.parquet...\n", + "Using 10 features (training order)\n", + "\n", + " BASELINE: AUC=0.766809 | PR=0.771672\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 8/10 [00:01<00:00, 4.38it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 10/10 [00:02<00:00, 4.35it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " PERMUTATION ABLATION RESULTS:\n", + " feature auc pr auc_drop pr_drop\n", + " gnomad_af 0.574536 0.625798 0.192274 0.145874\n", + " pos 0.743102 0.745121 0.023707 0.026551\n", + " chrom 0.747204 0.748561 0.019605 0.023111\n", + " ref 0.748308 0.747789 0.018501 0.023883\n", + " alt 0.748902 0.754349 0.017907 0.017323\n", + " phyloP100way_vertebrate_rankscore 0.753048 0.727441 0.013761 0.044231\n", + " phastCons17way_primate_rankscore 0.759336 0.760803 0.007473 0.010869\n", + " GERP_91_mammals_rankscore 0.765210 0.759358 0.001599 0.012314\n", + "phastCons470way_mammalian_rankscore 0.765913 0.771196 0.000896 0.000476\n", + " phyloP470way_mammalian_rankscore 0.773115 0.771422 -0.006306 0.000250\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from sklearn.metrics import roc_auc_score, average_precision_score\n", + "from tqdm import tqdm\n", + "\n", + "SEED = 42\n", + "\n", + "print(\" Loading test_enriched.parquet...\")\n", + "test_df = pd.read_parquet(\"test_enriched_SEQUENCES.parquet\")\n", + "\n", + "# GET EXACT TRAINING FEATURE ORDER\n", + "features = model.get_booster().feature_names\n", + "\n", + "print(f\"Using {len(features)} features (training order)\")\n", + "\n", + "# ---- PREP ----\n", + "X_test = pd.DataFrame()\n", + "\n", + "# build in correct order\n", + "for feat in features:\n", + " if feat in test_df.columns:\n", + " X_test[feat] = test_df[feat]\n", + " else:\n", + " X_test[feat] = 0 # safety\n", + "\n", + "y_test = (test_df[\"clean_label\"] == \"Pathogenic\").astype(int).values\n", + "\n", + "# ---- FIX TYPES ----\n", + "for col in [\"chrom\", \"ref\", \"alt\"]:\n", + " if col in X_test.columns:\n", + " X_test[col] = X_test[col].astype(\"category\")\n", + "\n", + "for col in X_test.columns:\n", + " if col not in [\"chrom\", \"ref\", \"alt\"]:\n", + " X_test[col] = pd.to_numeric(X_test[col], errors=\"coerce\")\n", + "\n", + "X_test = X_test.fillna(0)\n", + "\n", + "# ---- BASELINE ----\n", + "y_prob = model.predict_proba(X_test)[:, 1]\n", + "\n", + "base_auc = roc_auc_score(y_test, y_prob)\n", + "base_pr = average_precision_score(y_test, y_prob)\n", + "\n", + "print(f\"\\n BASELINE: AUC={base_auc:.6f} | PR={base_pr:.6f}\")\n", + "\n", + "# ---- PERMUTATION ABLATION ----\n", + "results = []\n", + "\n", + "cat_cols = [\"chrom\", \"ref\", \"alt\"]\n", + "\n", + "for feat in tqdm(features):\n", + " X_perm = X_test.copy()\n", + "\n", + " # shuffle feature\n", + " X_perm[feat] = np.random.permutation(X_perm[feat].values)\n", + "\n", + " # \ud83d\udd25 FIX: restore dtypes AFTER permutation\n", + " for col in cat_cols:\n", + " if col in X_perm.columns:\n", + " X_perm[col] = X_perm[col].astype(\"category\")\n", + "\n", + " for col in X_perm.columns:\n", + " if col not in cat_cols:\n", + " X_perm[col] = pd.to_numeric(X_perm[col], errors=\"coerce\")\n", + "\n", + " X_perm = X_perm.fillna(0)\n", + "\n", + " # predict\n", + " y_prob_perm = model.predict_proba(X_perm)[:, 1]\n", + "\n", + " auc = roc_auc_score(y_test, y_prob_perm)\n", + " pr = average_precision_score(y_test, y_prob_perm)\n", + "\n", + " results.append({\n", + " \"feature\": feat,\n", + " \"auc\": auc,\n", + " \"pr\": pr,\n", + " \"auc_drop\": base_auc - auc,\n", + " \"pr_drop\": base_pr - pr\n", + " })\n", + "\n", + "# ---- RESULTS ----\n", + "ablation_df = pd.DataFrame(results).sort_values(\"auc_drop\", ascending=False)\n", + "\n", + "print(\"\\n PERMUTATION ABLATION RESULTS:\")\n", + "print(ablation_df.to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udcc9 Generating Ablation ROC Curves...\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import roc_curve, auc\n", + "\n", + "print(\"\ud83d\udcc9 Generating Ablation ROC Curves...\")\n", + "\n", + "plt.figure(figsize=(12, 8))\n", + "\n", + "# baseline\n", + "fpr_base, tpr_base, _ = roc_curve(y_test, y_prob)\n", + "plt.plot(fpr_base, tpr_base, lw=3, color='black',\n", + " label=f'Baseline (AUC={base_auc:.3f})')\n", + "\n", + "# random line\n", + "plt.plot([0,1], [0,1], 'k--', lw=2)\n", + "\n", + "# store curves\n", + "curve_results = []\n", + "\n", + "for feat in features:\n", + " X_perm = X_test.copy()\n", + " X_perm[feat] = np.random.permutation(X_perm[feat].values)\n", + "\n", + " # FIX TYPES AGAIN\n", + " for col in [\"chrom\", \"ref\", \"alt\"]:\n", + " if col in X_perm.columns:\n", + " X_perm[col] = X_perm[col].astype(\"category\")\n", + "\n", + " for col in X_perm.columns:\n", + " if col not in [\"chrom\", \"ref\", \"alt\"]:\n", + " X_perm[col] = pd.to_numeric(X_perm[col], errors=\"coerce\")\n", + "\n", + " X_perm = X_perm.fillna(0)\n", + "\n", + " y_prob_perm = model.predict_proba(X_perm)[:, 1]\n", + "\n", + " fpr, tpr, _ = roc_curve(y_test, y_prob_perm)\n", + " roc_auc = auc(fpr, tpr)\n", + "\n", + " plt.plot(fpr, tpr, lw=1.5, alpha=0.7,\n", + " label=f'No {feat} (AUC={roc_auc:.3f})')\n", + "\n", + "# formatting\n", + "plt.xlim([0.0, 1.0])\n", + "plt.ylim([0.0, 1.05])\n", + "plt.xlabel('False Positive Rate')\n", + "plt.ylabel('True Positive Rate')\n", + "plt.title('Ablation: Feature Importance Analysis')\n", + "plt.legend(loc=\"lower right\", fontsize='small')\n", + "plt.grid(True, alpha=0.3)\n", + "\n", + "plt.savefig(\"ablation_roc.png\", dpi=300)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\ud83e\uddea Running NOISE-REMOVAL ABLATION (FINAL FIX)...\n", + "Removing 3 noisy features:\n", + "['GERP++_RS_rankscore', 'phyloP17way_primate_rankscore', 'phastCons100way_vertebrate_rankscore']\n", + "\n", + "==========================================\n", + "\ud83e\uddf9 NOISE-REMOVED MODEL PERFORMANCE\n", + "==========================================\n", + "AUC : 0.766809\n", + "PR-AUC : 0.771672\n", + "AUC Gain: 0.000000\n", + "PR Gain : 0.000000\n" + ] + } + ], + "source": [ + "print(\"\\n\ud83e\uddea Running NOISE-REMOVAL ABLATION (FINAL FIX)...\")\n", + "\n", + "# ---- DEFINE NOISY FEATURES ----\n", + "noise_features = [\n", + " \"GERP++_RS_rankscore\",\n", + " \"phyloP17way_primate_rankscore\",\n", + " \"phastCons100way_vertebrate_rankscore\"\n", + "]\n", + "\n", + "print(f\"Removing {len(noise_features)} noisy features:\")\n", + "print(noise_features)\n", + "\n", + "# ---- START FROM ORIGINAL TEST ----\n", + "X_clean = X_test.copy()\n", + "\n", + "# ---- \ud83d\udd25 ZERO OUT (DO NOT DROP) ----\n", + "for f in noise_features:\n", + " if f in X_clean.columns:\n", + " X_clean[f] = 0.0\n", + "\n", + "# ---- \ud83d\udd25 FORCE EXACT TRAINING ORDER (CRITICAL) ----\n", + "feature_order = model.get_booster().feature_names\n", + "X_clean = X_clean[feature_order]\n", + "\n", + "# ---- FIX TYPES ----\n", + "cat_cols = [\"chrom\", \"ref\", \"alt\"]\n", + "\n", + "for col in cat_cols:\n", + " if col in X_clean.columns:\n", + " X_clean[col] = X_clean[col].astype(\"category\")\n", + "\n", + "for col in X_clean.columns:\n", + " if col not in cat_cols:\n", + " X_clean[col] = pd.to_numeric(X_clean[col], errors=\"coerce\")\n", + "\n", + "X_clean = X_clean.fillna(0)\n", + "\n", + "# ---- PREDICT ----\n", + "y_prob_clean = model.predict_proba(X_clean)[:, 1]\n", + "\n", + "# ---- METRICS ----\n", + "clean_auc = roc_auc_score(y_test, y_prob_clean)\n", + "clean_pr = average_precision_score(y_test, y_prob_clean)\n", + "\n", + "print(\"\\n==========================================\")\n", + "print(\"\ud83e\uddf9 NOISE-REMOVED MODEL PERFORMANCE\")\n", + "print(\"==========================================\")\n", + "print(f\"AUC : {clean_auc:.6f}\")\n", + "print(f\"PR-AUC : {clean_pr:.6f}\")\n", + "print(f\"AUC Gain: {clean_auc - base_auc:.6f}\")\n", + "print(f\"PR Gain : {clean_pr - base_pr:.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots(figsize=(16, 11))\n", + "ax.axis('off')\n", + "\n", + "# ---------- Helper Functions ----------\n", + "def draw_box(text, xy, fc=\"#E8F0FE\", size=10):\n", + " ax.text(\n", + " xy[0], xy[1], text,\n", + " ha='center', va='center',\n", + " fontsize=size,\n", + " bbox=dict(boxstyle=\"round,pad=0.5\", facecolor=fc, edgecolor='black')\n", + " )\n", + "\n", + "def draw_arrow(start, end):\n", + " ax.annotate(\n", + " \"\", xy=end, xytext=start,\n", + " arrowprops=dict(arrowstyle=\"->\", lw=1.5)\n", + " )\n", + "\n", + "# ---------- TITLE ----------\n", + "plt.title(\"Framework Pipeline\", fontsize=18, weight='bold')\n", + "\n", + "# =====================================================\n", + "# DATA SOURCES (STRAIGHT + CLEAN)\n", + "# =====================================================\n", + "draw_box(\"ClinVar\\n(~7.7M variants)\", (0.1, 0.90))\n", + "draw_box(\"gnomAD\\n(AF)\", (0.3, 0.90))\n", + "draw_box(\"dbNSFP\\n(functional scores)\", (0.5, 0.90))\n", + "draw_box(\"GRCh38\\n(sequence)\", (0.7, 0.90))\n", + "\n", + "draw_box(\"Data Integration\", (0.4, 0.78), fc=\"#D1E7DD\", size=11)\n", + "\n", + "# straight arrows\n", + "draw_arrow((0.1, 0.86), (0.35, 0.80))\n", + "draw_arrow((0.3, 0.86), (0.38, 0.80))\n", + "draw_arrow((0.5, 0.86), (0.42, 0.80))\n", + "draw_arrow((0.7, 0.86), (0.45, 0.80))\n", + "\n", + "# =====================================================\n", + "# PIPELINE (MORE SPACING)\n", + "# =====================================================\n", + "draw_box(\n", + " \"Filtering & Cleaning\\n\"\n", + " \"\u2022 SNV only\\n\"\n", + " \"\u2022 Remove HGVS (Leakage Prevention)\\n\"\n", + " \"\u2022 Deduplication\",\n", + " (0.4, 0.64),\n", + ")\n", + "\n", + "draw_box(\n", + " \"Feature Engineering\\n\"\n", + " \"\u2022 gnomAD AF\\n\"\n", + " \"\u2022 Conservation (phyloP, phastCons)\\n\"\n", + " \"\u2022 Median Imputation\",\n", + " (0.4, 0.48),\n", + ")\n", + "\n", + "draw_box(\n", + " \"Dataset Construction\\n\"\n", + " \"Train: 300k (Balanced)\\n\"\n", + " \"Test: 14k (Rare Balanced)\\n\"\n", + " \"Test: 100k (Imbalanced)\",\n", + " (0.4, 0.32),\n", + " fc=\"#E2F0CB\"\n", + ")\n", + "\n", + "draw_box(\n", + " \"XGBoost Model\\n\"\n", + " \"\u2022 10 Core Features\\n\"\n", + " \"\u2022 Regularized Training\",\n", + " (0.4, 0.18),\n", + " fc=\"#D9EAD3\"\n", + ")\n", + "\n", + "draw_box(\n", + " \"Evaluation Pipeline\\n\"\n", + " \"\u2022 ROC-AUC / PR-AUC\\n\"\n", + " \"\u2022 Threshold Tuning (VAL \u2192 TEST)\\n\"\n", + " \"\u2022 SHAP + Ablation\",\n", + " (0.4, 0.05),\n", + " fc=\"#FCE5CD\"\n", + ")\n", + "\n", + "# arrows (vertical clean)\n", + "draw_arrow((0.4, 0.76), (0.4, 0.68))\n", + "draw_arrow((0.4, 0.60), (0.4, 0.52))\n", + "draw_arrow((0.4, 0.44), (0.4, 0.36))\n", + "draw_arrow((0.4, 0.28), (0.4, 0.21))\n", + "draw_arrow((0.4, 0.15), (0.4, 0.09))\n", + "\n", + "# =====================================================\n", + "# BENCHMARK (FIXED CLEANLY)\n", + "# =====================================================\n", + "draw_box(\n", + " \"Generalist Models\\n\"\n", + " \"\u2022 CADD\\n\u2022 REVEL\\n\u2022 AlphaMissense\\n\u2022 BayesDel\",\n", + " (0.75, 0.05),\n", + " fc=\"#FFF3CD\"\n", + ")\n", + "\n", + "# clean horizontal arrow\n", + "draw_arrow((0.50, 0.05), (0.70, 0.05))\n", + "\n", + "# aligned label\n", + "ax.text(0.60, 0.09, \"Benchmark Comparison\", fontsize=10, ha='center')\n", + "\n", + "# =====================================================\n", + "# SAVE\n", + "# =====================================================\n", + "# plt.savefig(\"framework_final.svg\", bbox_inches='tight')\n", + "# plt.savefig(\"framework_final.png\", dpi=300, bbox_inches='tight')\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file