diff --git "a/XGBOOST_XAI/ML2.json" "b/XGBOOST_XAI/ML2.json" deleted file mode 100644--- "a/XGBOOST_XAI/ML2.json" +++ /dev/null @@ -1,2172 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting xgboost\n", - " Downloading xgboost-3.2.0-py3-none-manylinux_2_28_x86_64.whl.metadata (2.1 kB)\n", - "Collecting shap\n", - " Downloading shap-0.51.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (25 kB)\n", - "Requirement already satisfied: pandas in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.1.4)\n", - "Requirement already satisfied: numpy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (1.26.4)\n", - "Requirement already satisfied: scikit-learn in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (1.3.2)\n", - "Requirement already satisfied: matplotlib in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (3.8.2)\n", - "Collecting seaborn\n", - " Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n", - "Collecting pyarrow\n", - " Downloading pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.1 kB)\n", - "Collecting fastparquet\n", - " Downloading fastparquet-2026.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (4.8 kB)\n", - "Requirement already satisfied: nvidia-nccl-cu12 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from xgboost) (2.27.3)\n", - "Requirement already satisfied: scipy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from xgboost) (1.11.4)\n", - "Collecting numpy\n", - " Downloading numpy-2.4.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)\n", - "Requirement already satisfied: tqdm>=4.27.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from shap) (4.67.3)\n", - "Requirement already satisfied: packaging>20.9 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from shap) (25.0)\n", - "Collecting slicer==0.0.8 (from shap)\n", - " Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB)\n", - "Collecting numba (from shap)\n", - " Downloading numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.9 kB)\n", - "Collecting llvmlite (from shap)\n", - " Downloading llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.0 kB)\n", - "Collecting cloudpickle (from shap)\n", - " Downloading cloudpickle-3.1.2-py3-none-any.whl.metadata (7.1 kB)\n", - "Requirement already satisfied: typing-extensions in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from shap) (4.15.0)\n", - "INFO: pip is looking at multiple versions of pandas to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting pandas\n", - " Downloading pandas-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (79 kB)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas) (2.9.0.post0)\n", - "INFO: pip is looking at multiple versions of scikit-learn to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting scikit-learn\n", - " Downloading scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)\n", - "Requirement already satisfied: joblib>=1.3.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from scikit-learn) (1.5.3)\n", - "Requirement already satisfied: threadpoolctl>=3.2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from scikit-learn) (3.6.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (4.62.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (1.5.0)\n", - "INFO: pip is looking at multiple versions of matplotlib to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting matplotlib\n", - " Downloading matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (52 kB)\n", - "Requirement already satisfied: pillow>=8 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (12.1.1)\n", - "Requirement already satisfied: pyparsing>=3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib) (3.3.2)\n", - "Collecting cramjam>=2.3 (from fastparquet)\n", - " Downloading cramjam-2.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)\n", - "Requirement already satisfied: fsspec in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from fastparquet) (2026.2.0)\n", - "Requirement already satisfied: six>=1.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n", - "INFO: pip is looking at multiple versions of scipy to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting scipy (from xgboost)\n", - " Downloading scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)\n", - "Downloading xgboost-3.2.0-py3-none-manylinux_2_28_x86_64.whl (131.7 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m131.7/131.7 MB\u001b[0m \u001b[31m210.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hDownloading shap-0.51.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m104.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading slicer-0.0.8-py3-none-any.whl (15 kB)\n", - "Downloading pandas-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (10.9 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10.9/10.9 MB\u001b[0m \u001b[31m190.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading numpy-2.4.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m16.6/16.6 MB\u001b[0m \u001b[31m174.3 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (8.9 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m180.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m179.1 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n", - "Downloading pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl (47.6 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m47.6/47.6 MB\u001b[0m \u001b[31m203.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0mm0:00:01\u001b[0m\n", - "\u001b[?25hDownloading fastparquet-2026.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m161.4 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading cramjam-2.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m168.1 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (35.2 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m35.2/35.2 MB\u001b[0m \u001b[31m198.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading cloudpickle-3.1.2-py3-none-any.whl (22 kB)\n", - "Downloading llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (56.3 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m205.4 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0mm0:00:01\u001b[0m\n", - "\u001b[?25hDownloading numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.8 MB)\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m179.7 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: slicer, pyarrow, numpy, llvmlite, cramjam, cloudpickle, scipy, pandas, numba, xgboost, scikit-learn, matplotlib, fastparquet, shap, seaborn\n", - "\u001b[2K Attempting uninstall: numpy\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", - "\u001b[2K Found existing installation: numpy 1.26.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", - "\u001b[2K Uninstalling numpy-1.26.4:\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", - "\u001b[2K Successfully uninstalled numpy-1.26.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 1/15\u001b[0m [pyarrow]\n", - "\u001b[2K Attempting uninstall: scipy[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]]\n", - "\u001b[2K Found existing installation: scipy 1.11.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]\n", - "\u001b[2K Uninstalling scipy-1.11.4:90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]\n", - "\u001b[2K Successfully uninstalled scipy-1.11.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 4/15\u001b[0m [cramjam]\n", - "\u001b[2K Attempting uninstall: pandas90m\u257a\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 6/15\u001b[0m [scipy]\n", - "\u001b[2K Found existing installation: pandas 2.1.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 6/15\u001b[0m [scipy]\n", - "\u001b[2K Uninstalling pandas-2.1.4:\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 7/15\u001b[0m [pandas]\n", - "\u001b[2K Successfully uninstalled pandas-2.1.4\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 7/15\u001b[0m [pandas]\n", - "\u001b[2K Attempting uninstall: scikit-learn\u001b[90m\u257a\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 9/15\u001b[0m [xgboost]\n", - "\u001b[2K Found existing installation: scikit-learn 1.3.2\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m 9/15\u001b[0m [xgboost]\n", - "\u001b[2K Uninstalling scikit-learn-1.3.2:0m\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", - "\u001b[2K Successfully uninstalled scikit-learn-1.3.2\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", - "\u001b[2K Attempting uninstall: matplotlib\u001b[0m\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", - "\u001b[2K Found existing installation: matplotlib 3.8.2\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", - "\u001b[2K Uninstalling matplotlib-3.8.2:\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m10/15\u001b[0m [scikit-learn]\n", - "\u001b[2K Successfully uninstalled matplotlib-3.8.2\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m11/15\u001b[0m [matplotlib]\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m15/15\u001b[0m [seaborn]4/15\u001b[0m [seaborn]uet]\n", - "\u001b[1A\u001b[2KSuccessfully installed cloudpickle-3.1.2 cramjam-2.11.0 fastparquet-2026.3.0 llvmlite-0.46.0 matplotlib-3.10.8 numba-0.64.0 numpy-2.4.3 pandas-3.0.1 pyarrow-23.0.1 scikit-learn-1.8.0 scipy-1.17.1 seaborn-0.13.2 shap-0.51.0 slicer-0.0.8 xgboost-3.2.0\n" - ] - } - ], - "source": [ - "#cell-0: Install Dependencies\n", - "!pip install xgboost shap pandas numpy scikit-learn matplotlib seaborn pyarrow fastparquet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Launching Training on 32 Cores...\n" - ] - } - ], - "source": [ - "# Cell 1: Imports & Hardware Configuration\n", - "import pandas as pd\n", - "import numpy as np\n", - "import xgboost as xgb\n", - "import shap\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import multiprocessing\n", - "import gc\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import roc_auc_score, average_precision_score, classification_report, confusion_matrix, f1_score, balanced_accuracy_score, roc_curve, auc\n", - "\n", - "# 1. Hardware Setup\n", - "n_cores = multiprocessing.cpu_count()\n", - "print(f\"Launching Training on {n_cores} Cores...\")\n", - "\n", - "# 2. Global Config\n", - "SEED = 42\n", - "pd.set_option('display.max_columns', None)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u26a1 Loading Parquet Files...\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train Shape: (299999, 66)\n", - "Test Shape: (14073, 66)\n", - "\n", - "Using 10 Tier-1 Features:\n", - "['chrom', 'pos', 'ref', 'alt', 'gnomad_af', 'GERP_91_mammals_rankscore', 'phyloP100way_vertebrate_rankscore', 'phyloP470way_mammalian_rankscore', 'phastCons470way_mammalian_rankscore', 'phastCons17way_primate_rankscore']\n", - " -> Encoding labels to 0 (Benign) and 1 (Pathogenic)...\n", - " -> Train unique labels: [1 0]\n", - " -> Test unique labels: [0 1]\n", - "Labels successfully encoded to integers.\n" - ] - } - ], - "source": [ - "# Cell 2: Load Data & Define Features (FIXED)\n", - "print(\"\u26a1 Loading Parquet Files...\")\n", - "\n", - "# Load your specific files\n", - "train_df = pd.read_parquet(\"Balanced_300k_SEQUENCES.parquet\")\n", - "test_df = pd.read_parquet(\"test_enriched_SEQUENCES.parquet\")\n", - "\n", - "print(f\"Train Shape: {train_df.shape}\")\n", - "print(f\"Test Shape: {test_df.shape}\")\n", - "\n", - "# ---- Define Tier-1 Features (Physics/Evolution only) ----\n", - "tier1_features = [\n", - " \"chrom\", \"pos\", \"ref\", \"alt\", \"gnomad_af\",\n", - " \"GERP_91_mammals_rankscore\", # keep (weak but not harmful)\n", - " \"phyloP100way_vertebrate_rankscore\", # strong\n", - " \"phyloP470way_mammalian_rankscore\", # moderate\n", - " \"phastCons470way_mammalian_rankscore\",\n", - " \"phastCons17way_primate_rankscore\"\n", - "]\n", - "\n", - "# Filter to keep only what exists in your DF\n", - "features = [c for c in tier1_features if c in train_df.columns]\n", - "print(f\"\\nUsing {len(features)} Tier-1 Features:\")\n", - "print(features)\n", - "\n", - "# ---- FIXED LABEL MAPPING ----\n", - "# We force conversion to string first to handle both 'category' and 'object' types\n", - "print(\" -> Encoding labels to 0 (Benign) and 1 (Pathogenic)...\")\n", - "\n", - "# 1. Force strings and strip whitespace to be safe\n", - "train_df[\"clean_label\"] = train_df[\"clean_label\"].astype(str).str.strip()\n", - "test_df[\"clean_label\"] = test_df[\"clean_label\"].astype(str).str.strip()\n", - "\n", - "# 2. Map explicitly\n", - "label_map = {\"Benign\": 0, \"Pathogenic\": 1}\n", - "\n", - "train_df[\"y\"] = train_df[\"clean_label\"].map(label_map)\n", - "test_df[\"y\"] = test_df[\"clean_label\"].map(label_map)\n", - "\n", - "# 3. Safety Check: Verify we have only 0s and 1s\n", - "print(f\" -> Train unique labels: {train_df['y'].unique()}\")\n", - "print(f\" -> Test unique labels: {test_df['y'].unique()}\")\n", - "\n", - "if train_df[\"y\"].isna().any():\n", - " raise ValueError(\"\u274c Error: Some labels in Train could not be mapped! Check for typos in 'Benign'/'Pathogenic'.\")\n", - "if test_df[\"y\"].isna().any():\n", - " raise ValueError(\"\u274c Error: Some labels in Test could not be mapped! Check for typos.\")\n", - "\n", - "print(\"Labels successfully encoded to integers.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " -> Casting chrom to 'category'...\n", - " -> Casting ref to 'category'...\n", - " -> Casting alt to 'category'...\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Data Ready: Train(269999), Val(30000), Hard Test(14073)\n" - ] - } - ], - "source": [ - "# Cell 3: Preprocessing & Type Casting\n", - "\n", - "# 1. Cast Categoricals for Native XGBoost Support\n", - "cat_cols = [\"chrom\", \"ref\", \"alt\"]\n", - "for col in cat_cols:\n", - " if col in features:\n", - " print(f\" -> Casting {col} to 'category'...\")\n", - " train_df[col] = train_df[col].astype(\"category\")\n", - " test_df[col] = test_df[col].astype(\"category\")\n", - "\n", - "# 2. Ensure 'pos' and scores are numeric\n", - "num_cols = [c for c in features if c not in cat_cols]\n", - "for col in num_cols:\n", - " train_df[col] = pd.to_numeric(train_df[col], errors='coerce').fillna(0)\n", - " test_df[col] = pd.to_numeric(test_df[col], errors='coerce').fillna(0)\n", - "\n", - "# 3. Create X and y matrices\n", - "X = train_df[features]\n", - "y = train_df['y']\n", - "\n", - "# Keep benchmarking columns in test_df separate, but extract X_test for prediction\n", - "X_test = test_df[features]\n", - "y_test = test_df['y']\n", - "\n", - "# 4. Stratified Split for Validation (Using 10% of Train)\n", - "X_train, X_val, y_train, y_val = train_test_split(\n", - " X, y, test_size=0.10, stratify=y, random_state=SEED\n", - ")\n", - "\n", - "print(f\"\\nData Ready: Train({len(X_train)}), Val({len(X_val)}), Hard Test({len(X_test)})\")\n", - "gc.collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udee0 Configuring XGBoost SOTA Hyperparameters...\n" - ] - } - ], - "source": [ - "# Cell 4: Initialize SOTA Model\n", - "print(\"\ud83d\udee0 Configuring XGBoost SOTA Hyperparameters...\")\n", - "\n", - "model = xgb.XGBClassifier(\n", - " # --- Architecture ---\n", - " n_estimators=10000, # High ceiling (relies on early stopping)\n", - " learning_rate=0.01, # Very slow learning to squeeze out signal\n", - " max_depth=10, # Deep trees for complex genomic interactions\n", - " \n", - " # --- Regularization (Prevent Overfitting) ---\n", - " gamma=1.5, # High gamma: Conservative splitting (crucial for hard test sets)\n", - " min_child_weight=5, # Requires 5 samples to make a leaf\n", - " reg_alpha=0.5, # L1 Regularization\n", - " reg_lambda=1.0, # L2 Regularization\n", - " \n", - " # --- Sampling (Stochastic Gradient Boosting) ---\n", - " subsample=0.8, # Use 80% of rows per tree\n", - " colsample_bytree=0.8, # Use 80% of features per tree\n", - " \n", - " # --- Speed & Hardware ---\n", - " n_jobs=n_cores, # Use all 32 cores\n", - " tree_method='hist', # Histogram-based (Fastest)\n", - " enable_categorical=True, # Native categorical support\n", - " \n", - " # --- Objective ---\n", - " objective='binary:logistic',\n", - " eval_metric=['auc', 'logloss'],\n", - " early_stopping_rounds=200, # Patience\n", - " random_state=SEED\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Starting Training (this will take time due to LR=0.01)...\n", - "[0]\tvalidation_0-auc:0.96589\tvalidation_0-logloss:0.68516\tvalidation_1-auc:0.96544\tvalidation_1-logloss:0.68518\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[500]\tvalidation_0-auc:0.97908\tvalidation_0-logloss:0.18481\tvalidation_1-auc:0.97640\tvalidation_1-logloss:0.19517\n", - "[1000]\tvalidation_0-auc:0.98221\tvalidation_0-logloss:0.17017\tvalidation_1-auc:0.97810\tvalidation_1-logloss:0.18719\n", - "[1500]\tvalidation_0-auc:0.98369\tvalidation_0-logloss:0.16353\tvalidation_1-auc:0.97882\tvalidation_1-logloss:0.18417\n", - "[2000]\tvalidation_0-auc:0.98446\tvalidation_0-logloss:0.16001\tvalidation_1-auc:0.97909\tvalidation_1-logloss:0.18296\n", - "[2500]\tvalidation_0-auc:0.98493\tvalidation_0-logloss:0.15781\tvalidation_1-auc:0.97922\tvalidation_1-logloss:0.18237\n", - "[3000]\tvalidation_0-auc:0.98525\tvalidation_0-logloss:0.15626\tvalidation_1-auc:0.97929\tvalidation_1-logloss:0.18202\n", - "[3500]\tvalidation_0-auc:0.98549\tvalidation_0-logloss:0.15512\tvalidation_1-auc:0.97935\tvalidation_1-logloss:0.18178\n", - "[4000]\tvalidation_0-auc:0.98569\tvalidation_0-logloss:0.15414\tvalidation_1-auc:0.97939\tvalidation_1-logloss:0.18156\n", - "[4500]\tvalidation_0-auc:0.98585\tvalidation_0-logloss:0.15336\tvalidation_1-auc:0.97943\tvalidation_1-logloss:0.18136\n", - "[5000]\tvalidation_0-auc:0.98600\tvalidation_0-logloss:0.15262\tvalidation_1-auc:0.97945\tvalidation_1-logloss:0.18125\n", - "[5500]\tvalidation_0-auc:0.98612\tvalidation_0-logloss:0.15201\tvalidation_1-auc:0.97947\tvalidation_1-logloss:0.18117\n", - "[6000]\tvalidation_0-auc:0.98621\tvalidation_0-logloss:0.15157\tvalidation_1-auc:0.97948\tvalidation_1-logloss:0.18110\n", - "[6500]\tvalidation_0-auc:0.98631\tvalidation_0-logloss:0.15106\tvalidation_1-auc:0.97949\tvalidation_1-logloss:0.18103\n", - "[7000]\tvalidation_0-auc:0.98640\tvalidation_0-logloss:0.15059\tvalidation_1-auc:0.97951\tvalidation_1-logloss:0.18097\n", - "[7500]\tvalidation_0-auc:0.98649\tvalidation_0-logloss:0.15015\tvalidation_1-auc:0.97952\tvalidation_1-logloss:0.18091\n", - "[8000]\tvalidation_0-auc:0.98656\tvalidation_0-logloss:0.14978\tvalidation_1-auc:0.97953\tvalidation_1-logloss:0.18085\n", - "[8500]\tvalidation_0-auc:0.98661\tvalidation_0-logloss:0.14951\tvalidation_1-auc:0.97953\tvalidation_1-logloss:0.18084\n", - "[8582]\tvalidation_0-auc:0.98662\tvalidation_0-logloss:0.14946\tvalidation_1-auc:0.97953\tvalidation_1-logloss:0.18084\n", - "\n", - " Training Complete! Best Iteration: 8382\n" - ] - } - ], - "source": [ - "# Cell 5: Training Run\n", - "print(\"Starting Training (this will take time due to LR=0.01)...\")\n", - "\n", - "eval_set = [(X_train, y_train), (X_val, y_val)]\n", - "\n", - "model.fit(\n", - " X_train, y_train,\n", - " eval_set=eval_set,\n", - " verbose=500 # Log every 500 trees\n", - ")\n", - "\n", - "print(f\"\\n Training Complete! Best Iteration: {model.best_iteration}\")\n", - "\n", - "# --- Plot Learning Curve ---\n", - "results = model.evals_result()\n", - "epochs = len(results['validation_0']['auc'])\n", - "x_axis = range(0, epochs)\n", - "\n", - "plt.figure(figsize=(10, 6))\n", - "plt.plot(x_axis, results['validation_0']['auc'], label='Train')\n", - "plt.plot(x_axis, results['validation_1']['auc'], label='Validation')\n", - "plt.legend()\n", - "plt.ylabel('AUC Score')\n", - "plt.xlabel('Iterations')\n", - "plt.title('XGBoost Learning Curve (Check for divergence!)')\n", - "plt.grid(True)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - ], - "source": [ - "# Cell 6: ROC-AUC Visualization\n", - "from sklearn.metrics import roc_curve, auc\n", - "\n", - "# Predict probabilities\n", - "y_prob = model.predict_proba(X_test)[:, 1]\n", - "\n", - "# Calculate ROC\n", - "fpr, tpr, _ = roc_curve(y_test, y_prob)\n", - "roc_auc = auc(fpr, tpr)\n", - "\n", - "plt.figure(figsize=(8, 8))\n", - "plt.plot(fpr, tpr, color='darkorange', lw=2, label=f' Model (AUC = {roc_auc:.4f})')\n", - "plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n", - "plt.xlim([0.0, 1.0])\n", - "plt.ylim([0.0, 1.05])\n", - "plt.xlabel('False Positive Rate')\n", - "plt.ylabel('True Positive Rate')\n", - "plt.title('ROC Curve on 14k Test Set')\n", - "plt.legend(loc=\"lower right\")\n", - "plt.grid(True, alpha=0.3)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " Calculating SHAP Values...\n" - ] - } - ], - "source": [ - "# Cell 7: SHAP Explainability\n", - "print(\"\\n Calculating SHAP Values...\")\n", - "\n", - "explainer = shap.TreeExplainer(model)\n", - "\n", - "# Subsample 5000 rows for speed (calculating on all 14k is fine on 32 cores too)\n", - "shap_sample = X_test.sample(n=5000, random_state=SEED) if len(X_test) > 5000 else X_test\n", - "shap_values = explainer.shap_values(shap_sample)\n", - "\n", - "# Beeswarm Plot\n", - "plt.figure(figsize=(12, 8))\n", - "plt.title(\"SHAP Feature Importance (Global View)\")\n", - "shap.summary_plot(shap_values, shap_sample, show=False)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "==========================================\n", - " XGBOOST MODEL PERFORMANCE\n", - "==========================================\n", - "ROC-AUC : 0.766809\n", - "PR-AUC : 0.771672\n", - "\n", - "==========================================\n", - " COMPARISION WITH ESTABLISHED TOOLS\n", - "==========================================\n", - "CADD ROC-AUC = 0.592509 PR-AUC = 0.704837\n", - "REVEL ROC-AUC = 0.587114 PR-AUC = 0.600423\n", - "AlphaMissense ROC-AUC = 0.592032 PR-AUC = 0.603459\n", - "PrimateAI ROC-AUC = 0.566141 PR-AUC = 0.581314\n", - "MPC ROC-AUC = 0.559678 PR-AUC = 0.578559\n", - "ClinPred ROC-AUC = 0.590860 PR-AUC = 0.606637\n", - "BayesDel ROC-AUC = 0.594821 PR-AUC = 0.706338\n", - "\n", - "==========================================\n", - " FINAL LEADERBOARD\n", - "==========================================\n", - " 1. My XGBoost AUC=0.766809 PR=0.771672\n", - " 2. BayesDel AUC=0.594821 PR=0.706338\n", - " 3. CADD AUC=0.592509 PR=0.704837\n", - " 4. AlphaMissense AUC=0.592032 PR=0.603459\n", - " 5. ClinPred AUC=0.590860 PR=0.606637\n", - " 6. REVEL AUC=0.587114 PR=0.600423\n", - " 7. PrimateAI AUC=0.566141 PR=0.581314\n", - " 8. MPC AUC=0.559678 PR=0.578559\n" - ] - } - ], - "source": [ - "# Cell 8: Benchmarking Logic + Plots\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve\n", - "\n", - "roc_auc = roc_auc_score(y_test, y_prob)\n", - "pr_auc = average_precision_score(y_test, y_prob)\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(f\" XGBOOST MODEL PERFORMANCE\")\n", - "print(\"==========================================\")\n", - "print(f\"ROC-AUC : {roc_auc:.6f}\")\n", - "print(f\"PR-AUC : {pr_auc:.6f}\")\n", - "\n", - "# ---- Benchmark against DBNSFP models ----\n", - "benchmarks = {\n", - " \"CADD\": \"CADD_raw_rankscore\",\n", - " \"REVEL\": \"REVEL_rankscore\",\n", - " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", - " \"PrimateAI\": \"PrimateAI_rankscore\",\n", - " \"MPC\": \"MPC_rankscore\",\n", - " \"ClinPred\": \"ClinPred_rankscore\",\n", - " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", - "}\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\" COMPARISION WITH ESTABLISHED TOOLS\")\n", - "print(\"==========================================\")\n", - "\n", - "benchmark_results = []\n", - "\n", - "for name, col in benchmarks.items():\n", - " if col not in test_df.columns:\n", - " continue\n", - " \n", - " scores = test_df[col].fillna(0).values\n", - " \n", - " try:\n", - " b_auc = roc_auc_score(y_test, scores)\n", - " b_pr = average_precision_score(y_test, scores)\n", - " benchmark_results.append((name, b_auc, b_pr))\n", - " print(f\"{name:14s} ROC-AUC = {b_auc:.6f} PR-AUC = {b_pr:.6f}\")\n", - " except ValueError:\n", - " pass\n", - "\n", - "# ---- Leaderboard ----\n", - "print(\"\\n==========================================\")\n", - "print(\" FINAL LEADERBOARD\")\n", - "print(\"==========================================\")\n", - "\n", - "ranking = sorted([(\"My XGBoost\", roc_auc, pr_auc)] + benchmark_results, key=lambda x: x[1], reverse=True)\n", - "\n", - "for i, (name, a, p) in enumerate(ranking, 1):\n", - " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", - "\n", - "# ---- ROC Curve ----\n", - "fpr, tpr, _ = roc_curve(y_test, y_prob)\n", - "\n", - "plt.figure()\n", - "plt.plot(fpr, tpr)\n", - "plt.xlabel(\"False Positive Rate\")\n", - "plt.ylabel(\"True Positive Rate\")\n", - "plt.title(f\"ROC Curve (AUC = {roc_auc:.4f})\")\n", - "plt.grid()\n", - "plt.show()\n", - "\n", - "# ---- Precision-Recall Curve ----\n", - "precision, recall, _ = precision_recall_curve(y_test, y_prob)\n", - "\n", - "plt.figure()\n", - "plt.plot(recall, precision)\n", - "plt.xlabel(\"Recall\")\n", - "plt.ylabel(\"Precision\")\n", - "plt.title(f\"Precision-Recall Curve (AUC = {pr_auc:.4f})\")\n", - "plt.grid()\n", - "plt.show()\n", - "\n", - "# ---- Bar Plot Comparison ----\n", - "df_bench = pd.DataFrame(\n", - " [(\"My XGBoost\", roc_auc, pr_auc)] + benchmark_results,\n", - " columns=[\"Model\", \"ROC-AUC\", \"PR-AUC\"]\n", - ")\n", - "\n", - "df_bench = df_bench.sort_values(by=\"ROC-AUC\", ascending=False)\n", - "\n", - "plt.figure(figsize=(10,5))\n", - "plt.bar(df_bench[\"Model\"], df_bench[\"ROC-AUC\"])\n", - "plt.xticks(rotation=45)\n", - "plt.ylabel(\"ROC-AUC\")\n", - "plt.title(\"Model Comparison (ROC-AUC)\")\n", - "plt.grid(axis='y')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "==============================\n", - " VALIDATION RESULTS (TUNING)\n", - "==============================\n", - "Optimal Threshold : 0.488\n", - "Best F1 (Val) : 0.9242\n", - "Balanced Acc (Val): 0.9268\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "==============================\n", - " TEST PERFORMANCE (FINAL)\n", - "==============================\n", - "Threshold Used : 0.488\n", - "F1 (Test) : 0.7215\n", - "Balanced Acc (Test): 0.6174\n", - "\n", - "Confusion Matrix (Test):\n", - "[[1754 5258]\n", - " [ 109 6952]]\n", - "\n", - "Classification Report (Test):\n", - " precision recall f1-score support\n", - "\n", - " 0 0.9415 0.2501 0.3953 7012\n", - " 1 0.5694 0.9846 0.7215 7061\n", - "\n", - " accuracy 0.6186 14073\n", - " macro avg 0.7554 0.6174 0.5584 14073\n", - "weighted avg 0.7548 0.6186 0.5590 14073\n", - "\n" - ] - } - ], - "source": [ - "# Cell 9: Threshold Tuning (FULL VERSION - VAL + TEST, NO LEAKAGE)\n", - "\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.metrics import (\n", - " f1_score, balanced_accuracy_score,\n", - " confusion_matrix, classification_report\n", - ")\n", - "\n", - "# =========================================================\n", - "# 1. GET PROBABILITIES\n", - "# =========================================================\n", - "y_val_prob = model.predict_proba(X_val)[:, 1]\n", - "y_test_prob = model.predict_proba(X_test)[:, 1]\n", - "\n", - "thresholds = np.linspace(0.01, 0.99, 200)\n", - "\n", - "# =========================================================\n", - "# 2. TUNE THRESHOLD ON VALIDATION SET\n", - "# =========================================================\n", - "best_f1 = 0\n", - "best_t = 0\n", - "best_bal_acc = 0\n", - "\n", - "f1_scores_val = []\n", - "bal_scores_val = []\n", - "\n", - "for t in thresholds:\n", - " y_pred_val = (y_val_prob >= t).astype(int)\n", - "\n", - " f1 = f1_score(y_val, y_pred_val)\n", - " bal = balanced_accuracy_score(y_val, y_pred_val)\n", - "\n", - " f1_scores_val.append(f1)\n", - " bal_scores_val.append(bal)\n", - "\n", - " if f1 > best_f1:\n", - " best_f1 = f1\n", - " best_t = t\n", - " best_bal_acc = bal\n", - "\n", - "print(\"\\n==============================\")\n", - "print(\" VALIDATION RESULTS (TUNING)\")\n", - "print(\"==============================\")\n", - "print(f\"Optimal Threshold : {best_t:.3f}\")\n", - "print(f\"Best F1 (Val) : {best_f1:.4f}\")\n", - "print(f\"Balanced Acc (Val): {best_bal_acc:.4f}\")\n", - "\n", - "# =========================================================\n", - "# 3. PLOT VALIDATION CURVE\n", - "# =========================================================\n", - "plt.figure(figsize=(8,5))\n", - "plt.plot(thresholds, f1_scores_val, label=\"F1 (Val)\")\n", - "plt.plot(thresholds, bal_scores_val, label=\"Balanced Acc (Val)\")\n", - "plt.axvline(best_t, linestyle='--', label=f\"Best T = {best_t:.3f}\")\n", - "\n", - "plt.xlabel(\"Threshold\")\n", - "plt.ylabel(\"Score\")\n", - "plt.title(\"Threshold Optimization (Validation Set)\")\n", - "plt.legend()\n", - "plt.grid()\n", - "plt.show()\n", - "\n", - "# =========================================================\n", - "# 4. APPLY THRESHOLD ON TEST SET\n", - "# =========================================================\n", - "y_pred_test = (y_test_prob >= best_t).astype(int)\n", - "\n", - "f1_test = f1_score(y_test, y_pred_test)\n", - "bal_test = balanced_accuracy_score(y_test, y_pred_test)\n", - "\n", - "print(\"\\n==============================\")\n", - "print(\" TEST PERFORMANCE (FINAL)\")\n", - "print(\"==============================\")\n", - "print(f\"Threshold Used : {best_t:.3f}\")\n", - "print(f\"F1 (Test) : {f1_test:.4f}\")\n", - "print(f\"Balanced Acc (Test): {bal_test:.4f}\")\n", - "\n", - "# =========================================================\n", - "# 5. TEST CONFUSION MATRIX\n", - "# =========================================================\n", - "cm_test = confusion_matrix(y_test, y_pred_test)\n", - "\n", - "print(\"\\nConfusion Matrix (Test):\")\n", - "print(cm_test)\n", - "\n", - "print(\"\\nClassification Report (Test):\")\n", - "print(classification_report(y_test, y_pred_test, digits=4))\n", - "\n", - "plt.figure(figsize=(6,5))\n", - "plt.imshow(cm_test)\n", - "\n", - "for i in range(cm_test.shape[0]):\n", - " for j in range(cm_test.shape[1]):\n", - " plt.text(j, i, cm_test[i, j], ha='center', va='center')\n", - "\n", - "plt.xlabel(\"Predicted Label\")\n", - "plt.ylabel(\"True Label\")\n", - "plt.title(\"Confusion Matrix (Test Set)\")\n", - "plt.xticks([0,1], [\"Class 0\", \"Class 1\"])\n", - "plt.yticks([0,1], [\"Class 0\", \"Class 1\"])\n", - "\n", - "plt.colorbar()\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# =========================================================\n", - "# 6. ALSO PLOT TEST CURVE (FOR ANALYSIS ONLY)\n", - "# =========================================================\n", - "f1_scores_test = []\n", - "bal_scores_test = []\n", - "\n", - "for t in thresholds:\n", - " y_pred_t = (y_test_prob >= t).astype(int)\n", - "\n", - " f1_scores_test.append(f1_score(y_test, y_pred_t))\n", - " bal_scores_test.append(balanced_accuracy_score(y_test, y_pred_t))\n", - "\n", - "plt.figure(figsize=(8,5))\n", - "plt.plot(thresholds, f1_scores_test, label=\"F1 (Test)\")\n", - "plt.plot(thresholds, bal_scores_test, label=\"Balanced Acc (Test)\")\n", - "plt.axvline(best_t, linestyle='--', label=f\"Chosen T = {best_t:.3f}\")\n", - "\n", - "plt.xlabel(\"Threshold\")\n", - "plt.ylabel(\"Score\")\n", - "plt.title(\"Test Threshold Curve (Analysis Only - Not Used for Tuning)\")\n", - "plt.legend()\n", - "plt.grid()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Generating Deep Feature Analysis...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_7026/3990650171.py:49: FutureWarning: \n", - "\n", - "Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.\n", - "\n", - " sns.barplot(x=\"gain\", y=\"feature\", data=fi.head(25), palette=\"viridis\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Top Features by GAIN:\n", - "\n", - " feature gain split\n", - " phyloP100way_vertebrate_rankscore 178.044281 32840.0\n", - " phyloP470way_mammalian_rankscore 35.505932 34554.0\n", - " gnomad_af 31.278227 54468.0\n", - " chrom 25.794613 51327.0\n", - "phastCons470way_mammalian_rankscore 10.803547 10682.0\n", - " phastCons17way_primate_rankscore 9.607828 36827.0\n", - " pos 9.198639 101420.0\n", - " alt 7.743703 21750.0\n", - " ref 6.535082 21038.0\n", - " GERP_91_mammals_rankscore 6.072751 30064.0\n" - ] - } - ], - "source": [ - "# Cell 10: Advanced Feature Importance (SHAP + Gain)\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import shap\n", - "\n", - "print(\" Generating Deep Feature Analysis...\")\n", - "\n", - "# -------------------------------------------------------\n", - "# 1. SHAP Mean Absolute Value (Global Feature Importance)\n", - "# -------------------------------------------------------\n", - "plt.figure(figsize=(10, 6))\n", - "plt.title(\"Mean Absolute SHAP Values (Global Impact on Model)\")\n", - "# plot_type=\"bar\" automatically calculates the mean absolute value for you\n", - "shap.summary_plot(shap_values, shap_sample, plot_type=\"bar\", show=False)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# -------------------------------------------------------\n", - "# 2. XGBoost Native Feature Importance (Gain & Split)\n", - "# -------------------------------------------------------\n", - "# XGBoost handles this differently than LightGBM.\n", - "# We access the internal booster to get specific metrics.\n", - "booster = model.get_booster()\n", - "\n", - "# Get Gain (Average gain across all splits the feature is used in)\n", - "gain_scores = booster.get_score(importance_type='gain')\n", - "# Get Weight (Number of times a feature is used to split the data)\n", - "split_scores = booster.get_score(importance_type='weight')\n", - "\n", - "# Convert to DataFrame\n", - "# Note: XGBoost usually returns a dict {feature: score}, so we map it back\n", - "all_features = model.feature_names_in_\n", - "fi_data = []\n", - "\n", - "for feat in all_features:\n", - " fi_data.append({\n", - " \"feature\": feat,\n", - " \"gain\": gain_scores.get(feat, 0), # Default to 0 if not used\n", - " \"split\": split_scores.get(feat, 0)\n", - " })\n", - "\n", - "fi = pd.DataFrame(fi_data).sort_values(\"gain\", ascending=False)\n", - "\n", - "# -------------------------------------------------------\n", - "# 3. Plot Top 25 Features by GAIN\n", - "# -------------------------------------------------------\n", - "plt.figure(figsize=(10, 8))\n", - "sns.barplot(x=\"gain\", y=\"feature\", data=fi.head(25), palette=\"viridis\")\n", - "plt.title(\"Top Features by Gain (Structural Importance)\")\n", - "plt.xlabel(\"Average Gain (Contribution to Accuracy)\")\n", - "plt.ylabel(\"Feature\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# -------------------------------------------------------\n", - "# 4. Print Table\n", - "# -------------------------------------------------------\n", - "print(\" Top Features by GAIN:\\n\")\n", - "print(fi.head(30).to_string(index=False))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udcbe Saving Model Artifacts...\n", - " -> Saved Native JSON: pathopreter_sota_xgboost.json\n", - " -> Saved Joblib Pickle: pathopreter_sota_xgboost.pkl\n", - " -> Saved Feature List: pathopreter_features.txt\n", - "\n", - " Verification - Files in Root:\n", - " - pathopreter_features.txt : 0.00 MB\n", - " - pathopreter_sota_xgboost.pkl : 35.90 MB\n", - " - pathopreter_sota_xgboost.json : 50.47 MB\n" - ] - } - ], - "source": [ - "# Cell 11: Save SOTA Model to Root\n", - "import joblib\n", - "import os\n", - "\n", - "print(\"\ud83d\udcbe Saving Model Artifacts...\")\n", - "\n", - "# 1. Save as Native XGBoost JSON (Best for portability/deployment)\n", - "# This saves the tree structure, feature names, and weights\n", - "model_filename_json = \"pathopreter_sota_xgboost.json\"\n", - "model.save_model(model_filename_json)\n", - "print(f\" -> Saved Native JSON: {model_filename_json}\")\n", - "\n", - "# 2. Save as Joblib Pickle (Best for Python reloading)\n", - "# This saves the entire Sklearn wrapper, including hyperparameters and configuration\n", - "model_filename_pkl = \"pathopreter_sota_xgboost.pkl\"\n", - "joblib.dump(model, model_filename_pkl)\n", - "print(f\" -> Saved Joblib Pickle: {model_filename_pkl}\")\n", - "\n", - "# 3. Save Feature List (Crucial for inference consistency)\n", - "# If you reload later, you MUST ensure columns are in this exact order\n", - "feature_filename = \"pathopreter_features.txt\"\n", - "with open(feature_filename, \"w\") as f:\n", - " for feat in features:\n", - " f.write(f\"{feat}\\n\")\n", - "print(f\" -> Saved Feature List: {feature_filename}\")\n", - "\n", - "# --- Verification ---\n", - "print(\"\\n Verification - Files in Root:\")\n", - "files = [f for f in os.listdir('.') if \"pathopreter\" in f]\n", - "for f in files:\n", - " size_mb = os.path.getsize(f) / (1024 * 1024)\n", - " print(f\" - {f} : {size_mb:.2f} MB\")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " STARTING EXPERIMENT 2: ROBUSTNESS CHECK (NO LOCATION BIAS)...\n", - " -> Original Feature Count: 10\n", - " -> Robust Feature Count: 8\n", - " -> DROPPED: ['chrom', 'pos']\n", - "\ud83d\udd25 Training Robust Model (This may take a moment)...\n", - "[0]\tvalidation_0-auc:0.93564\tvalidation_0-logloss:0.68610\tvalidation_1-auc:0.93442\tvalidation_1-logloss:0.68610\n", - "[500]\tvalidation_0-auc:0.94672\tvalidation_0-logloss:0.27310\tvalidation_1-auc:0.94417\tvalidation_1-logloss:0.27910\n", - "[1000]\tvalidation_0-auc:0.94798\tvalidation_0-logloss:0.26800\tvalidation_1-auc:0.94445\tvalidation_1-logloss:0.27766\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1500]\tvalidation_0-auc:0.94832\tvalidation_0-logloss:0.26676\tvalidation_1-auc:0.94446\tvalidation_1-logloss:0.27752\n", - "[2000]\tvalidation_0-auc:0.94851\tvalidation_0-logloss:0.26601\tvalidation_1-auc:0.94451\tvalidation_1-logloss:0.27748\n", - "[2500]\tvalidation_0-auc:0.94865\tvalidation_0-logloss:0.26549\tvalidation_1-auc:0.94456\tvalidation_1-logloss:0.27744\n", - "[3000]\tvalidation_0-auc:0.94878\tvalidation_0-logloss:0.26498\tvalidation_1-auc:0.94459\tvalidation_1-logloss:0.27742\n", - "[3104]\tvalidation_0-auc:0.94880\tvalidation_0-logloss:0.26489\tvalidation_1-auc:0.94460\tvalidation_1-logloss:0.27741\n", - "\n", - "==========================================\n", - " ROBUSTNESS CHECK RESULTS\n", - "==========================================\n", - "Baseline AUC (With Loc) : 0.766809\n", - "Robust AUC (No Loc) : 0.760368\n", - "Difference : -0.006441\n", - "\n", - "\u2705 VERDICT: PASSED. The model is learning BIOLOGY, not location.\n", - " It still beats REVEL/AlphaMissense without cheating!\n", - "\n", - "==========================================\n", - " NEW LEADERBOARD (ROBUST MODEL)\n", - "==========================================\n", - " 1. My Robust XGB AUC=0.760368 PR=0.769851\n", - " 2. BayesDel AUC=0.594821 PR=0.706338\n", - " 3. CADD AUC=0.592509 PR=0.704837\n", - " 4. AlphaMissense AUC=0.592032 PR=0.603459\n", - " 5. ClinPred AUC=0.590860 PR=0.606637\n", - " 6. REVEL AUC=0.587114 PR=0.600423\n", - " 7. PrimateAI AUC=0.566141 PR=0.581314\n", - " 8. MPC AUC=0.559678 PR=0.578559\n", - "\n", - "\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\n", - " feature gain\n", - " phyloP100way_vertebrate_rankscore 230.754532\n", - " phyloP470way_mammalian_rankscore 59.637676\n", - " gnomad_af 53.463173\n", - " phastCons17way_primate_rankscore 17.051365\n", - "phastCons470way_mammalian_rankscore 13.066975\n", - " ref 11.242688\n", - " alt 10.989820\n", - " GERP_91_mammals_rankscore 7.565748\n" - ] - } - ], - "source": [ - "# Cell 12: Robustness Experiment (Drop Chrom/Pos) & Re-Benchmark\n", - "import pandas as pd\n", - "import xgboost as xgb\n", - "import numpy as np\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import roc_auc_score, average_precision_score\n", - "\n", - "print(\" STARTING EXPERIMENT 2: ROBUSTNESS CHECK (NO LOCATION BIAS)...\")\n", - "\n", - "# ==========================================\n", - "# 1. Define Robust Feature Set\n", - "# ==========================================\n", - "# Remove 'chrom' and 'pos' to force the model to learn biology only\n", - "features_robust = [f for f in features if f not in ['chrom', 'pos']]\n", - "\n", - "print(f\" -> Original Feature Count: {len(features)}\")\n", - "print(f\" -> Robust Feature Count: {len(features_robust)}\")\n", - "print(f\" -> DROPPED: ['chrom', 'pos']\")\n", - "\n", - "# ==========================================\n", - "# 2. Prepare Data (Robust Split)\n", - "# ==========================================\n", - "X_rob = train_df[features_robust]\n", - "y_rob = train_df['y']\n", - "\n", - "X_test_rob = test_df[features_robust]\n", - "y_test_rob = test_df['y']\n", - "\n", - "# Stratified Validation Split\n", - "X_train_r, X_val_r, y_train_r, y_val_r = train_test_split(\n", - " X_rob, y_rob, test_size=0.10, stratify=y_rob, random_state=SEED\n", - ")\n", - "\n", - "# ==========================================\n", - "# 3. Train Robust Model (SOTA Params)\n", - "# ==========================================\n", - "print(\"\ud83d\udd25 Training Robust Model (This may take a moment)...\")\n", - "\n", - "model_robust = xgb.XGBClassifier(\n", - " n_estimators=10000,\n", - " learning_rate=0.01,\n", - " max_depth=10,\n", - " gamma=1.5,\n", - " min_child_weight=5,\n", - " reg_alpha=0.5,\n", - " reg_lambda=1.0,\n", - " subsample=0.8,\n", - " colsample_bytree=0.8,\n", - " n_jobs=n_cores,\n", - " tree_method='hist',\n", - " enable_categorical=True, # Still needed for ref/alt\n", - " early_stopping_rounds=200,\n", - " random_state=SEED,\n", - " objective='binary:logistic',\n", - " eval_metric=['auc', 'logloss']\n", - ")\n", - "\n", - "model_robust.fit(\n", - " X_train_r, y_train_r,\n", - " eval_set=[(X_train_r, y_train_r), (X_val_r, y_val_r)],\n", - " verbose=500\n", - ")\n", - "\n", - "# ==========================================\n", - "# 4. Evaluation & Comparison\n", - "# ==========================================\n", - "print(\"\\n==========================================\")\n", - "print(\" ROBUSTNESS CHECK RESULTS\")\n", - "print(\"==========================================\")\n", - "\n", - "# Predict\n", - "y_prob_r = model_robust.predict_proba(X_test_rob)[:, 1]\n", - "\n", - "# Metrics\n", - "roc_auc_r = roc_auc_score(y_test_rob, y_prob_r)\n", - "pr_auc_r = average_precision_score(y_test_rob, y_prob_r)\n", - "\n", - "# Compare with Baseline (Assumes 'roc_auc' exists from previous cell)\n", - "# If running fresh, we assume the previous score was ~0.7477\n", - "baseline_auc = roc_auc if 'roc_auc' in locals() else 0.747704\n", - "delta = roc_auc_r - baseline_auc\n", - "\n", - "print(f\"Baseline AUC (With Loc) : {baseline_auc:.6f}\")\n", - "print(f\"Robust AUC (No Loc) : {roc_auc_r:.6f}\")\n", - "print(f\"Difference : {delta:.6f}\")\n", - "\n", - "if roc_auc_r > 0.70:\n", - " print(\"\\n\u2705 VERDICT: PASSED. The model is learning BIOLOGY, not location.\")\n", - " print(\" It still beats REVEL/AlphaMissense without cheating!\")\n", - "else:\n", - " print(\"\\n\u26a0\ufe0f VERDICT: WARNING. Significant drop. Re-evaluate features.\")\n", - "\n", - "# ==========================================\n", - "# 5. Re-Benchmark Against SOTA Tools\n", - "# ==========================================\n", - "benchmarks = {\n", - " \"CADD\": \"CADD_raw_rankscore\",\n", - " \"REVEL\": \"REVEL_rankscore\",\n", - " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", - " \"PrimateAI\": \"PrimateAI_rankscore\",\n", - " \"MPC\": \"MPC_rankscore\",\n", - " \"ClinPred\": \"ClinPred_rankscore\",\n", - " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", - "}\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\" NEW LEADERBOARD (ROBUST MODEL)\")\n", - "print(\"==========================================\")\n", - "\n", - "benchmark_results = []\n", - "\n", - "for name, col in benchmarks.items():\n", - " if col not in test_df.columns:\n", - " continue\n", - " \n", - " scores = test_df[col].fillna(0).values\n", - " \n", - " try:\n", - " b_auc = roc_auc_score(y_test_rob, scores)\n", - " b_pr = average_precision_score(y_test_rob, scores)\n", - " benchmark_results.append((name, b_auc, b_pr))\n", - " except ValueError:\n", - " pass\n", - "\n", - "# Add our Robust Model to the list\n", - "ranking = sorted([(\"My Robust XGB\", roc_auc_r, pr_auc_r)] + benchmark_results, key=lambda x: x[1], reverse=True)\n", - "\n", - "for i, (name, a, p) in enumerate(ranking, 1):\n", - " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", - "\n", - "# ==========================================\n", - "# 6. Check New Top Features (What took over?)\n", - "# ==========================================\n", - "booster = model_robust.get_booster()\n", - "gain_scores = booster.get_score(importance_type='gain')\n", - "fi_rob = pd.DataFrame([\n", - " {\"feature\": k, \"gain\": v} for k,v in gain_scores.items()\n", - "]).sort_values(\"gain\", ascending=False)\n", - "\n", - "print(\"\\n\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\")\n", - "print(fi_rob.head(10).to_string(index=False))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udcbe Saving PathoPreter Artifacts...\n", - " -> Saved SOTA Model: pathopreter_sota_xgboost.json & pathopreter_sota_xgboost.pkl\n", - " -> Saved Robust Model: pathopreter_robust_biology.json & pathopreter_robust_biology.pkl\n", - "\n", - " Final Inventory (Root Directory):\n", - " - pathopreter_features.txt : 0.00 MB\n", - " - pathopreter_features_robust.txt : 0.00 MB\n", - " - pathopreter_features_sota.txt : 0.00 MB\n", - " - pathopreter_robust_biology.json : 19.56 MB\n", - " - pathopreter_robust_biology.pkl : 13.17 MB\n", - " - pathopreter_sota_xgboost.json : 50.47 MB\n", - " - pathopreter_sota_xgboost.pkl : 35.90 MB\n" - ] - } - ], - "source": [ - "# Cell 13: Save All Models with 'PathoPreter' Naming\n", - "import joblib\n", - "import os\n", - "\n", - "print(\"\ud83d\udcbe Saving PathoPreter Artifacts...\")\n", - "\n", - "# ==========================================\n", - "# 1. Save The \"SOTA\" Model (The Competition Winner)\n", - "# ==========================================\n", - "# Native XGBoost JSON (Portable)\n", - "sota_json = \"pathopreter_sota_xgboost.json\"\n", - "model.save_model(sota_json)\n", - "\n", - "# Joblib Pickle (Python Ready)\n", - "sota_pkl = \"pathopreter_sota_xgboost.pkl\"\n", - "joblib.dump(model, sota_pkl)\n", - "\n", - "# Feature List (SOTA)\n", - "with open(\"pathopreter_features_sota.txt\", \"w\") as f:\n", - " for feat in features:\n", - " f.write(f\"{feat}\\n\")\n", - "\n", - "print(f\" -> Saved SOTA Model: {sota_json} & {sota_pkl}\")\n", - "\n", - "# ==========================================\n", - "# 2. Save The \"Robust\" Model (The Scientist)\n", - "# ==========================================\n", - "# Native XGBoost JSON (Portable)\n", - "robust_json = \"pathopreter_robust_biology.json\"\n", - "model_robust.save_model(robust_json)\n", - "\n", - "# Joblib Pickle (Python Ready)\n", - "robust_pkl = \"pathopreter_robust_biology.pkl\"\n", - "joblib.dump(model_robust, robust_pkl)\n", - "\n", - "# Feature List (Robust - No Chrom/Pos)\n", - "with open(\"pathopreter_features_robust.txt\", \"w\") as f:\n", - " for feat in features_robust:\n", - " f.write(f\"{feat}\\n\")\n", - "\n", - "print(f\" -> Saved Robust Model: {robust_json} & {robust_pkl}\")\n", - "\n", - "# ==========================================\n", - "# 3. Verification\n", - "# ==========================================\n", - "print(\"\\n Final Inventory (Root Directory):\")\n", - "files = [f for f in os.listdir('.') if \"pathopreter\" in f]\n", - "for f in sorted(files):\n", - " size_mb = os.path.getsize(f) / (1024 * 1024)\n", - " print(f\" - {f:<35} : {size_mb:.2f} MB\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " Loading ONLY bottom 100000 rows from: 900k_test.parquet\n", - " Loading trained robust model from JSON...\n", - " Model loaded.\n", - " -> Training feature count: 8\n", - " -> Loaded shape: (100000, 17)\n", - " -> Label counts: {0: 92813, 1: 7187}\n", - " -> Final X_new shape: (100000, 8)\n", - "\n", - " Running predictions...\n", - "\n", - "==========================================\n", - " RESULTS ON BOTTOM 100K ROWS (NEW DATA)\n", - "==========================================\n", - "Robust Model AUC : 0.972796\n", - "Robust Model PR-AUC : 0.724422\n", - "Baseline AUC (With Loc) : 0.766809\n", - "Robust AUC (No Loc) : 0.972796\n", - "Difference : 0.205987\n", - "\n", - "==========================================\n", - " NEW LEADERBOARD (ROBUST MODEL ON 100K)\n", - "==========================================\n", - " 1. My Robust XGB AUC=0.972796 PR=0.724422\n", - " 2. ClinPred AUC=0.960248 PR=0.592119\n", - " 3. REVEL AUC=0.955430 PR=0.582778\n", - " 4. AlphaMissense AUC=0.952558 PR=0.576359\n", - " 5. BayesDel AUC=0.935482 PR=0.688013\n", - " 6. PrimateAI AUC=0.929806 PR=0.483381\n", - " 7. MPC AUC=0.921036 PR=0.495954\n", - " 8. CADD AUC=0.898042 PR=0.656561\n", - "\n", - " Top Features in Robust Model (Pure Biology):\n", - " feature gain\n", - " phyloP100way_vertebrate_rankscore 230.754532\n", - " phyloP470way_mammalian_rankscore 59.637676\n", - " gnomad_af 53.463173\n", - " phastCons17way_primate_rankscore 17.051365\n", - "phastCons470way_mammalian_rankscore 13.066975\n", - " ref 11.242688\n", - " alt 10.989820\n", - " GERP_91_mammals_rankscore 7.565748\n", - "\n", - " Done.\n" - ] - } - ], - "source": [ - "# ============================================================\n", - "# CELL: TEST ROBUST MODEL ON BOTTOM 100K ROWS (CATEGORICAL-SAFE)\n", - "# ============================================================\n", - "\n", - "import pandas as pd\n", - "import numpy as np\n", - "import xgboost as xgb\n", - "from sklearn.metrics import roc_auc_score, average_precision_score\n", - "\n", - "DATA_PATH = \"900k_test.parquet\"\n", - "MODEL_JSON = \"pathopreter_robust_biology.json\"\n", - "N_ROWS = 100_000\n", - "\n", - "print(f\"\\n Loading ONLY bottom {N_ROWS} rows from: {DATA_PATH}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 1) LOAD YOUR ROBUST MODEL (JSON \u2013 SAFE FORMAT)\n", - "# ------------------------------------------------------------\n", - "print(\" Loading trained robust model from JSON...\")\n", - "\n", - "booster = xgb.Booster()\n", - "booster.load_model(MODEL_JSON)\n", - "\n", - "# Exact feature list used in training\n", - "train_feature_names = booster.feature_names\n", - "\n", - "print(f\" Model loaded.\")\n", - "print(f\" -> Training feature count: {len(train_feature_names)}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 2) LOAD ONLY WHAT THE MODEL NEEDS + LABEL\n", - "# ------------------------------------------------------------\n", - "\n", - "# needed_cols = list(set(train_feature_names + [\"clean_label\"]))\n", - "needed_cols = list(set(\n", - " train_feature_names \n", - " + [\"clean_label\"] \n", - " + list(benchmarks.values())\n", - "))\n", - "df_new = (\n", - " pd.read_parquet(DATA_PATH, columns=needed_cols)\n", - " .tail(N_ROWS)\n", - " .reset_index(drop=True)\n", - ")\n", - "\n", - "# Make label exactly like training\n", - "df_new[\"y\"] = (df_new[\"clean_label\"] == \"Pathogenic\").astype(int)\n", - "\n", - "print(f\" -> Loaded shape: {df_new.shape}\")\n", - "print(f\" -> Label counts: {df_new['y'].value_counts().to_dict()}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 3) BUILD X_new IN EXACT TRAINING ORDER (CRITICAL)\n", - "# ------------------------------------------------------------\n", - "\n", - "X_new = pd.DataFrame()\n", - "\n", - "for feat in train_feature_names:\n", - " if feat in df_new.columns:\n", - " X_new[feat] = df_new[feat]\n", - " else:\n", - " # If some feature is missing in new data, add it\n", - " X_new[feat] = 0.0\n", - "\n", - "# -------- HANDLE CATEGORICAL COLUMNS SAFELY --------\n", - "for cat_col in [\"ref\", \"alt\"]:\n", - " if cat_col in X_new.columns:\n", - " # Convert to categorical\n", - " X_new[cat_col] = X_new[cat_col].astype(\"category\")\n", - "\n", - " # If NaNs exist, add \"N\" (unknown nucleotide) as a valid category\n", - " if X_new[cat_col].isna().any():\n", - " if \"N\" not in X_new[cat_col].cat.categories:\n", - " X_new[cat_col] = X_new[cat_col].cat.add_categories(\"N\")\n", - " X_new[cat_col] = X_new[cat_col].fillna(\"N\")\n", - "\n", - "# -------- FILL NA FOR NUMERIC COLUMNS ONLY --------\n", - "num_cols = X_new.select_dtypes(include=[np.number]).columns\n", - "X_new[num_cols] = X_new[num_cols].fillna(0)\n", - "\n", - "y_new = df_new[\"y\"].values\n", - "\n", - "print(f\" -> Final X_new shape: {X_new.shape}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 4) PREDICT (MATCHES YOUR TRAINING)\n", - "# ------------------------------------------------------------\n", - "print(\"\\n Running predictions...\")\n", - "\n", - "dmat = xgb.DMatrix(\n", - " X_new,\n", - " feature_names=train_feature_names,\n", - " enable_categorical=True # MATCHES YOUR TRAINING\n", - ")\n", - "\n", - "y_prob_new = booster.predict(dmat, validate_features=False)\n", - "\n", - "# ------------------------------------------------------------\n", - "# 5) METRICS (SAME STYLE AS YOUR CELL)\n", - "# ------------------------------------------------------------\n", - "roc_auc_new = roc_auc_score(y_new, y_prob_new)\n", - "pr_auc_new = average_precision_score(y_new, y_prob_new)\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\" RESULTS ON BOTTOM 100K ROWS (NEW DATA)\")\n", - "print(\"==========================================\")\n", - "print(f\"Robust Model AUC : {roc_auc_new:.6f}\")\n", - "print(f\"Robust Model PR-AUC : {pr_auc_new:.6f}\")\n", - "\n", - "# Baseline comparison (same logic as your notebook)\n", - "baseline_auc = roc_auc if 'roc_auc' in locals() else 0.747704\n", - "delta = roc_auc_new - baseline_auc\n", - "\n", - "print(f\"Baseline AUC (With Loc) : {baseline_auc:.6f}\")\n", - "print(f\"Robust AUC (No Loc) : {roc_auc_new:.6f}\")\n", - "print(f\"Difference : {delta:.6f}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 6) BENCHMARK AGAINST SOTA (SAME AS YOUR CELL)\n", - "# ------------------------------------------------------------\n", - "benchmarks = {\n", - " \"CADD\": \"CADD_raw_rankscore\",\n", - " \"REVEL\": \"REVEL_rankscore\",\n", - " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", - " \"PrimateAI\": \"PrimateAI_rankscore\",\n", - " \"MPC\": \"MPC_rankscore\",\n", - " \"ClinPred\": \"ClinPred_rankscore\",\n", - " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", - "}\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\" NEW LEADERBOARD (ROBUST MODEL ON 100K)\")\n", - "print(\"==========================================\")\n", - "\n", - "benchmark_results = []\n", - "\n", - "for name, col in benchmarks.items():\n", - " if col not in df_new.columns:\n", - " print(f\" Skipping {name} (column missing)\")\n", - " continue\n", - "\n", - " scores = df_new[col].fillna(0).values\n", - "\n", - " try:\n", - " b_auc = roc_auc_score(y_new, scores)\n", - " b_pr = average_precision_score(y_new, scores)\n", - " benchmark_results.append((name, b_auc, b_pr))\n", - " except ValueError:\n", - " pass\n", - "\n", - "ranking = sorted(\n", - " [(\"My Robust XGB\", roc_auc_new, pr_auc_new)] + benchmark_results,\n", - " key=lambda x: x[1],\n", - " reverse=True\n", - ")\n", - "\n", - "for i, (name, a, p) in enumerate(ranking, 1):\n", - " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 7) TOP FEATURES (GAIN \u2014 SAME AS YOUR CELL)\n", - "# ------------------------------------------------------------\n", - "print(\"\\n Top Features in Robust Model (Pure Biology):\")\n", - "\n", - "gain_scores = booster.get_score(importance_type='gain')\n", - "fi_rob = (\n", - " pd.DataFrame({\"feature\": list(gain_scores.keys()),\n", - " \"gain\": list(gain_scores.values())})\n", - " .sort_values(\"gain\", ascending=False)\n", - ")\n", - "\n", - "print(fi_rob.head(10).to_string(index=False))\n", - "print(\"\\n Done.\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "914707\n" - ] - } - ], - "source": [ - "import pyarrow.parquet as pq\n", - "\n", - "pf = pq.ParquetFile(\"900k_test.parquet\")\n", - "print(pf.metadata.num_rows)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\ud83d\udd0d Loading ONLY bottom 900000 rows from: 900k_test.parquet\n", - "\ud83d\udd01 Loading trained robust model from JSON...\n", - "\u2705 Model loaded.\n", - " -> Training feature count: 8\n", - " -> Loaded shape: (900000, 17)\n", - " -> Label counts: {0: 891684, 1: 8316}\n", - " -> Final X_new shape: (900000, 8)\n", - "\n", - "\ud83d\udd25 Running predictions...\n", - "\n", - "==========================================\n", - "\ud83e\uddea RESULTS ON BOTTOM 100K ROWS (NEW DATA)\n", - "==========================================\n", - "Robust Model AUC : 0.982708\n", - "Robust Model PR-AUC : 0.523736\n", - "Baseline AUC (With Loc) : 0.766809\n", - "Robust AUC (No Loc) : 0.982708\n", - "Difference : 0.215899\n", - "\n", - "==========================================\n", - "\u2694\ufe0f NEW LEADERBOARD (ROBUST MODEL ON 900K)\n", - "==========================================\n", - " 1. My Robust XGB AUC=0.982708 PR=0.523736\n", - " 2. BayesDel AUC=0.949125 PR=0.669971\n", - " 3. ClinPred AUC=0.943464 PR=0.487158\n", - " 4. REVEL AUC=0.937387 PR=0.466993\n", - " 5. AlphaMissense AUC=0.933589 PR=0.433256\n", - " 6. PrimateAI AUC=0.911503 PR=0.224050\n", - " 7. CADD AUC=0.909277 PR=0.573139\n", - " 8. MPC AUC=0.901957 PR=0.250881\n", - "\n", - "\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\n", - " feature gain\n", - " phyloP100way_vertebrate_rankscore 230.754532\n", - " phyloP470way_mammalian_rankscore 59.637676\n", - " gnomad_af 53.463173\n", - " phastCons17way_primate_rankscore 17.051365\n", - "phastCons470way_mammalian_rankscore 13.066975\n", - " ref 11.242688\n", - " alt 10.989820\n", - " GERP_91_mammals_rankscore 7.565748\n", - "\n", - "\u2705 Done.\n" - ] - } - ], - "source": [ - "# ============================================================\n", - "# CELL: TEST ROBUST MODEL ON BOTTOM 900K ROWS (CATEGORICAL-SAFE)\n", - "# ============================================================\n", - "\n", - "import pandas as pd\n", - "import numpy as np\n", - "import xgboost as xgb\n", - "from sklearn.metrics import roc_auc_score, average_precision_score\n", - "\n", - "DATA_PATH = \"900k_test.parquet\"\n", - "MODEL_JSON = \"pathopreter_robust_biology.json\"\n", - "N_ROWS = 900_000\n", - "\n", - "print(f\"\\n\ud83d\udd0d Loading ONLY bottom {N_ROWS} rows from: {DATA_PATH}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 1) LOAD YOUR ROBUST MODEL (JSON \u2013 SAFE FORMAT)\n", - "# ------------------------------------------------------------\n", - "print(\"\ud83d\udd01 Loading trained robust model from JSON...\")\n", - "\n", - "booster = xgb.Booster()\n", - "booster.load_model(MODEL_JSON)\n", - "\n", - "# Exact feature list used in training\n", - "train_feature_names = booster.feature_names\n", - "\n", - "print(f\"\u2705 Model loaded.\")\n", - "print(f\" -> Training feature count: {len(train_feature_names)}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 2) LOAD ONLY WHAT THE MODEL NEEDS + LABEL\n", - "# ------------------------------------------------------------\n", - "\n", - "# needed_cols = list(set(train_feature_names + [\"clean_label\"]))\n", - "needed_cols = list(set(\n", - " train_feature_names \n", - " + [\"clean_label\"] \n", - " + list(benchmarks.values())\n", - "))\n", - "df_new = (\n", - " pd.read_parquet(DATA_PATH, columns=needed_cols)\n", - " .tail(N_ROWS)\n", - " .reset_index(drop=True)\n", - ")\n", - "\n", - "# Make label exactly like training\n", - "df_new[\"y\"] = (df_new[\"clean_label\"] == \"Pathogenic\").astype(int)\n", - "\n", - "print(f\" -> Loaded shape: {df_new.shape}\")\n", - "print(f\" -> Label counts: {df_new['y'].value_counts().to_dict()}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 3) BUILD X_new IN EXACT TRAINING ORDER (CRITICAL)\n", - "# ------------------------------------------------------------\n", - "\n", - "X_new = pd.DataFrame()\n", - "\n", - "for feat in train_feature_names:\n", - " if feat in df_new.columns:\n", - " X_new[feat] = df_new[feat]\n", - " else:\n", - " # If some feature is missing in new data, add it\n", - " X_new[feat] = 0.0\n", - "\n", - "# -------- HANDLE CATEGORICAL COLUMNS SAFELY --------\n", - "for cat_col in [\"ref\", \"alt\"]:\n", - " if cat_col in X_new.columns:\n", - " # Convert to categorical\n", - " X_new[cat_col] = X_new[cat_col].astype(\"category\")\n", - "\n", - " # If NaNs exist, add \"N\" (unknown nucleotide) as a valid category\n", - " if X_new[cat_col].isna().any():\n", - " if \"N\" not in X_new[cat_col].cat.categories:\n", - " X_new[cat_col] = X_new[cat_col].cat.add_categories(\"N\")\n", - " X_new[cat_col] = X_new[cat_col].fillna(\"N\")\n", - "\n", - "# -------- FILL NA FOR NUMERIC COLUMNS ONLY --------\n", - "num_cols = X_new.select_dtypes(include=[np.number]).columns\n", - "X_new[num_cols] = X_new[num_cols].fillna(0)\n", - "\n", - "y_new = df_new[\"y\"].values\n", - "\n", - "print(f\" -> Final X_new shape: {X_new.shape}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 4) PREDICT (MATCHES YOUR TRAINING)\n", - "# ------------------------------------------------------------\n", - "print(\"\\n\ud83d\udd25 Running predictions...\")\n", - "\n", - "dmat = xgb.DMatrix(\n", - " X_new,\n", - " feature_names=train_feature_names,\n", - " enable_categorical=True # MATCHES YOUR TRAINING\n", - ")\n", - "\n", - "y_prob_new = booster.predict(dmat, validate_features=False)\n", - "\n", - "# ------------------------------------------------------------\n", - "# 5) METRICS (SAME STYLE AS YOUR CELL)\n", - "# ------------------------------------------------------------\n", - "roc_auc_new = roc_auc_score(y_new, y_prob_new)\n", - "pr_auc_new = average_precision_score(y_new, y_prob_new)\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\"\ud83e\uddea RESULTS ON BOTTOM 100K ROWS (NEW DATA)\")\n", - "print(\"==========================================\")\n", - "print(f\"Robust Model AUC : {roc_auc_new:.6f}\")\n", - "print(f\"Robust Model PR-AUC : {pr_auc_new:.6f}\")\n", - "\n", - "# Baseline comparison (same logic as your notebook)\n", - "baseline_auc = roc_auc if 'roc_auc' in locals() else 0.747704\n", - "delta = roc_auc_new - baseline_auc\n", - "\n", - "print(f\"Baseline AUC (With Loc) : {baseline_auc:.6f}\")\n", - "print(f\"Robust AUC (No Loc) : {roc_auc_new:.6f}\")\n", - "print(f\"Difference : {delta:.6f}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 6) BENCHMARK AGAINST SOTA (SAME AS YOUR CELL)\n", - "# ------------------------------------------------------------\n", - "benchmarks = {\n", - " \"CADD\": \"CADD_raw_rankscore\",\n", - " \"REVEL\": \"REVEL_rankscore\",\n", - " \"AlphaMissense\": \"AlphaMissense_rankscore\",\n", - " \"PrimateAI\": \"PrimateAI_rankscore\",\n", - " \"MPC\": \"MPC_rankscore\",\n", - " \"ClinPred\": \"ClinPred_rankscore\",\n", - " \"BayesDel\": \"BayesDel_addAF_rankscore\"\n", - "}\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\"\u2694\ufe0f NEW LEADERBOARD (ROBUST MODEL ON 900K)\")\n", - "print(\"==========================================\")\n", - "\n", - "benchmark_results = []\n", - "\n", - "for name, col in benchmarks.items():\n", - " if col not in df_new.columns:\n", - " print(f\"\u26a0\ufe0f Skipping {name} (column missing)\")\n", - " continue\n", - "\n", - " scores = df_new[col].fillna(0).values\n", - "\n", - " try:\n", - " b_auc = roc_auc_score(y_new, scores)\n", - " b_pr = average_precision_score(y_new, scores)\n", - " benchmark_results.append((name, b_auc, b_pr))\n", - " except ValueError:\n", - " pass\n", - "\n", - "ranking = sorted(\n", - " [(\"My Robust XGB\", roc_auc_new, pr_auc_new)] + benchmark_results,\n", - " key=lambda x: x[1],\n", - " reverse=True\n", - ")\n", - "\n", - "for i, (name, a, p) in enumerate(ranking, 1):\n", - " print(f\"{i:2d}. {name:14s} AUC={a:.6f} PR={p:.6f}\")\n", - "\n", - "# ------------------------------------------------------------\n", - "# 7) TOP FEATURES (GAIN \u2014 SAME AS YOUR CELL)\n", - "# ------------------------------------------------------------\n", - "print(\"\\n\ud83c\udfc6 Top Features in Robust Model (Pure Biology):\")\n", - "\n", - "gain_scores = booster.get_score(importance_type='gain')\n", - "fi_rob = (\n", - " pd.DataFrame({\"feature\": list(gain_scores.keys()),\n", - " \"gain\": list(gain_scores.values())})\n", - " .sort_values(\"gain\", ascending=False)\n", - ")\n", - "\n", - "print(fi_rob.head(10).to_string(index=False))\n", - "print(\"\\n\u2705 Done.\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Loading test_enriched.parquet...\n", - "Using 10 features (training order)\n", - "\n", - " BASELINE: AUC=0.766809 | PR=0.771672\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 80%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 8/10 [00:01<00:00, 4.38it/s]" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 10/10 [00:02<00:00, 4.35it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " PERMUTATION ABLATION RESULTS:\n", - " feature auc pr auc_drop pr_drop\n", - " gnomad_af 0.574536 0.625798 0.192274 0.145874\n", - " pos 0.743102 0.745121 0.023707 0.026551\n", - " chrom 0.747204 0.748561 0.019605 0.023111\n", - " ref 0.748308 0.747789 0.018501 0.023883\n", - " alt 0.748902 0.754349 0.017907 0.017323\n", - " phyloP100way_vertebrate_rankscore 0.753048 0.727441 0.013761 0.044231\n", - " phastCons17way_primate_rankscore 0.759336 0.760803 0.007473 0.010869\n", - " GERP_91_mammals_rankscore 0.765210 0.759358 0.001599 0.012314\n", - "phastCons470way_mammalian_rankscore 0.765913 0.771196 0.000896 0.000476\n", - " phyloP470way_mammalian_rankscore 0.773115 0.771422 -0.006306 0.000250\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "from sklearn.metrics import roc_auc_score, average_precision_score\n", - "from tqdm import tqdm\n", - "\n", - "SEED = 42\n", - "\n", - "print(\" Loading test_enriched.parquet...\")\n", - "test_df = pd.read_parquet(\"test_enriched_SEQUENCES.parquet\")\n", - "\n", - "# GET EXACT TRAINING FEATURE ORDER\n", - "features = model.get_booster().feature_names\n", - "\n", - "print(f\"Using {len(features)} features (training order)\")\n", - "\n", - "# ---- PREP ----\n", - "X_test = pd.DataFrame()\n", - "\n", - "# build in correct order\n", - "for feat in features:\n", - " if feat in test_df.columns:\n", - " X_test[feat] = test_df[feat]\n", - " else:\n", - " X_test[feat] = 0 # safety\n", - "\n", - "y_test = (test_df[\"clean_label\"] == \"Pathogenic\").astype(int).values\n", - "\n", - "# ---- FIX TYPES ----\n", - "for col in [\"chrom\", \"ref\", \"alt\"]:\n", - " if col in X_test.columns:\n", - " X_test[col] = X_test[col].astype(\"category\")\n", - "\n", - "for col in X_test.columns:\n", - " if col not in [\"chrom\", \"ref\", \"alt\"]:\n", - " X_test[col] = pd.to_numeric(X_test[col], errors=\"coerce\")\n", - "\n", - "X_test = X_test.fillna(0)\n", - "\n", - "# ---- BASELINE ----\n", - "y_prob = model.predict_proba(X_test)[:, 1]\n", - "\n", - "base_auc = roc_auc_score(y_test, y_prob)\n", - "base_pr = average_precision_score(y_test, y_prob)\n", - "\n", - "print(f\"\\n BASELINE: AUC={base_auc:.6f} | PR={base_pr:.6f}\")\n", - "\n", - "# ---- PERMUTATION ABLATION ----\n", - "results = []\n", - "\n", - "cat_cols = [\"chrom\", \"ref\", \"alt\"]\n", - "\n", - "for feat in tqdm(features):\n", - " X_perm = X_test.copy()\n", - "\n", - " # shuffle feature\n", - " X_perm[feat] = np.random.permutation(X_perm[feat].values)\n", - "\n", - " # \ud83d\udd25 FIX: restore dtypes AFTER permutation\n", - " for col in cat_cols:\n", - " if col in X_perm.columns:\n", - " X_perm[col] = X_perm[col].astype(\"category\")\n", - "\n", - " for col in X_perm.columns:\n", - " if col not in cat_cols:\n", - " X_perm[col] = pd.to_numeric(X_perm[col], errors=\"coerce\")\n", - "\n", - " X_perm = X_perm.fillna(0)\n", - "\n", - " # predict\n", - " y_prob_perm = model.predict_proba(X_perm)[:, 1]\n", - "\n", - " auc = roc_auc_score(y_test, y_prob_perm)\n", - " pr = average_precision_score(y_test, y_prob_perm)\n", - "\n", - " results.append({\n", - " \"feature\": feat,\n", - " \"auc\": auc,\n", - " \"pr\": pr,\n", - " \"auc_drop\": base_auc - auc,\n", - " \"pr_drop\": base_pr - pr\n", - " })\n", - "\n", - "# ---- RESULTS ----\n", - "ablation_df = pd.DataFrame(results).sort_values(\"auc_drop\", ascending=False)\n", - "\n", - "print(\"\\n PERMUTATION ABLATION RESULTS:\")\n", - "print(ablation_df.to_string(index=False))" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udcc9 Generating Ablation ROC Curves...\n" - ] - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "from sklearn.metrics import roc_curve, auc\n", - "\n", - "print(\"\ud83d\udcc9 Generating Ablation ROC Curves...\")\n", - "\n", - "plt.figure(figsize=(12, 8))\n", - "\n", - "# baseline\n", - "fpr_base, tpr_base, _ = roc_curve(y_test, y_prob)\n", - "plt.plot(fpr_base, tpr_base, lw=3, color='black',\n", - " label=f'Baseline (AUC={base_auc:.3f})')\n", - "\n", - "# random line\n", - "plt.plot([0,1], [0,1], 'k--', lw=2)\n", - "\n", - "# store curves\n", - "curve_results = []\n", - "\n", - "for feat in features:\n", - " X_perm = X_test.copy()\n", - " X_perm[feat] = np.random.permutation(X_perm[feat].values)\n", - "\n", - " # FIX TYPES AGAIN\n", - " for col in [\"chrom\", \"ref\", \"alt\"]:\n", - " if col in X_perm.columns:\n", - " X_perm[col] = X_perm[col].astype(\"category\")\n", - "\n", - " for col in X_perm.columns:\n", - " if col not in [\"chrom\", \"ref\", \"alt\"]:\n", - " X_perm[col] = pd.to_numeric(X_perm[col], errors=\"coerce\")\n", - "\n", - " X_perm = X_perm.fillna(0)\n", - "\n", - " y_prob_perm = model.predict_proba(X_perm)[:, 1]\n", - "\n", - " fpr, tpr, _ = roc_curve(y_test, y_prob_perm)\n", - " roc_auc = auc(fpr, tpr)\n", - "\n", - " plt.plot(fpr, tpr, lw=1.5, alpha=0.7,\n", - " label=f'No {feat} (AUC={roc_auc:.3f})')\n", - "\n", - "# formatting\n", - "plt.xlim([0.0, 1.0])\n", - "plt.ylim([0.0, 1.05])\n", - "plt.xlabel('False Positive Rate')\n", - "plt.ylabel('True Positive Rate')\n", - "plt.title('Ablation: Feature Importance Analysis')\n", - "plt.legend(loc=\"lower right\", fontsize='small')\n", - "plt.grid(True, alpha=0.3)\n", - "\n", - "plt.savefig(\"ablation_roc.png\", dpi=300)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\ud83e\uddea Running NOISE-REMOVAL ABLATION (FINAL FIX)...\n", - "Removing 3 noisy features:\n", - "['GERP++_RS_rankscore', 'phyloP17way_primate_rankscore', 'phastCons100way_vertebrate_rankscore']\n", - "\n", - "==========================================\n", - "\ud83e\uddf9 NOISE-REMOVED MODEL PERFORMANCE\n", - "==========================================\n", - "AUC : 0.766809\n", - "PR-AUC : 0.771672\n", - "AUC Gain: 0.000000\n", - "PR Gain : 0.000000\n" - ] - } - ], - "source": [ - "print(\"\\n\ud83e\uddea Running NOISE-REMOVAL ABLATION (FINAL FIX)...\")\n", - "\n", - "# ---- DEFINE NOISY FEATURES ----\n", - "noise_features = [\n", - " \"GERP++_RS_rankscore\",\n", - " \"phyloP17way_primate_rankscore\",\n", - " \"phastCons100way_vertebrate_rankscore\"\n", - "]\n", - "\n", - "print(f\"Removing {len(noise_features)} noisy features:\")\n", - "print(noise_features)\n", - "\n", - "# ---- START FROM ORIGINAL TEST ----\n", - "X_clean = X_test.copy()\n", - "\n", - "# ---- \ud83d\udd25 ZERO OUT (DO NOT DROP) ----\n", - "for f in noise_features:\n", - " if f in X_clean.columns:\n", - " X_clean[f] = 0.0\n", - "\n", - "# ---- \ud83d\udd25 FORCE EXACT TRAINING ORDER (CRITICAL) ----\n", - "feature_order = model.get_booster().feature_names\n", - "X_clean = X_clean[feature_order]\n", - "\n", - "# ---- FIX TYPES ----\n", - "cat_cols = [\"chrom\", \"ref\", \"alt\"]\n", - "\n", - "for col in cat_cols:\n", - " if col in X_clean.columns:\n", - " X_clean[col] = X_clean[col].astype(\"category\")\n", - "\n", - "for col in X_clean.columns:\n", - " if col not in cat_cols:\n", - " X_clean[col] = pd.to_numeric(X_clean[col], errors=\"coerce\")\n", - "\n", - "X_clean = X_clean.fillna(0)\n", - "\n", - "# ---- PREDICT ----\n", - "y_prob_clean = model.predict_proba(X_clean)[:, 1]\n", - "\n", - "# ---- METRICS ----\n", - "clean_auc = roc_auc_score(y_test, y_prob_clean)\n", - "clean_pr = average_precision_score(y_test, y_prob_clean)\n", - "\n", - "print(\"\\n==========================================\")\n", - "print(\"\ud83e\uddf9 NOISE-REMOVED MODEL PERFORMANCE\")\n", - "print(\"==========================================\")\n", - "print(f\"AUC : {clean_auc:.6f}\")\n", - "print(f\"PR-AUC : {clean_pr:.6f}\")\n", - "print(f\"AUC Gain: {clean_auc - base_auc:.6f}\")\n", - "print(f\"PR Gain : {clean_pr - base_pr:.6f}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "fig, ax = plt.subplots(figsize=(16, 11))\n", - "ax.axis('off')\n", - "\n", - "# ---------- Helper Functions ----------\n", - "def draw_box(text, xy, fc=\"#E8F0FE\", size=10):\n", - " ax.text(\n", - " xy[0], xy[1], text,\n", - " ha='center', va='center',\n", - " fontsize=size,\n", - " bbox=dict(boxstyle=\"round,pad=0.5\", facecolor=fc, edgecolor='black')\n", - " )\n", - "\n", - "def draw_arrow(start, end):\n", - " ax.annotate(\n", - " \"\", xy=end, xytext=start,\n", - " arrowprops=dict(arrowstyle=\"->\", lw=1.5)\n", - " )\n", - "\n", - "# ---------- TITLE ----------\n", - "plt.title(\"Framework Pipeline\", fontsize=18, weight='bold')\n", - "\n", - "# =====================================================\n", - "# DATA SOURCES (STRAIGHT + CLEAN)\n", - "# =====================================================\n", - "draw_box(\"ClinVar\\n(~7.7M variants)\", (0.1, 0.90))\n", - "draw_box(\"gnomAD\\n(AF)\", (0.3, 0.90))\n", - "draw_box(\"dbNSFP\\n(functional scores)\", (0.5, 0.90))\n", - "draw_box(\"GRCh38\\n(sequence)\", (0.7, 0.90))\n", - "\n", - "draw_box(\"Data Integration\", (0.4, 0.78), fc=\"#D1E7DD\", size=11)\n", - "\n", - "# straight arrows\n", - "draw_arrow((0.1, 0.86), (0.35, 0.80))\n", - "draw_arrow((0.3, 0.86), (0.38, 0.80))\n", - "draw_arrow((0.5, 0.86), (0.42, 0.80))\n", - "draw_arrow((0.7, 0.86), (0.45, 0.80))\n", - "\n", - "# =====================================================\n", - "# PIPELINE (MORE SPACING)\n", - "# =====================================================\n", - "draw_box(\n", - " \"Filtering & Cleaning\\n\"\n", - " \"\u2022 SNV only\\n\"\n", - " \"\u2022 Remove HGVS (Leakage Prevention)\\n\"\n", - " \"\u2022 Deduplication\",\n", - " (0.4, 0.64),\n", - ")\n", - "\n", - "draw_box(\n", - " \"Feature Engineering\\n\"\n", - " \"\u2022 gnomAD AF\\n\"\n", - " \"\u2022 Conservation (phyloP, phastCons)\\n\"\n", - " \"\u2022 Median Imputation\",\n", - " (0.4, 0.48),\n", - ")\n", - "\n", - "draw_box(\n", - " \"Dataset Construction\\n\"\n", - " \"Train: 300k (Balanced)\\n\"\n", - " \"Test: 14k (Rare Balanced)\\n\"\n", - " \"Test: 100k (Imbalanced)\",\n", - " (0.4, 0.32),\n", - " fc=\"#E2F0CB\"\n", - ")\n", - "\n", - "draw_box(\n", - " \"XGBoost Model\\n\"\n", - " \"\u2022 10 Core Features\\n\"\n", - " \"\u2022 Regularized Training\",\n", - " (0.4, 0.18),\n", - " fc=\"#D9EAD3\"\n", - ")\n", - "\n", - "draw_box(\n", - " \"Evaluation Pipeline\\n\"\n", - " \"\u2022 ROC-AUC / PR-AUC\\n\"\n", - " \"\u2022 Threshold Tuning (VAL \u2192 TEST)\\n\"\n", - " \"\u2022 SHAP + Ablation\",\n", - " (0.4, 0.05),\n", - " fc=\"#FCE5CD\"\n", - ")\n", - "\n", - "# arrows (vertical clean)\n", - "draw_arrow((0.4, 0.76), (0.4, 0.68))\n", - "draw_arrow((0.4, 0.60), (0.4, 0.52))\n", - "draw_arrow((0.4, 0.44), (0.4, 0.36))\n", - "draw_arrow((0.4, 0.28), (0.4, 0.21))\n", - "draw_arrow((0.4, 0.15), (0.4, 0.09))\n", - "\n", - "# =====================================================\n", - "# BENCHMARK (FIXED CLEANLY)\n", - "# =====================================================\n", - "draw_box(\n", - " \"Generalist Models\\n\"\n", - " \"\u2022 CADD\\n\u2022 REVEL\\n\u2022 AlphaMissense\\n\u2022 BayesDel\",\n", - " (0.75, 0.05),\n", - " fc=\"#FFF3CD\"\n", - ")\n", - "\n", - "# clean horizontal arrow\n", - "draw_arrow((0.50, 0.05), (0.70, 0.05))\n", - "\n", - "# aligned label\n", - "ax.text(0.60, 0.09, \"Benchmark Comparison\", fontsize=10, ha='center')\n", - "\n", - "# =====================================================\n", - "# SAVE\n", - "# =====================================================\n", - "# plt.savefig(\"framework_final.svg\", bbox_inches='tight')\n", - "# plt.savefig(\"framework_final.png\", dpi=300, bbox_inches='tight')\n", - "\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} \ No newline at end of file