Upload folder using huggingface_hub
Browse files- KNN_SHAP_Explainability.ipynb +1234 -0
- LICENSE +21 -0
- README.md +274 -0
KNN_SHAP_Explainability.ipynb
ADDED
|
@@ -0,0 +1,1234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": "# KNN Explainability with SHAP on Cloud GPU\n\nThis notebook demonstrates how to use SHAP (SHapley Additive exPlanations) to interpret K-Nearest Neighbors (KNN) model predictions with comprehensive visualizations.\n\n**Environment:** Cloud GPU Instance (Running in VS Code)\n\n## Prerequisites\n- Cloud GPU instance with GPU (RTX 3090, RTX 4090, or A100 recommended)\n- VS Code with Jupyter extension installed\n- SSH connection to cloud instance"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"cell_type": "markdown",
|
| 10 |
+
"metadata": {},
|
| 11 |
+
"source": [
|
| 12 |
+
"## 1. Environment Setup and GPU Verification"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"# Check GPU availability and specifications\n",
|
| 22 |
+
"import subprocess\n",
|
| 23 |
+
"import sys\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"print(\"=\" * 80)\n",
|
| 26 |
+
"print(\"VAST.AI GPU INFORMATION\")\n",
|
| 27 |
+
"print(\"=\" * 80)\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"try:\n",
|
| 30 |
+
" result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
|
| 31 |
+
" print(result.stdout)\n",
|
| 32 |
+
"except FileNotFoundError:\n",
|
| 33 |
+
" print(\"nvidia-smi not found. GPU may not be available.\")\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 36 |
+
"print(\"PYTHON ENVIRONMENT\")\n",
|
| 37 |
+
"print(\"=\" * 80)\n",
|
| 38 |
+
"print(f\"Python version: {sys.version}\")\n",
|
| 39 |
+
"print(f\"Python executable: {sys.executable}\")"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"source": [
|
| 48 |
+
"# Check PyTorch CUDA availability\n",
|
| 49 |
+
"try:\n",
|
| 50 |
+
" import torch\n",
|
| 51 |
+
" print(\"\\n\" + \"=\" * 80)\n",
|
| 52 |
+
" print(\"PYTORCH & CUDA INFORMATION\")\n",
|
| 53 |
+
" print(\"=\" * 80)\n",
|
| 54 |
+
" print(f\"PyTorch version: {torch.__version__}\")\n",
|
| 55 |
+
" print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
| 56 |
+
" \n",
|
| 57 |
+
" if torch.cuda.is_available():\n",
|
| 58 |
+
" print(f\"CUDA version: {torch.version.cuda}\")\n",
|
| 59 |
+
" print(f\"Number of GPUs: {torch.cuda.device_count()}\")\n",
|
| 60 |
+
" for i in range(torch.cuda.device_count()):\n",
|
| 61 |
+
" print(f\"\\nGPU {i}: {torch.cuda.get_device_name(i)}\")\n",
|
| 62 |
+
" print(f\" Memory Total: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB\")\n",
|
| 63 |
+
" print(f\" Memory Allocated: {torch.cuda.memory_allocated(i) / 1e9:.4f} GB\")\n",
|
| 64 |
+
" print(f\" Memory Cached: {torch.cuda.memory_reserved(i) / 1e9:.4f} GB\")\n",
|
| 65 |
+
" else:\n",
|
| 66 |
+
" print(\"\\nWARNING: CUDA not available. Running on CPU.\")\n",
|
| 67 |
+
" print(\"This notebook is optimized for GPU but will work on CPU (slower).\")\n",
|
| 68 |
+
"except ImportError:\n",
|
| 69 |
+
" print(\"\\nPyTorch not installed yet. Will install in next cell.\")"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"# Install required packages\n",
|
| 79 |
+
"print(\"Installing required packages...\\n\")\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"packages = [\n",
|
| 82 |
+
" 'torch',\n",
|
| 83 |
+
" 'shap',\n",
|
| 84 |
+
" 'scikit-learn',\n",
|
| 85 |
+
" 'matplotlib',\n",
|
| 86 |
+
" 'seaborn',\n",
|
| 87 |
+
" 'pandas',\n",
|
| 88 |
+
" 'numpy',\n",
|
| 89 |
+
" 'plotly',\n",
|
| 90 |
+
" 'ipywidgets',\n",
|
| 91 |
+
" 'tqdm'\n",
|
| 92 |
+
"]\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"for package in packages:\n",
|
| 95 |
+
" print(f\"Installing {package}...\")\n",
|
| 96 |
+
" subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"print(\"\\n\" + \"=\"*80)\n",
|
| 99 |
+
"print(\"All packages installed successfully!\")\n",
|
| 100 |
+
"print(\"=\"*80)"
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "markdown",
|
| 105 |
+
"metadata": {},
|
| 106 |
+
"source": [
|
| 107 |
+
"## 2. Import Libraries"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": "# Import all necessary libraries\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport shap\nimport torch\nimport sklearn\nfrom sklearn.neighbors import KNeighborsClassifier\nfrom sklearn.model_selection import train_test_split, cross_val_score\nfrom sklearn.preprocessing import StandardScaler\nfrom sklearn.datasets import load_breast_cancer, load_wine, load_iris, make_classification\nfrom sklearn.metrics import (\n accuracy_score, \n classification_report, \n confusion_matrix,\n roc_curve,\n roc_auc_score,\n precision_recall_curve\n)\nfrom tqdm.auto import tqdm\nimport warnings\nwarnings.filterwarnings('ignore')\n\n# Set random seeds for reproducibility\nnp.random.seed(42)\ntorch.manual_seed(42)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(42)\n\n# Configure plotting style\nsns.set_style('whitegrid')\nplt.rcParams['figure.figsize'] = (14, 8)\nplt.rcParams['font.size'] = 10\nplt.rcParams['axes.titlesize'] = 14\nplt.rcParams['axes.labelsize'] = 12\n\n# Initialize SHAP's JavaScript visualization\nshap.initjs()\n\nprint(\"=\" * 80)\nprint(\"LIBRARY VERSIONS\")\nprint(\"=\" * 80)\nprint(f\"NumPy version: {np.__version__}\")\nprint(f\"Pandas version: {pd.__version__}\")\nprint(f\"Matplotlib version: {plt.matplotlib.__version__}\")\nprint(f\"Seaborn version: {sns.__version__}\")\nprint(f\"SHAP version: {shap.__version__}\")\nprint(f\"PyTorch version: {torch.__version__}\")\nprint(f\"Scikit-learn version: {sklearn.__version__}\")\nprint(\"\\nLibraries imported successfully!\")"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "markdown",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"## 3. GPU Memory Management Functions"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"# Utility functions for GPU memory management on Vast.ai\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"def print_gpu_memory():\n",
|
| 133 |
+
" \"\"\"Print current GPU memory usage\"\"\"\n",
|
| 134 |
+
" if torch.cuda.is_available():\n",
|
| 135 |
+
" for i in range(torch.cuda.device_count()):\n",
|
| 136 |
+
" allocated = torch.cuda.memory_allocated(i) / 1e9\n",
|
| 137 |
+
" cached = torch.cuda.memory_reserved(i) / 1e9\n",
|
| 138 |
+
" total = torch.cuda.get_device_properties(i).total_memory / 1e9\n",
|
| 139 |
+
" print(f\"GPU {i} ({torch.cuda.get_device_name(i)}):\")\n",
|
| 140 |
+
" print(f\" Allocated: {allocated:.3f} GB / {total:.2f} GB ({allocated/total*100:.1f}%)\")\n",
|
| 141 |
+
" print(f\" Cached: {cached:.3f} GB\")\n",
|
| 142 |
+
" else:\n",
|
| 143 |
+
" print(\"No GPU available\")\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"def clear_gpu_memory():\n",
|
| 146 |
+
" \"\"\"Clear GPU cache\"\"\"\n",
|
| 147 |
+
" if torch.cuda.is_available():\n",
|
| 148 |
+
" torch.cuda.empty_cache()\n",
|
| 149 |
+
" print(\"GPU cache cleared\")\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"def get_optimal_device():\n",
|
| 152 |
+
" \"\"\"Get optimal device for computation\"\"\"\n",
|
| 153 |
+
" if torch.cuda.is_available():\n",
|
| 154 |
+
" device = torch.device('cuda')\n",
|
| 155 |
+
" print(f\"Using GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 156 |
+
" else:\n",
|
| 157 |
+
" device = torch.device('cpu')\n",
|
| 158 |
+
" print(\"Using CPU (GPU not available)\")\n",
|
| 159 |
+
" return device\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"# Initialize device\n",
|
| 162 |
+
"device = get_optimal_device()\n",
|
| 163 |
+
"print(\"\\nInitial GPU Memory:\")\n",
|
| 164 |
+
"print_gpu_memory()"
|
| 165 |
+
]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "markdown",
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"source": [
|
| 171 |
+
"## 4. Data Loading and Exploration"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"cell_type": "code",
|
| 176 |
+
"execution_count": null,
|
| 177 |
+
"metadata": {},
|
| 178 |
+
"outputs": [],
|
| 179 |
+
"source": [
|
| 180 |
+
"# Load the Breast Cancer Wisconsin dataset\n",
|
| 181 |
+
"print(\"=\" * 80)\n",
|
| 182 |
+
"print(\"LOADING DATASET\")\n",
|
| 183 |
+
"print(\"=\" * 80)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"data = load_breast_cancer()\n",
|
| 186 |
+
"X = pd.DataFrame(data.data, columns=data.feature_names)\n",
|
| 187 |
+
"y = pd.Series(data.target, name='target')\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"print(f\"\\nDataset: Breast Cancer Wisconsin (Diagnostic)\")\n",
|
| 190 |
+
"print(f\"Number of samples: {X.shape[0]}\")\n",
|
| 191 |
+
"print(f\"Number of features: {X.shape[1]}\")\n",
|
| 192 |
+
"print(f\"Number of classes: {len(data.target_names)}\")\n",
|
| 193 |
+
"print(f\"Class names: {list(data.target_names)}\")\n",
|
| 194 |
+
"print(f\"\\nTarget distribution:\")\n",
|
| 195 |
+
"for idx, name in enumerate(data.target_names):\n",
|
| 196 |
+
" count = (y == idx).sum()\n",
|
| 197 |
+
" percentage = count / len(y) * 100\n",
|
| 198 |
+
" print(f\" {name}: {count} ({percentage:.2f}%)\")\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"print(f\"\\nFeature statistics:\")\n",
|
| 201 |
+
"print(X.describe().round(2))"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"# Display first few rows\n",
|
| 211 |
+
"print(\"\\nFirst 5 rows of the dataset:\")\n",
|
| 212 |
+
"display_df = X.head()\n",
|
| 213 |
+
"display_df['target'] = y.head().map({0: data.target_names[0], 1: data.target_names[1]})\n",
|
| 214 |
+
"display_df"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"cell_type": "markdown",
|
| 219 |
+
"metadata": {},
|
| 220 |
+
"source": [
|
| 221 |
+
"## 5. Exploratory Data Analysis"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": null,
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"outputs": [],
|
| 229 |
+
"source": [
|
| 230 |
+
"# Target distribution visualization\n",
|
| 231 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"# Bar plot\n",
|
| 234 |
+
"target_counts = y.value_counts().sort_index()\n",
|
| 235 |
+
"colors = ['#FF6B6B', '#4ECDC4']\n",
|
| 236 |
+
"axes[0].bar(data.target_names, target_counts.values, color=colors, edgecolor='black', linewidth=1.5)\n",
|
| 237 |
+
"axes[0].set_ylabel('Count', fontweight='bold')\n",
|
| 238 |
+
"axes[0].set_title('Target Class Distribution', fontweight='bold', fontsize=14)\n",
|
| 239 |
+
"axes[0].grid(axis='y', alpha=0.3)\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"# Add count labels on bars\n",
|
| 242 |
+
"for i, (name, count) in enumerate(zip(data.target_names, target_counts.values)):\n",
|
| 243 |
+
" axes[0].text(i, count + 5, str(count), ha='center', fontweight='bold')\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"# Pie chart\n",
|
| 246 |
+
"axes[1].pie(target_counts.values, labels=data.target_names, autopct='%1.1f%%',\n",
|
| 247 |
+
" colors=colors, startangle=90, textprops={'fontsize': 12, 'fontweight': 'bold'})\n",
|
| 248 |
+
"axes[1].set_title('Target Class Proportion', fontweight='bold', fontsize=14)\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"plt.tight_layout()\n",
|
| 251 |
+
"plt.show()"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"cell_type": "code",
|
| 256 |
+
"execution_count": null,
|
| 257 |
+
"metadata": {},
|
| 258 |
+
"outputs": [],
|
| 259 |
+
"source": [
|
| 260 |
+
"# Feature distributions for key features\n",
|
| 261 |
+
"print(\"Feature Distributions (Top 12 Features):\")\n",
|
| 262 |
+
"fig, axes = plt.subplots(3, 4, figsize=(16, 12))\n",
|
| 263 |
+
"axes = axes.ravel()\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"for idx, col in enumerate(X.columns[:12]):\n",
|
| 266 |
+
" # Histogram for each class\n",
|
| 267 |
+
" for target_idx, target_name in enumerate(data.target_names):\n",
|
| 268 |
+
" mask = y == target_idx\n",
|
| 269 |
+
" axes[idx].hist(X.loc[mask, col], bins=25, alpha=0.6, \n",
|
| 270 |
+
" label=target_name, color=colors[target_idx], edgecolor='black')\n",
|
| 271 |
+
" \n",
|
| 272 |
+
" axes[idx].set_title(f'{col}', fontsize=10, fontweight='bold')\n",
|
| 273 |
+
" axes[idx].set_xlabel('Value', fontsize=9)\n",
|
| 274 |
+
" axes[idx].set_ylabel('Frequency', fontsize=9)\n",
|
| 275 |
+
" axes[idx].legend(fontsize=8)\n",
|
| 276 |
+
" axes[idx].grid(alpha=0.3)\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"plt.tight_layout()\n",
|
| 279 |
+
"plt.suptitle('Feature Distributions by Class', y=1.002, fontsize=16, fontweight='bold')\n",
|
| 280 |
+
"plt.show()"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "code",
|
| 285 |
+
"execution_count": null,
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"outputs": [],
|
| 288 |
+
"source": [
|
| 289 |
+
"# Correlation heatmap for top features\n",
|
| 290 |
+
"print(\"\\nFeature Correlation Analysis (Top 15 Features):\")\n",
|
| 291 |
+
"top_features = X.columns[:15]\n",
|
| 292 |
+
"correlation_matrix = X[top_features].corr()\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"plt.figure(figsize=(14, 12))\n",
|
| 295 |
+
"mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))\n",
|
| 296 |
+
"sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt='.2f', \n",
|
| 297 |
+
" cmap='coolwarm', center=0, square=True, linewidths=0.5,\n",
|
| 298 |
+
" cbar_kws={\"shrink\": 0.8, \"label\": \"Correlation Coefficient\"})\n",
|
| 299 |
+
"plt.title('Feature Correlation Matrix (Top 15 Features)', fontsize=14, fontweight='bold', pad=20)\n",
|
| 300 |
+
"plt.tight_layout()\n",
|
| 301 |
+
"plt.show()\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"# Find highly correlated pairs\n",
|
| 304 |
+
"high_corr_pairs = []\n",
|
| 305 |
+
"for i in range(len(correlation_matrix.columns)):\n",
|
| 306 |
+
" for j in range(i+1, len(correlation_matrix.columns)):\n",
|
| 307 |
+
" if abs(correlation_matrix.iloc[i, j]) > 0.8:\n",
|
| 308 |
+
" high_corr_pairs.append((\n",
|
| 309 |
+
" correlation_matrix.columns[i],\n",
|
| 310 |
+
" correlation_matrix.columns[j],\n",
|
| 311 |
+
" correlation_matrix.iloc[i, j]\n",
|
| 312 |
+
" ))\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"if high_corr_pairs:\n",
|
| 315 |
+
" print(\"\\nHighly correlated feature pairs (|r| > 0.8):\")\n",
|
| 316 |
+
" for feat1, feat2, corr in high_corr_pairs[:5]:\n",
|
| 317 |
+
" print(f\" {feat1} <-> {feat2}: {corr:.3f}\")"
|
| 318 |
+
]
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"cell_type": "markdown",
|
| 322 |
+
"metadata": {},
|
| 323 |
+
"source": [
|
| 324 |
+
"## 6. Data Preprocessing"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"cell_type": "code",
|
| 329 |
+
"execution_count": null,
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"outputs": [],
|
| 332 |
+
"source": [
|
| 333 |
+
"# Split the data\n",
|
| 334 |
+
"print(\"=\" * 80)\n",
|
| 335 |
+
"print(\"DATA PREPROCESSING\")\n",
|
| 336 |
+
"print(\"=\" * 80)\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
| 339 |
+
" X, y, test_size=0.2, random_state=42, stratify=y\n",
|
| 340 |
+
")\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"print(f\"\\nTraining set size: {X_train.shape[0]} samples\")\n",
|
| 343 |
+
"print(f\"Test set size: {X_test.shape[0]} samples\")\n",
|
| 344 |
+
"print(f\"\\nTraining set class distribution:\")\n",
|
| 345 |
+
"for idx, name in enumerate(data.target_names):\n",
|
| 346 |
+
" count = (y_train == idx).sum()\n",
|
| 347 |
+
" percentage = count / len(y_train) * 100\n",
|
| 348 |
+
" print(f\" {name}: {count} ({percentage:.2f}%)\")\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"# Feature scaling (critical for KNN)\n",
|
| 351 |
+
"print(\"\\nApplying StandardScaler...\")\n",
|
| 352 |
+
"scaler = StandardScaler()\n",
|
| 353 |
+
"X_train_scaled = scaler.fit_transform(X_train)\n",
|
| 354 |
+
"X_test_scaled = scaler.transform(X_test)\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"# Convert back to DataFrame for better handling\n",
|
| 357 |
+
"X_train_scaled = pd.DataFrame(X_train_scaled, columns=X.columns, index=X_train.index)\n",
|
| 358 |
+
"X_test_scaled = pd.DataFrame(X_test_scaled, columns=X.columns, index=X_test.index)\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"print(f\"\\nScaled features - Mean: {X_train_scaled.mean().mean():.6f}\")\n",
|
| 361 |
+
"print(f\"Scaled features - Std: {X_train_scaled.std().mean():.6f}\")\n",
|
| 362 |
+
"print(\"\\nData preprocessing completed!\")"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"cell_type": "code",
|
| 367 |
+
"execution_count": null,
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [],
|
| 370 |
+
"source": [
|
| 371 |
+
"# Visualize scaling effect\n",
|
| 372 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"# Before scaling\n",
|
| 375 |
+
"sample_features = X.columns[:5]\n",
|
| 376 |
+
"X_train[sample_features].boxplot(ax=axes[0])\n",
|
| 377 |
+
"axes[0].set_title('Feature Scales - Before Scaling', fontweight='bold', fontsize=12)\n",
|
| 378 |
+
"axes[0].set_ylabel('Value', fontweight='bold')\n",
|
| 379 |
+
"axes[0].tick_params(axis='x', rotation=45)\n",
|
| 380 |
+
"axes[0].grid(alpha=0.3)\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"# After scaling\n",
|
| 383 |
+
"X_train_scaled[sample_features].boxplot(ax=axes[1])\n",
|
| 384 |
+
"axes[1].set_title('Feature Scales - After Scaling', fontweight='bold', fontsize=12)\n",
|
| 385 |
+
"axes[1].set_ylabel('Standardized Value', fontweight='bold')\n",
|
| 386 |
+
"axes[1].tick_params(axis='x', rotation=45)\n",
|
| 387 |
+
"axes[1].grid(alpha=0.3)\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"plt.tight_layout()\n",
|
| 390 |
+
"plt.show()"
|
| 391 |
+
]
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"cell_type": "markdown",
|
| 395 |
+
"metadata": {},
|
| 396 |
+
"source": [
|
| 397 |
+
"## 7. KNN Model Training and Optimization"
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"cell_type": "code",
|
| 402 |
+
"execution_count": null,
|
| 403 |
+
"metadata": {},
|
| 404 |
+
"outputs": [],
|
| 405 |
+
"source": [
|
| 406 |
+
"# Find optimal K value using cross-validation\n",
|
| 407 |
+
"print(\"=\" * 80)\n",
|
| 408 |
+
"print(\"KNN MODEL TRAINING\")\n",
|
| 409 |
+
"print(\"=\" * 80)\n",
|
| 410 |
+
"print(\"\\nFinding optimal K value...\\n\")\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"k_range = range(1, 31)\n",
|
| 413 |
+
"train_scores = []\n",
|
| 414 |
+
"test_scores = []\n",
|
| 415 |
+
"cv_scores = []\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"# Use tqdm for progress bar\n",
|
| 418 |
+
"for k in tqdm(k_range, desc=\"Testing K values\"):\n",
|
| 419 |
+
" knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)\n",
|
| 420 |
+
" \n",
|
| 421 |
+
" # Training score\n",
|
| 422 |
+
" knn.fit(X_train_scaled, y_train)\n",
|
| 423 |
+
" train_scores.append(knn.score(X_train_scaled, y_train))\n",
|
| 424 |
+
" \n",
|
| 425 |
+
" # Test score\n",
|
| 426 |
+
" test_scores.append(knn.score(X_test_scaled, y_test))\n",
|
| 427 |
+
" \n",
|
| 428 |
+
" # Cross-validation score\n",
|
| 429 |
+
" cv_score = cross_val_score(knn, X_train_scaled, y_train, cv=5, n_jobs=-1)\n",
|
| 430 |
+
" cv_scores.append(cv_score.mean())\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"# Find best K\n",
|
| 433 |
+
"best_k_test = k_range[np.argmax(test_scores)]\n",
|
| 434 |
+
"best_k_cv = k_range[np.argmax(cv_scores)]\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"print(f\"\\n✓ Optimal K (based on test accuracy): {best_k_test}\")\n",
|
| 437 |
+
"print(f\"✓ Optimal K (based on CV score): {best_k_cv}\")\n",
|
| 438 |
+
"print(f\"\\nUsing K = {best_k_cv} for final model\")"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"cell_type": "code",
|
| 443 |
+
"execution_count": null,
|
| 444 |
+
"metadata": {},
|
| 445 |
+
"outputs": [],
|
| 446 |
+
"source": [
|
| 447 |
+
"# Plot K vs Accuracy\n",
|
| 448 |
+
"fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
|
| 449 |
+
"\n",
|
| 450 |
+
"# Left plot: All scores\n",
|
| 451 |
+
"axes[0].plot(k_range, train_scores, label='Training Accuracy', \n",
|
| 452 |
+
" marker='o', linewidth=2, markersize=4, color='#2ecc71')\n",
|
| 453 |
+
"axes[0].plot(k_range, test_scores, label='Test Accuracy', \n",
|
| 454 |
+
" marker='s', linewidth=2, markersize=4, color='#e74c3c')\n",
|
| 455 |
+
"axes[0].plot(k_range, cv_scores, label='CV Accuracy (5-fold)', \n",
|
| 456 |
+
" marker='^', linewidth=2, markersize=4, color='#3498db')\n",
|
| 457 |
+
"axes[0].axvline(x=best_k_cv, color='black', linestyle='--', alpha=0.5, label=f'Best K={best_k_cv}')\n",
|
| 458 |
+
"axes[0].set_xlabel('K Value (Number of Neighbors)', fontweight='bold')\n",
|
| 459 |
+
"axes[0].set_ylabel('Accuracy', fontweight='bold')\n",
|
| 460 |
+
"axes[0].set_title('KNN: K Value vs Accuracy', fontweight='bold', fontsize=14)\n",
|
| 461 |
+
"axes[0].legend(loc='best', fontsize=10)\n",
|
| 462 |
+
"axes[0].grid(True, alpha=0.3)\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"# Right plot: Train-test gap\n",
|
| 465 |
+
"gap = np.array(train_scores) - np.array(test_scores)\n",
|
| 466 |
+
"axes[1].plot(k_range, gap, marker='o', linewidth=2, markersize=4, color='#9b59b6')\n",
|
| 467 |
+
"axes[1].axvline(x=best_k_cv, color='black', linestyle='--', alpha=0.5, label=f'Best K={best_k_cv}')\n",
|
| 468 |
+
"axes[1].axhline(y=0, color='red', linestyle='-', alpha=0.3)\n",
|
| 469 |
+
"axes[1].set_xlabel('K Value (Number of Neighbors)', fontweight='bold')\n",
|
| 470 |
+
"axes[1].set_ylabel('Train-Test Accuracy Gap', fontweight='bold')\n",
|
| 471 |
+
"axes[1].set_title('Overfitting Analysis', fontweight='bold', fontsize=14)\n",
|
| 472 |
+
"axes[1].legend(loc='best', fontsize=10)\n",
|
| 473 |
+
"axes[1].grid(True, alpha=0.3)\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"plt.tight_layout()\n",
|
| 476 |
+
"plt.show()\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"print(f\"\\nBest test accuracy: {max(test_scores):.4f} at K={best_k_test}\")\n",
|
| 479 |
+
"print(f\"Best CV accuracy: {max(cv_scores):.4f} at K={best_k_cv}\")"
|
| 480 |
+
]
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"cell_type": "code",
|
| 484 |
+
"execution_count": null,
|
| 485 |
+
"metadata": {},
|
| 486 |
+
"outputs": [],
|
| 487 |
+
"source": [
|
| 488 |
+
"# Train final KNN model with optimal K\n",
|
| 489 |
+
"print(\"\\nTraining final KNN model...\")\n",
|
| 490 |
+
"optimal_k = best_k_cv\n",
|
| 491 |
+
"knn_model = KNeighborsClassifier(n_neighbors=optimal_k, n_jobs=-1)\n",
|
| 492 |
+
"knn_model.fit(X_train_scaled, y_train)\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"# Make predictions\n",
|
| 495 |
+
"y_train_pred = knn_model.predict(X_train_scaled)\n",
|
| 496 |
+
"y_test_pred = knn_model.predict(X_test_scaled)\n",
|
| 497 |
+
"y_train_proba = knn_model.predict_proba(X_train_scaled)\n",
|
| 498 |
+
"y_test_proba = knn_model.predict_proba(X_test_scaled)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"# Calculate metrics\n",
|
| 501 |
+
"train_accuracy = accuracy_score(y_train, y_train_pred)\n",
|
| 502 |
+
"test_accuracy = accuracy_score(y_test, y_test_pred)\n",
|
| 503 |
+
"\n",
|
| 504 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 505 |
+
"print(\"MODEL PERFORMANCE\")\n",
|
| 506 |
+
"print(\"=\" * 80)\n",
|
| 507 |
+
"print(f\"\\nOptimal K: {optimal_k}\")\n",
|
| 508 |
+
"print(f\"Training Accuracy: {train_accuracy:.4f}\")\n",
|
| 509 |
+
"print(f\"Test Accuracy: {test_accuracy:.4f}\")\n",
|
| 510 |
+
"print(f\"\\nClassification Report (Test Set):\")\n",
|
| 511 |
+
"print(classification_report(y_test, y_test_pred, target_names=data.target_names, digits=4))"
|
| 512 |
+
]
|
| 513 |
+
},
|
| 514 |
+
{
|
| 515 |
+
"cell_type": "code",
|
| 516 |
+
"execution_count": null,
|
| 517 |
+
"metadata": {},
|
| 518 |
+
"outputs": [],
|
| 519 |
+
"source": [
|
| 520 |
+
"# Confusion Matrix Visualization\n",
|
| 521 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
| 522 |
+
"\n",
|
| 523 |
+
"# Absolute counts\n",
|
| 524 |
+
"cm = confusion_matrix(y_test, y_test_pred)\n",
|
| 525 |
+
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n",
|
| 526 |
+
" xticklabels=data.target_names, \n",
|
| 527 |
+
" yticklabels=data.target_names,\n",
|
| 528 |
+
" cbar_kws={'label': 'Count'},\n",
|
| 529 |
+
" ax=axes[0],\n",
|
| 530 |
+
" annot_kws={'fontsize': 14, 'fontweight': 'bold'})\n",
|
| 531 |
+
"axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')\n",
|
| 532 |
+
"axes[0].set_ylabel('True Label', fontsize=12, fontweight='bold')\n",
|
| 533 |
+
"axes[0].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"# Normalized\n",
|
| 536 |
+
"cm_norm = confusion_matrix(y_test, y_test_pred, normalize='true')\n",
|
| 537 |
+
"sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Greens', \n",
|
| 538 |
+
" xticklabels=data.target_names, \n",
|
| 539 |
+
" yticklabels=data.target_names,\n",
|
| 540 |
+
" cbar_kws={'label': 'Proportion'},\n",
|
| 541 |
+
" ax=axes[1],\n",
|
| 542 |
+
" annot_kws={'fontsize': 14, 'fontweight': 'bold'})\n",
|
| 543 |
+
"axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')\n",
|
| 544 |
+
"axes[1].set_ylabel('True Label', fontsize=12, fontweight='bold')\n",
|
| 545 |
+
"axes[1].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"plt.tight_layout()\n",
|
| 548 |
+
"plt.show()"
|
| 549 |
+
]
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"cell_type": "code",
|
| 553 |
+
"execution_count": null,
|
| 554 |
+
"metadata": {},
|
| 555 |
+
"outputs": [],
|
| 556 |
+
"source": [
|
| 557 |
+
"# ROC Curve and AUC\n",
|
| 558 |
+
"fpr, tpr, thresholds = roc_curve(y_test, y_test_proba[:, 1])\n",
|
| 559 |
+
"roc_auc = roc_auc_score(y_test, y_test_proba[:, 1])\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"# ROC Curve\n",
|
| 564 |
+
"axes[0].plot(fpr, tpr, color='darkorange', lw=2, \n",
|
| 565 |
+
" label=f'ROC curve (AUC = {roc_auc:.4f})')\n",
|
| 566 |
+
"axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')\n",
|
| 567 |
+
"axes[0].set_xlim([0.0, 1.0])\n",
|
| 568 |
+
"axes[0].set_ylim([0.0, 1.05])\n",
|
| 569 |
+
"axes[0].set_xlabel('False Positive Rate', fontweight='bold')\n",
|
| 570 |
+
"axes[0].set_ylabel('True Positive Rate', fontweight='bold')\n",
|
| 571 |
+
"axes[0].set_title('Receiver Operating Characteristic (ROC) Curve', fontweight='bold', fontsize=12)\n",
|
| 572 |
+
"axes[0].legend(loc=\"lower right\", fontsize=10)\n",
|
| 573 |
+
"axes[0].grid(alpha=0.3)\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"# Precision-Recall Curve\n",
|
| 576 |
+
"precision, recall, _ = precision_recall_curve(y_test, y_test_proba[:, 1])\n",
|
| 577 |
+
"axes[1].plot(recall, precision, color='green', lw=2, label='Precision-Recall curve')\n",
|
| 578 |
+
"axes[1].set_xlabel('Recall', fontweight='bold')\n",
|
| 579 |
+
"axes[1].set_ylabel('Precision', fontweight='bold')\n",
|
| 580 |
+
"axes[1].set_title('Precision-Recall Curve', fontweight='bold', fontsize=12)\n",
|
| 581 |
+
"axes[1].legend(loc=\"best\", fontsize=10)\n",
|
| 582 |
+
"axes[1].grid(alpha=0.3)\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"plt.tight_layout()\n",
|
| 585 |
+
"plt.show()"
|
| 586 |
+
]
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"cell_type": "markdown",
|
| 590 |
+
"metadata": {},
|
| 591 |
+
"source": [
|
| 592 |
+
"## 8. SHAP Explainability Setup"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "code",
|
| 597 |
+
"execution_count": null,
|
| 598 |
+
"metadata": {},
|
| 599 |
+
"outputs": [],
|
| 600 |
+
"source": [
|
| 601 |
+
"print(\"=\" * 80)\n",
|
| 602 |
+
"print(\"SHAP EXPLAINABILITY ANALYSIS\")\n",
|
| 603 |
+
"print(\"=\" * 80)\n",
|
| 604 |
+
"print(\"\\nSetting up SHAP explainer...\\n\")\n",
|
| 605 |
+
"\n",
|
| 606 |
+
"# Create background dataset for SHAP\n",
|
| 607 |
+
"# Using kmeans to select representative samples (faster for large datasets)\n",
|
| 608 |
+
"background_size = 100\n",
|
| 609 |
+
"background = shap.kmeans(X_train_scaled, background_size)\n",
|
| 610 |
+
"\n",
|
| 611 |
+
"print(f\"Background dataset size: {background_size} samples\")\n",
|
| 612 |
+
"print(f\"Background dataset shape: {background.data.shape}\")\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"# Create SHAP explainer\n",
|
| 615 |
+
"# Using KernelExplainer (model-agnostic) for KNN\n",
|
| 616 |
+
"print(\"\\nCreating SHAP KernelExplainer (this may take a moment)...\")\n",
|
| 617 |
+
"explainer = shap.KernelExplainer(knn_model.predict_proba, background)\n",
|
| 618 |
+
"\n",
|
| 619 |
+
"print(\"\\n✓ SHAP explainer created successfully!\")\n",
|
| 620 |
+
"print(f\"Expected value (class 0): {explainer.expected_value[0]:.4f}\")\n",
|
| 621 |
+
"print(f\"Expected value (class 1): {explainer.expected_value[1]:.4f}\")\n",
|
| 622 |
+
"\n",
|
| 623 |
+
"# Check GPU memory after setup\n",
|
| 624 |
+
"print(\"\\nGPU Memory Status:\")\n",
|
| 625 |
+
"print_gpu_memory()"
|
| 626 |
+
]
|
| 627 |
+
},
|
| 628 |
+
{
|
| 629 |
+
"cell_type": "code",
|
| 630 |
+
"execution_count": null,
|
| 631 |
+
"metadata": {},
|
| 632 |
+
"outputs": [],
|
| 633 |
+
"source": "# Compute SHAP values for test set\n# Adjust sample size based on your GPU memory and time constraints\nn_samples = min(100, len(X_test_scaled)) # Use 100 samples or all test samples if less\nX_test_sample = X_test_scaled.iloc[:n_samples]\ny_test_sample = y_test.iloc[:n_samples]\n\nprint(f\"Computing SHAP values for {n_samples} test samples...\")\nprint(\"This may take several minutes depending on your GPU...\")\nprint(\"Progress will be shown below:\\n\")\n\n# Compute SHAP values with progress tracking\nimport time\nstart_time = time.time()\n\nshap_values = explainer.shap_values(X_test_sample, nsamples=100) # nsamples controls accuracy/speed tradeoff\n\nelapsed_time = time.time() - start_time\n\nprint(f\"\\n✓ SHAP values computed successfully!\")\nprint(f\"Computation time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)\")\nprint(f\"Time per sample: {elapsed_time/n_samples:.2f} seconds\")\n\n# Debug: Check the structure of shap_values\nprint(f\"\\nDEBUG: Type of shap_values: {type(shap_values)}\")\nif isinstance(shap_values, list):\n print(f\"DEBUG: shap_values is a list with {len(shap_values)} elements\")\n for i, sv in enumerate(shap_values):\n print(f\"DEBUG: shap_values[{i}].shape = {np.array(sv).shape}\")\nelse:\n print(f\"DEBUG: shap_values.shape = {np.array(shap_values).shape}\")\nprint(f\"DEBUG: X_test_sample.shape = {X_test_sample.shape}\")\n\nprint(f\"\\nSHAP values shape: {np.array(shap_values).shape}\")\nprint(f\" - 2 classes (binary classification)\")\nprint(f\" - {n_samples} samples explained\")\nprint(f\" - {X_test_sample.shape[1]} features\")\n\n# Check GPU memory\nprint(\"\\nGPU Memory Status:\")\nprint_gpu_memory()"
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
"cell_type": "markdown",
|
| 637 |
+
"metadata": {},
|
| 638 |
+
"source": "## 9. SHAP Visualizations - Global Explanations\n\n# Fix SHAP values format for binary classification\n# SHAP KernelExplainer returns shape (n_samples, n_features, n_classes) for predict_proba\n# We need to convert it to a list of arrays: [class_0_shap_values, class_1_shap_values]\nprint(\"Reshaping SHAP values for visualization...\")\nprint(f\"Original shape: {shap_values.shape}\")\n\nif len(shap_values.shape) == 3 and shap_values.shape[2] == 2:\n # Convert from (n_samples, n_features, n_classes) to list of (n_samples, n_features)\n shap_values_list = [shap_values[:, :, i] for i in range(shap_values.shape[2])]\n print(f\"Converted to list format:\")\n print(f\" Class 0 SHAP values shape: {shap_values_list[0].shape}\")\n print(f\" Class 1 SHAP values shape: {shap_values_list[1].shape}\")\n shap_values = shap_values_list\nelse:\n print(f\"SHAP values already in correct format\")\n\nprint(\"✓ SHAP values ready for visualization!\")"
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"cell_type": "code",
|
| 642 |
+
"execution_count": null,
|
| 643 |
+
"metadata": {},
|
| 644 |
+
"outputs": [],
|
| 645 |
+
"source": "# SHAP Summary Plot (Beeswarm) - Shows global feature importance\nprint(\"=\" * 80)\nprint(\"SHAP VISUALIZATION 1: SUMMARY PLOT (BEESWARM)\")\nprint(\"=\" * 80)\nprint(\"\"\"\nThis plot shows:\n- Feature importance (vertical axis, ordered by importance)\n- SHAP values (horizontal axis, impact on prediction)\n- Feature values (color, red=high, blue=low)\n- Distribution across all samples (density)\n\nReading the plot:\n- Features at the top are most important\n- Points to the right increase probability of class 1 (malignant)\n- Points to the left decrease probability of class 1\n- Color shows whether high (red) or low (blue) feature values have that effect\n\"\"\")\n\nplt.figure(figsize=(14, 10))\nshap.summary_plot(shap_values[1], X_test_sample.values, \n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"dot\", show=False, max_display=20)\nplt.title('SHAP Summary Plot - Global Feature Importance and Impact\\n(Predicting Malignant Class)', \n fontsize=14, fontweight='bold', pad=20)\nplt.tight_layout()\nplt.show()"
|
| 646 |
+
},
|
| 647 |
+
{
|
| 648 |
+
"cell_type": "code",
|
| 649 |
+
"execution_count": null,
|
| 650 |
+
"metadata": {},
|
| 651 |
+
"outputs": [],
|
| 652 |
+
"source": "# SHAP Bar Plot - Mean absolute SHAP values\nprint(\"=\" * 80)\nprint(\"SHAP VISUALIZATION 2: BAR PLOT\")\nprint(\"=\" * 80)\nprint(\"\"\"\nThis plot shows:\n- Average magnitude of feature impact (mean |SHAP value|)\n- Overall feature importance ranking\n- Which features have the strongest effect on predictions (regardless of direction)\n\"\"\")\n\nplt.figure(figsize=(14, 10))\nshap.summary_plot(shap_values[1], X_test_sample.values,\n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"bar\", show=False, max_display=20)\nplt.title('SHAP Bar Plot - Mean Absolute Feature Importance', \n fontsize=14, fontweight='bold', pad=20)\nplt.xlabel('Mean |SHAP value| (average impact on model output magnitude)', fontweight='bold')\nplt.tight_layout()\nplt.show()"
|
| 653 |
+
},
|
| 654 |
+
{
|
| 655 |
+
"cell_type": "code",
|
| 656 |
+
"execution_count": null,
|
| 657 |
+
"metadata": {},
|
| 658 |
+
"outputs": [],
|
| 659 |
+
"source": [
|
| 660 |
+
"# Calculate and display feature importance rankings\n",
|
| 661 |
+
"feature_importance = np.abs(shap_values[1]).mean(axis=0)\n",
|
| 662 |
+
"feature_importance_df = pd.DataFrame({\n",
|
| 663 |
+
" 'Feature': X_test_sample.columns,\n",
|
| 664 |
+
" 'Mean_Abs_SHAP': feature_importance,\n",
|
| 665 |
+
" 'Mean_SHAP': shap_values[1].mean(axis=0),\n",
|
| 666 |
+
" 'Std_SHAP': np.std(shap_values[1], axis=0)\n",
|
| 667 |
+
"}).sort_values('Mean_Abs_SHAP', ascending=False)\n",
|
| 668 |
+
"\n",
|
| 669 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 670 |
+
"print(\"TOP 15 MOST IMPORTANT FEATURES (by mean |SHAP value|)\")\n",
|
| 671 |
+
"print(\"=\" * 80)\n",
|
| 672 |
+
"print(feature_importance_df.head(15).to_string(index=False))\n",
|
| 673 |
+
"\n",
|
| 674 |
+
"# Visualize feature importance with custom plot\n",
|
| 675 |
+
"fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"# Top features by absolute importance\n",
|
| 678 |
+
"top_15 = feature_importance_df.head(15)\n",
|
| 679 |
+
"y_pos = np.arange(len(top_15))\n",
|
| 680 |
+
"axes[0].barh(y_pos, top_15['Mean_Abs_SHAP'], color='steelblue', edgecolor='black')\n",
|
| 681 |
+
"axes[0].set_yticks(y_pos)\n",
|
| 682 |
+
"axes[0].set_yticklabels(top_15['Feature'], fontsize=9)\n",
|
| 683 |
+
"axes[0].set_xlabel('Mean |SHAP Value|', fontweight='bold')\n",
|
| 684 |
+
"axes[0].set_title('Top 15 Features by Absolute Importance', fontweight='bold', fontsize=12)\n",
|
| 685 |
+
"axes[0].invert_yaxis()\n",
|
| 686 |
+
"axes[0].grid(axis='x', alpha=0.3)\n",
|
| 687 |
+
"\n",
|
| 688 |
+
"# Top features by directional importance (showing positive/negative effect)\n",
|
| 689 |
+
"colors = ['green' if x > 0 else 'red' for x in top_15['Mean_SHAP']]\n",
|
| 690 |
+
"axes[1].barh(y_pos, top_15['Mean_SHAP'], color=colors, edgecolor='black', alpha=0.7)\n",
|
| 691 |
+
"axes[1].set_yticks(y_pos)\n",
|
| 692 |
+
"axes[1].set_yticklabels(top_15['Feature'], fontsize=9)\n",
|
| 693 |
+
"axes[1].set_xlabel('Mean SHAP Value (Directional)', fontweight='bold')\n",
|
| 694 |
+
"axes[1].set_title('Top 15 Features - Directional Impact\\n(Green: Increases Malignant Prob, Red: Decreases)', \n",
|
| 695 |
+
" fontweight='bold', fontsize=12)\n",
|
| 696 |
+
"axes[1].axvline(x=0, color='black', linestyle='-', linewidth=0.8)\n",
|
| 697 |
+
"axes[1].invert_yaxis()\n",
|
| 698 |
+
"axes[1].grid(axis='x', alpha=0.3)\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"plt.tight_layout()\n",
|
| 701 |
+
"plt.show()"
|
| 702 |
+
]
|
| 703 |
+
},
|
| 704 |
+
{
|
| 705 |
+
"cell_type": "markdown",
|
| 706 |
+
"metadata": {},
|
| 707 |
+
"source": [
|
| 708 |
+
"## 10. SHAP Visualizations - Local Explanations (Individual Predictions)"
|
| 709 |
+
]
|
| 710 |
+
},
|
| 711 |
+
{
|
| 712 |
+
"cell_type": "code",
|
| 713 |
+
"execution_count": null,
|
| 714 |
+
"metadata": {},
|
| 715 |
+
"outputs": [],
|
| 716 |
+
"source": [
|
| 717 |
+
"# Select interesting samples for detailed explanation\n",
|
| 718 |
+
"print(\"=\" * 80)\n",
|
| 719 |
+
"print(\"SELECTING SAMPLES FOR DETAILED EXPLANATION\")\n",
|
| 720 |
+
"print(\"=\" * 80)\n",
|
| 721 |
+
"\n",
|
| 722 |
+
"# Get predictions for our sample\n",
|
| 723 |
+
"y_pred_sample = knn_model.predict(X_test_sample)\n",
|
| 724 |
+
"y_proba_sample = knn_model.predict_proba(X_test_sample)\n",
|
| 725 |
+
"\n",
|
| 726 |
+
"# Find interesting samples\n",
|
| 727 |
+
"correct_mask = y_test_sample == y_pred_sample\n",
|
| 728 |
+
"incorrect_mask = ~correct_mask\n",
|
| 729 |
+
"\n",
|
| 730 |
+
"# High confidence correct\n",
|
| 731 |
+
"high_conf_correct_idx = np.where(correct_mask & (np.max(y_proba_sample, axis=1) > 0.9))[0]\n",
|
| 732 |
+
"# Low confidence correct\n",
|
| 733 |
+
"low_conf_correct_idx = np.where(correct_mask & (np.max(y_proba_sample, axis=1) < 0.7))[0]\n",
|
| 734 |
+
"# Misclassified\n",
|
| 735 |
+
"misclassified_idx = np.where(incorrect_mask)[0]\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"print(f\"\\nTotal samples analyzed: {len(y_test_sample)}\")\n",
|
| 738 |
+
"print(f\"Correctly classified: {correct_mask.sum()} ({correct_mask.sum()/len(y_test_sample)*100:.1f}%)\")\n",
|
| 739 |
+
"print(f\"Misclassified: {incorrect_mask.sum()} ({incorrect_mask.sum()/len(y_test_sample)*100:.1f}%)\")\n",
|
| 740 |
+
"print(f\"\\nHigh confidence correct predictions: {len(high_conf_correct_idx)}\")\n",
|
| 741 |
+
"print(f\"Low confidence correct predictions: {len(low_conf_correct_idx)}\")\n",
|
| 742 |
+
"print(f\"Misclassified predictions: {len(misclassified_idx)}\")\n",
|
| 743 |
+
"\n",
|
| 744 |
+
"# Select samples to explain\n",
|
| 745 |
+
"samples_to_explain = []\n",
|
| 746 |
+
"sample_descriptions = []\n",
|
| 747 |
+
"\n",
|
| 748 |
+
"if len(high_conf_correct_idx) > 0:\n",
|
| 749 |
+
" samples_to_explain.append(high_conf_correct_idx[0])\n",
|
| 750 |
+
" sample_descriptions.append(\"High confidence correct\")\n",
|
| 751 |
+
"\n",
|
| 752 |
+
"if len(low_conf_correct_idx) > 0:\n",
|
| 753 |
+
" samples_to_explain.append(low_conf_correct_idx[0])\n",
|
| 754 |
+
" sample_descriptions.append(\"Low confidence correct\")\n",
|
| 755 |
+
"\n",
|
| 756 |
+
"if len(misclassified_idx) > 0:\n",
|
| 757 |
+
" samples_to_explain.append(misclassified_idx[0])\n",
|
| 758 |
+
" sample_descriptions.append(\"Misclassified\")\n",
|
| 759 |
+
"\n",
|
| 760 |
+
"if len(samples_to_explain) == 0:\n",
|
| 761 |
+
" # If no special cases, just use first sample\n",
|
| 762 |
+
" samples_to_explain = [0]\n",
|
| 763 |
+
" sample_descriptions = [\"First sample\"]\n",
|
| 764 |
+
"\n",
|
| 765 |
+
"print(f\"\\nSelected {len(samples_to_explain)} samples for detailed explanation\")"
|
| 766 |
+
]
|
| 767 |
+
},
|
| 768 |
+
{
|
| 769 |
+
"cell_type": "code",
|
| 770 |
+
"execution_count": null,
|
| 771 |
+
"metadata": {},
|
| 772 |
+
"outputs": [],
|
| 773 |
+
"source": [
|
| 774 |
+
"# Waterfall plots for selected samples\n",
|
| 775 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 776 |
+
"print(\"SHAP VISUALIZATION 3: WATERFALL PLOTS (Individual Predictions)\")\n",
|
| 777 |
+
"print(\"=\" * 80)\n",
|
| 778 |
+
"print(\"\"\"\n",
|
| 779 |
+
"Waterfall plots show how each feature contributes to pushing the prediction\n",
|
| 780 |
+
"from the base value (expected value) to the final prediction for a single sample.\n",
|
| 781 |
+
"\n",
|
| 782 |
+
"Reading the plot:\n",
|
| 783 |
+
"- Starts at E[f(x)] (expected value/average prediction)\n",
|
| 784 |
+
"- Each bar shows a feature's contribution\n",
|
| 785 |
+
"- Red bars push prediction higher (toward malignant)\n",
|
| 786 |
+
"- Blue bars push prediction lower (toward benign)\n",
|
| 787 |
+
"- Final value f(x) is the model's output for this sample\n",
|
| 788 |
+
"\"\"\")\n",
|
| 789 |
+
"\n",
|
| 790 |
+
"for idx, (sample_idx, description) in enumerate(zip(samples_to_explain, sample_descriptions)):\n",
|
| 791 |
+
" true_label = data.target_names[y_test_sample.iloc[sample_idx]]\n",
|
| 792 |
+
" pred_label = data.target_names[y_pred_sample[sample_idx]]\n",
|
| 793 |
+
" pred_proba = y_proba_sample[sample_idx]\n",
|
| 794 |
+
" \n",
|
| 795 |
+
" print(f\"\\n{'-'*80}\")\n",
|
| 796 |
+
" print(f\"Sample {idx+1}: {description} (Index {sample_idx})\")\n",
|
| 797 |
+
" print(f\"{'-'*80}\")\n",
|
| 798 |
+
" print(f\"True Label: {true_label}\")\n",
|
| 799 |
+
" print(f\"Predicted Label: {pred_label}\")\n",
|
| 800 |
+
" print(f\"Prediction Probabilities:\")\n",
|
| 801 |
+
" for class_idx, class_name in enumerate(data.target_names):\n",
|
| 802 |
+
" print(f\" {class_name}: {pred_proba[class_idx]:.4f} ({pred_proba[class_idx]*100:.2f}%)\")\n",
|
| 803 |
+
" print(f\"Correct: {true_label == pred_label}\")\n",
|
| 804 |
+
" \n",
|
| 805 |
+
" # Create waterfall plot\n",
|
| 806 |
+
" shap.plots.waterfall(\n",
|
| 807 |
+
" shap.Explanation(\n",
|
| 808 |
+
" values=shap_values[1][sample_idx],\n",
|
| 809 |
+
" base_values=explainer.expected_value[1],\n",
|
| 810 |
+
" data=X_test_sample.iloc[sample_idx],\n",
|
| 811 |
+
" feature_names=X_test_sample.columns.tolist()\n",
|
| 812 |
+
" ),\n",
|
| 813 |
+
" max_display=15\n",
|
| 814 |
+
" )"
|
| 815 |
+
]
|
| 816 |
+
},
|
| 817 |
+
{
|
| 818 |
+
"cell_type": "code",
|
| 819 |
+
"execution_count": null,
|
| 820 |
+
"metadata": {},
|
| 821 |
+
"outputs": [],
|
| 822 |
+
"source": [
|
| 823 |
+
"# Force plots for individual predictions\n",
|
| 824 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 825 |
+
"print(\"SHAP VISUALIZATION 4: FORCE PLOTS (Individual Predictions)\")\n",
|
| 826 |
+
"print(\"=\" * 80)\n",
|
| 827 |
+
"print(\"\"\"\n",
|
| 828 |
+
"Force plots provide another view of individual predictions:\n",
|
| 829 |
+
"- Red features push prediction toward higher values (malignant)\n",
|
| 830 |
+
"- Blue features push prediction toward lower values (benign)\n",
|
| 831 |
+
"- Width of each feature shows magnitude of impact\n",
|
| 832 |
+
"\"\"\")\n",
|
| 833 |
+
"\n",
|
| 834 |
+
"for idx, (sample_idx, description) in enumerate(zip(samples_to_explain[:3], sample_descriptions[:3])):\n",
|
| 835 |
+
" print(f\"\\nSample {idx+1}: {description}\")\n",
|
| 836 |
+
" \n",
|
| 837 |
+
" shap.force_plot(\n",
|
| 838 |
+
" explainer.expected_value[1],\n",
|
| 839 |
+
" shap_values[1][sample_idx],\n",
|
| 840 |
+
" X_test_sample.iloc[sample_idx],\n",
|
| 841 |
+
" matplotlib=True,\n",
|
| 842 |
+
" show=False,\n",
|
| 843 |
+
" figsize=(20, 3)\n",
|
| 844 |
+
" )\n",
|
| 845 |
+
" plt.title(f'Force Plot - {description} (Sample {sample_idx})', \n",
|
| 846 |
+
" fontsize=12, fontweight='bold', pad=10)\n",
|
| 847 |
+
" plt.tight_layout()\n",
|
| 848 |
+
" plt.show()"
|
| 849 |
+
]
|
| 850 |
+
},
|
| 851 |
+
{
|
| 852 |
+
"cell_type": "code",
|
| 853 |
+
"execution_count": null,
|
| 854 |
+
"metadata": {},
|
| 855 |
+
"outputs": [],
|
| 856 |
+
"source": [
|
| 857 |
+
"# Interactive force plot for multiple samples\n",
|
| 858 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 859 |
+
"print(\"SHAP VISUALIZATION 5: INTERACTIVE FORCE PLOT (Multiple Predictions)\")\n",
|
| 860 |
+
"print(\"=\" * 80)\n",
|
| 861 |
+
"print(\"\"\"\n",
|
| 862 |
+
"This interactive visualization shows force plots for multiple samples simultaneously.\n",
|
| 863 |
+
"Samples are sorted by their predicted probability, allowing you to see patterns\n",
|
| 864 |
+
"across different prediction strengths.\n",
|
| 865 |
+
"\"\"\")\n",
|
| 866 |
+
"\n",
|
| 867 |
+
"# Use first 50 samples for visualization\n",
|
| 868 |
+
"n_force_samples = min(50, len(X_test_sample))\n",
|
| 869 |
+
"\n",
|
| 870 |
+
"shap.force_plot(\n",
|
| 871 |
+
" explainer.expected_value[1],\n",
|
| 872 |
+
" shap_values[1][:n_force_samples],\n",
|
| 873 |
+
" X_test_sample.iloc[:n_force_samples]\n",
|
| 874 |
+
")"
|
| 875 |
+
]
|
| 876 |
+
},
|
| 877 |
+
{
|
| 878 |
+
"cell_type": "markdown",
|
| 879 |
+
"metadata": {},
|
| 880 |
+
"source": [
|
| 881 |
+
"## 11. SHAP Dependence Plots - Feature Interactions"
|
| 882 |
+
]
|
| 883 |
+
},
|
| 884 |
+
{
|
| 885 |
+
"cell_type": "code",
|
| 886 |
+
"execution_count": null,
|
| 887 |
+
"metadata": {},
|
| 888 |
+
"outputs": [],
|
| 889 |
+
"source": "# Dependence plots for top features\nprint(\"=\" * 80)\nprint(\"SHAP VISUALIZATION 6: DEPENDENCE PLOTS\")\nprint(\"=\" * 80)\nprint(\"\"\"\nDependence plots show how feature values relate to their SHAP values:\n- X-axis: Feature value\n- Y-axis: SHAP value (impact on prediction)\n- Color: Another feature that may interact with this feature\n\nThese plots reveal:\n- Non-linear relationships between features and predictions\n- Feature interactions (shown by color patterns)\n- Threshold effects\n\"\"\")\n\n# Get top 6 most important features\ntop_features = feature_importance_df.head(6)['Feature'].values\n\nfig, axes = plt.subplots(2, 3, figsize=(18, 12))\naxes = axes.ravel()\n\nfor idx, feature in enumerate(top_features):\n plt.sca(axes[idx])\n shap.dependence_plot(\n feature,\n shap_values[1],\n X_test_sample.values,\n feature_names=X_test_sample.columns.tolist(),\n show=False,\n ax=axes[idx]\n )\n axes[idx].set_title(f'Dependence Plot: {feature}', fontsize=11, fontweight='bold')\n axes[idx].grid(alpha=0.3)\n\nplt.suptitle('SHAP Dependence Plots - Top 6 Features', \n fontsize=14, fontweight='bold', y=1.002)\nplt.tight_layout()\nplt.show()\n\nprint(\"\\nKey observations from dependence plots:\")\nprint(\"- Look for non-linear patterns in the scatter plots\")\nprint(\"- Color gradients indicate feature interactions\")\nprint(\"- Vertical spread at a given x-value suggests interactions with other features\")"
|
| 890 |
+
},
|
| 891 |
+
{
|
| 892 |
+
"cell_type": "markdown",
|
| 893 |
+
"metadata": {},
|
| 894 |
+
"source": [
|
| 895 |
+
"## 12. SHAP Decision Plot"
|
| 896 |
+
]
|
| 897 |
+
},
|
| 898 |
+
{
|
| 899 |
+
"cell_type": "code",
|
| 900 |
+
"execution_count": null,
|
| 901 |
+
"metadata": {},
|
| 902 |
+
"outputs": [],
|
| 903 |
+
"source": [
|
| 904 |
+
"# Decision plot showing prediction paths\n",
|
| 905 |
+
"print(\"=\" * 80)\n",
|
| 906 |
+
"print(\"SHAP VISUALIZATION 7: DECISION PLOT\")\n",
|
| 907 |
+
"print(\"=\" * 80)\n",
|
| 908 |
+
"print(\"\"\"\n",
|
| 909 |
+
"Decision plots show the cumulative effect of features on predictions:\n",
|
| 910 |
+
"- Each line represents one sample's prediction path\n",
|
| 911 |
+
"- Starts from expected value at bottom\n",
|
| 912 |
+
"- Each feature shifts the prediction up or down\n",
|
| 913 |
+
"- Final position (top) is the model's prediction\n",
|
| 914 |
+
"- Color indicates the final predicted class\n",
|
| 915 |
+
"\n",
|
| 916 |
+
"This helps visualize:\n",
|
| 917 |
+
"- Which features drive different predictions\n",
|
| 918 |
+
"- Where predictions diverge\n",
|
| 919 |
+
"- Similarity between prediction paths\n",
|
| 920 |
+
"\"\"\")\n",
|
| 921 |
+
"\n",
|
| 922 |
+
"# Select diverse samples for decision plot\n",
|
| 923 |
+
"n_decision_samples = min(30, len(X_test_sample))\n",
|
| 924 |
+
"decision_indices = np.linspace(0, len(X_test_sample)-1, n_decision_samples, dtype=int)\n",
|
| 925 |
+
"\n",
|
| 926 |
+
"plt.figure(figsize=(14, 10))\n",
|
| 927 |
+
"shap.decision_plot(\n",
|
| 928 |
+
" explainer.expected_value[1],\n",
|
| 929 |
+
" shap_values[1][decision_indices],\n",
|
| 930 |
+
" X_test_sample.iloc[decision_indices],\n",
|
| 931 |
+
" show=False,\n",
|
| 932 |
+
" feature_display_range=slice(-1, -21, -1) # Show top 20 features\n",
|
| 933 |
+
")\n",
|
| 934 |
+
"plt.title(f'SHAP Decision Plot - Prediction Paths for {n_decision_samples} Samples\\n(Top 20 Features)', \n",
|
| 935 |
+
" fontsize=14, fontweight='bold', pad=20)\n",
|
| 936 |
+
"plt.tight_layout()\n",
|
| 937 |
+
"plt.show()"
|
| 938 |
+
]
|
| 939 |
+
},
|
| 940 |
+
{
|
| 941 |
+
"cell_type": "markdown",
|
| 942 |
+
"metadata": {},
|
| 943 |
+
"source": [
|
| 944 |
+
"## 13. Advanced Analysis - Correct vs Incorrect Predictions"
|
| 945 |
+
]
|
| 946 |
+
},
|
| 947 |
+
{
|
| 948 |
+
"cell_type": "code",
|
| 949 |
+
"execution_count": null,
|
| 950 |
+
"metadata": {},
|
| 951 |
+
"outputs": [],
|
| 952 |
+
"source": "# Compare SHAP patterns between correct and incorrect predictions\nprint(\"=\" * 80)\nprint(\"ADVANCED ANALYSIS: CORRECT VS INCORRECT PREDICTIONS\")\nprint(\"=\" * 80)\n\nif incorrect_mask.sum() > 0:\n print(f\"\\nAnalyzing {incorrect_mask.sum()} misclassified samples...\\n\")\n \n # Compare average SHAP values\n shap_correct = np.abs(shap_values[1][correct_mask]).mean(axis=0)\n shap_incorrect = np.abs(shap_values[1][incorrect_mask]).mean(axis=0)\n \n comparison_df = pd.DataFrame({\n 'Feature': X_test_sample.columns,\n 'Correct_Predictions': shap_correct,\n 'Incorrect_Predictions': shap_incorrect,\n 'Difference': shap_incorrect - shap_correct\n }).sort_values('Difference', ascending=False)\n \n print(\"Features with largest difference in importance:\")\n print(\"\\nTop 10 features MORE important in incorrect predictions:\")\n print(comparison_df.head(10).to_string(index=False))\n \n # Visualize comparison\n fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n \n # Convert mask to numpy array for indexing\n correct_mask_np = correct_mask.values if hasattr(correct_mask, 'values') else correct_mask\n incorrect_mask_np = incorrect_mask.values if hasattr(incorrect_mask, 'values') else incorrect_mask\n \n # Get the data for correct and incorrect predictions\n X_correct = X_test_sample.values[correct_mask_np]\n X_incorrect = X_test_sample.values[incorrect_mask_np]\n \n # Summary plot for correct predictions\n plt.sca(axes[0])\n shap.summary_plot(shap_values[1][correct_mask_np], \n X_correct,\n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"bar\", show=False, max_display=15)\n axes[0].set_title(f'Feature Importance - Correct Predictions (n={correct_mask.sum()})', \n fontweight='bold', fontsize=12)\n \n # Summary plot for incorrect predictions\n plt.sca(axes[1])\n shap.summary_plot(shap_values[1][incorrect_mask_np], \n X_incorrect,\n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"bar\", show=False, max_display=15)\n axes[1].set_title(f'Feature Importance - Incorrect Predictions (n={incorrect_mask.sum()})', \n fontweight='bold', fontsize=12)\n \n plt.tight_layout()\n plt.show()\n \nelse:\n print(\"\\nAll samples in the test set were correctly classified!\")\n print(\"This indicates excellent model performance.\")"
|
| 953 |
+
},
|
| 954 |
+
{
|
| 955 |
+
"cell_type": "markdown",
|
| 956 |
+
"metadata": {},
|
| 957 |
+
"source": [
|
| 958 |
+
"## 14. Interactive Exploration Function"
|
| 959 |
+
]
|
| 960 |
+
},
|
| 961 |
+
{
|
| 962 |
+
"cell_type": "code",
|
| 963 |
+
"execution_count": null,
|
| 964 |
+
"metadata": {},
|
| 965 |
+
"outputs": [],
|
| 966 |
+
"source": [
|
| 967 |
+
"# Create interactive exploration function\n",
|
| 968 |
+
"def explain_sample(sample_index):\n",
|
| 969 |
+
" \"\"\"\n",
|
| 970 |
+
" Provide detailed explanation for a specific sample prediction.\n",
|
| 971 |
+
" \n",
|
| 972 |
+
" Args:\n",
|
| 973 |
+
" sample_index: Index of the sample to explain (0 to len(X_test_sample)-1)\n",
|
| 974 |
+
" \"\"\"\n",
|
| 975 |
+
" if sample_index < 0 or sample_index >= len(X_test_sample):\n",
|
| 976 |
+
" print(f\"Error: Sample index out of range. Please use 0 to {len(X_test_sample)-1}\")\n",
|
| 977 |
+
" return\n",
|
| 978 |
+
" \n",
|
| 979 |
+
" print(\"\\n\" + \"=\" * 80)\n",
|
| 980 |
+
" print(f\"DETAILED EXPLANATION FOR SAMPLE {sample_index}\")\n",
|
| 981 |
+
" print(\"=\" * 80)\n",
|
| 982 |
+
" \n",
|
| 983 |
+
" # Get prediction information\n",
|
| 984 |
+
" true_label = data.target_names[y_test_sample.iloc[sample_index]]\n",
|
| 985 |
+
" pred_label = data.target_names[y_pred_sample[sample_index]]\n",
|
| 986 |
+
" pred_proba = y_proba_sample[sample_index]\n",
|
| 987 |
+
" \n",
|
| 988 |
+
" print(f\"\\n1. PREDICTION SUMMARY\")\n",
|
| 989 |
+
" print(f\"{'-'*80}\")\n",
|
| 990 |
+
" print(f\"True Label: {true_label}\")\n",
|
| 991 |
+
" print(f\"Predicted Label: {pred_label}\")\n",
|
| 992 |
+
" print(f\"Correct: {'✓ Yes' if true_label == pred_label else '✗ No'}\")\n",
|
| 993 |
+
" print(f\"\\nPrediction Probabilities:\")\n",
|
| 994 |
+
" for class_idx, class_name in enumerate(data.target_names):\n",
|
| 995 |
+
" prob = pred_proba[class_idx]\n",
|
| 996 |
+
" bar = '█' * int(prob * 50)\n",
|
| 997 |
+
" print(f\" {class_name:12s}: {prob:.4f} ({prob*100:5.2f}%) {bar}\")\n",
|
| 998 |
+
" \n",
|
| 999 |
+
" # SHAP explanation\n",
|
| 1000 |
+
" print(f\"\\n2. SHAP EXPLANATION\")\n",
|
| 1001 |
+
" print(f\"{'-'*80}\")\n",
|
| 1002 |
+
" print(f\"Base value (expected output): {explainer.expected_value[1]:.4f}\")\n",
|
| 1003 |
+
" print(f\"Model output for this sample: {explainer.expected_value[1] + shap_values[1][sample_index].sum():.4f}\")\n",
|
| 1004 |
+
" \n",
|
| 1005 |
+
" # Top positive and negative contributors\n",
|
| 1006 |
+
" shap_sample = shap_values[1][sample_index]\n",
|
| 1007 |
+
" feature_impacts = pd.DataFrame({\n",
|
| 1008 |
+
" 'Feature': X_test_sample.columns,\n",
|
| 1009 |
+
" 'Value': X_test_sample.iloc[sample_index].values,\n",
|
| 1010 |
+
" 'SHAP_Value': shap_sample\n",
|
| 1011 |
+
" }).sort_values('SHAP_Value', key=abs, ascending=False)\n",
|
| 1012 |
+
" \n",
|
| 1013 |
+
" print(f\"\\nTop 5 features INCREASING malignant probability:\")\n",
|
| 1014 |
+
" positive_features = feature_impacts[feature_impacts['SHAP_Value'] > 0].head(5)\n",
|
| 1015 |
+
" for idx, row in positive_features.iterrows():\n",
|
| 1016 |
+
" print(f\" {row['Feature']:30s}: {row['SHAP_Value']:+.4f} (value={row['Value']:.4f})\")\n",
|
| 1017 |
+
" \n",
|
| 1018 |
+
" print(f\"\\nTop 5 features DECREASING malignant probability:\")\n",
|
| 1019 |
+
" negative_features = feature_impacts[feature_impacts['SHAP_Value'] < 0].head(5)\n",
|
| 1020 |
+
" for idx, row in negative_features.iterrows():\n",
|
| 1021 |
+
" print(f\" {row['Feature']:30s}: {row['SHAP_Value']:+.4f} (value={row['Value']:.4f})\")\n",
|
| 1022 |
+
" \n",
|
| 1023 |
+
" # Visualizations\n",
|
| 1024 |
+
" print(f\"\\n3. VISUALIZATIONS\")\n",
|
| 1025 |
+
" print(f\"{'-'*80}\\n\")\n",
|
| 1026 |
+
" \n",
|
| 1027 |
+
" # Waterfall plot\n",
|
| 1028 |
+
" print(\"Waterfall Plot:\")\n",
|
| 1029 |
+
" shap.plots.waterfall(\n",
|
| 1030 |
+
" shap.Explanation(\n",
|
| 1031 |
+
" values=shap_values[1][sample_index],\n",
|
| 1032 |
+
" base_values=explainer.expected_value[1],\n",
|
| 1033 |
+
" data=X_test_sample.iloc[sample_index],\n",
|
| 1034 |
+
" feature_names=X_test_sample.columns.tolist()\n",
|
| 1035 |
+
" ),\n",
|
| 1036 |
+
" max_display=15\n",
|
| 1037 |
+
" )\n",
|
| 1038 |
+
" \n",
|
| 1039 |
+
" # Feature values comparison\n",
|
| 1040 |
+
" print(\"\\n4. FEATURE VALUES COMPARISON\")\n",
|
| 1041 |
+
" print(f\"{'-'*80}\")\n",
|
| 1042 |
+
" print(\"\\nTop 10 features by absolute value (original scale):\")\n",
|
| 1043 |
+
" original_values = X_test.iloc[sample_index].sort_values(ascending=False).head(10)\n",
|
| 1044 |
+
" for feature, value in original_values.items():\n",
|
| 1045 |
+
" mean_val = X_train[feature].mean()\n",
|
| 1046 |
+
" std_val = X_train[feature].std()\n",
|
| 1047 |
+
" z_score = (value - mean_val) / std_val\n",
|
| 1048 |
+
" print(f\" {feature:30s}: {value:10.4f} (μ={mean_val:8.4f}, z={z_score:+6.2f})\")\n",
|
| 1049 |
+
"\n",
|
| 1050 |
+
"# Display usage instructions\n",
|
| 1051 |
+
"print(\"=\" * 80)\n",
|
| 1052 |
+
"print(\"INTERACTIVE EXPLORATION\")\n",
|
| 1053 |
+
"print(\"=\" * 80)\n",
|
| 1054 |
+
"print(f\"\\nUse the explain_sample() function to explore any prediction:\")\n",
|
| 1055 |
+
"print(f\"\\nExample usage:\")\n",
|
| 1056 |
+
"print(f\" explain_sample(0) # Explain first sample\")\n",
|
| 1057 |
+
"print(f\" explain_sample(10) # Explain 11th sample\")\n",
|
| 1058 |
+
"print(f\"\\nValid range: 0 to {len(X_test_sample)-1}\")\n",
|
| 1059 |
+
"print(f\"\\nTry these interesting samples:\")\n",
|
| 1060 |
+
"if len(high_conf_correct_idx) > 0:\n",
|
| 1061 |
+
" print(f\" explain_sample({high_conf_correct_idx[0]}) # High confidence correct\")\n",
|
| 1062 |
+
"if len(low_conf_correct_idx) > 0:\n",
|
| 1063 |
+
" print(f\" explain_sample({low_conf_correct_idx[0]}) # Low confidence correct\")\n",
|
| 1064 |
+
"if len(misclassified_idx) > 0:\n",
|
| 1065 |
+
" print(f\" explain_sample({misclassified_idx[0]}) # Misclassified sample\")\n",
|
| 1066 |
+
"\n",
|
| 1067 |
+
"print(\"\\n\" + \"=\"*80)\n",
|
| 1068 |
+
"print(\"Example: Explaining sample 0\")\n",
|
| 1069 |
+
"print(\"=\"*80)\n",
|
| 1070 |
+
"explain_sample(0)"
|
| 1071 |
+
]
|
| 1072 |
+
},
|
| 1073 |
+
{
|
| 1074 |
+
"cell_type": "markdown",
|
| 1075 |
+
"metadata": {},
|
| 1076 |
+
"source": [
|
| 1077 |
+
"## 15. GPU Resource Monitoring"
|
| 1078 |
+
]
|
| 1079 |
+
},
|
| 1080 |
+
{
|
| 1081 |
+
"cell_type": "code",
|
| 1082 |
+
"execution_count": null,
|
| 1083 |
+
"metadata": {},
|
| 1084 |
+
"outputs": [],
|
| 1085 |
+
"source": [
|
| 1086 |
+
"# Final GPU memory check\n",
|
| 1087 |
+
"print(\"=\" * 80)\n",
|
| 1088 |
+
"print(\"VAST.AI GPU RESOURCE SUMMARY\")\n",
|
| 1089 |
+
"print(\"=\" * 80)\n",
|
| 1090 |
+
"print(\"\\nFinal GPU Memory Status:\")\n",
|
| 1091 |
+
"print_gpu_memory()\n",
|
| 1092 |
+
"\n",
|
| 1093 |
+
"print(\"\\nTo free GPU memory, run: clear_gpu_memory()\")"
|
| 1094 |
+
]
|
| 1095 |
+
},
|
| 1096 |
+
{
|
| 1097 |
+
"cell_type": "markdown",
|
| 1098 |
+
"metadata": {},
|
| 1099 |
+
"source": [
|
| 1100 |
+
"## 16. Summary and Key Insights"
|
| 1101 |
+
]
|
| 1102 |
+
},
|
| 1103 |
+
{
|
| 1104 |
+
"cell_type": "code",
|
| 1105 |
+
"execution_count": null,
|
| 1106 |
+
"metadata": {},
|
| 1107 |
+
"outputs": [],
|
| 1108 |
+
"source": [
|
| 1109 |
+
"# Comprehensive summary\n",
|
| 1110 |
+
"print(\"=\" * 80)\n",
|
| 1111 |
+
"print(\"COMPREHENSIVE SUMMARY\")\n",
|
| 1112 |
+
"print(\"=\" * 80)\n",
|
| 1113 |
+
"\n",
|
| 1114 |
+
"print(\"\\n1. MODEL PERFORMANCE\")\n",
|
| 1115 |
+
"print(\"-\" * 80)\n",
|
| 1116 |
+
"print(f\"Algorithm: K-Nearest Neighbors (KNN)\")\n",
|
| 1117 |
+
"print(f\"Optimal K: {optimal_k} neighbors\")\n",
|
| 1118 |
+
"print(f\"Training Accuracy: {train_accuracy:.4f} ({train_accuracy*100:.2f}%)\")\n",
|
| 1119 |
+
"print(f\"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)\")\n",
|
| 1120 |
+
"print(f\"ROC AUC Score: {roc_auc:.4f}\")\n",
|
| 1121 |
+
"\n",
|
| 1122 |
+
"print(\"\\n2. DATASET INFORMATION\")\n",
|
| 1123 |
+
"print(\"-\" * 80)\n",
|
| 1124 |
+
"print(f\"Dataset: Breast Cancer Wisconsin (Diagnostic)\")\n",
|
| 1125 |
+
"print(f\"Total Samples: {len(X)}\")\n",
|
| 1126 |
+
"print(f\"Features: {X.shape[1]}\")\n",
|
| 1127 |
+
"print(f\"Classes: {len(data.target_names)} ({', '.join(data.target_names)})\")\n",
|
| 1128 |
+
"print(f\"Train/Test Split: {len(X_train)}/{len(X_test)}\")\n",
|
| 1129 |
+
"\n",
|
| 1130 |
+
"print(\"\\n3. TOP 5 MOST IMPORTANT FEATURES (SHAP)\")\n",
|
| 1131 |
+
"print(\"-\" * 80)\n",
|
| 1132 |
+
"for idx, row in feature_importance_df.head(5).iterrows():\n",
|
| 1133 |
+
" print(f\"{idx+1}. {row['Feature']:30s} (mean |SHAP|={row['Mean_Abs_SHAP']:.4f})\")\n",
|
| 1134 |
+
"\n",
|
| 1135 |
+
"print(\"\\n4. SHAP EXPLAINABILITY INSIGHTS\")\n",
|
| 1136 |
+
"print(\"-\" * 80)\n",
|
| 1137 |
+
"print(f\"Samples explained: {n_samples}\")\n",
|
| 1138 |
+
"print(f\"Background dataset size: {background_size}\")\n",
|
| 1139 |
+
"print(f\"Computation time: {elapsed_time:.2f} seconds\")\n",
|
| 1140 |
+
"print(f\"Time per sample: {elapsed_time/n_samples:.2f} seconds\")\n",
|
| 1141 |
+
"\n",
|
| 1142 |
+
"print(\"\\n5. KEY TAKEAWAYS\")\n",
|
| 1143 |
+
"print(\"-\" * 80)\n",
|
| 1144 |
+
"print(\"\"\"\n",
|
| 1145 |
+
"✓ SHAP provides model-agnostic explainability for KNN predictions\n",
|
| 1146 |
+
"✓ Feature importance rankings identify which features drive predictions\n",
|
| 1147 |
+
"✓ Waterfall plots explain individual predictions step-by-step\n",
|
| 1148 |
+
"✓ Dependence plots reveal non-linear relationships and interactions\n",
|
| 1149 |
+
"✓ Summary plots show global patterns across all predictions\n",
|
| 1150 |
+
"✓ Decision plots visualize prediction paths for multiple samples\n",
|
| 1151 |
+
"✓ KNN combined with SHAP offers both accuracy and interpretability\n",
|
| 1152 |
+
"\"\"\")\n",
|
| 1153 |
+
"\n",
|
| 1154 |
+
"print(\"\\n6. SHAP VISUALIZATION TYPES USED\")\n",
|
| 1155 |
+
"print(\"-\" * 80)\n",
|
| 1156 |
+
"visualizations = [\n",
|
| 1157 |
+
" (\"Summary Plot (Beeswarm)\", \"Global feature importance with value distributions\"),\n",
|
| 1158 |
+
" (\"Bar Plot\", \"Mean absolute feature importance\"),\n",
|
| 1159 |
+
" (\"Waterfall Plot\", \"Individual prediction breakdown\"),\n",
|
| 1160 |
+
" (\"Force Plot\", \"Visual representation of feature contributions\"),\n",
|
| 1161 |
+
" (\"Dependence Plot\", \"Feature-value relationships and interactions\"),\n",
|
| 1162 |
+
" (\"Decision Plot\", \"Cumulative prediction paths for multiple samples\")\n",
|
| 1163 |
+
"]\n",
|
| 1164 |
+
"for viz_name, description in visualizations:\n",
|
| 1165 |
+
" print(f\" • {viz_name:30s}: {description}\")\n",
|
| 1166 |
+
"\n",
|
| 1167 |
+
"print(\"\\n7. VAST.AI ENVIRONMENT\")\n",
|
| 1168 |
+
"print(\"-\" * 80)\n",
|
| 1169 |
+
"if torch.cuda.is_available():\n",
|
| 1170 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 1171 |
+
" print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")\n",
|
| 1172 |
+
" print(f\"CUDA Version: {torch.version.cuda}\")\n",
|
| 1173 |
+
"else:\n",
|
| 1174 |
+
" print(\"Running on CPU (GPU not detected)\")\n",
|
| 1175 |
+
"\n",
|
| 1176 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 1177 |
+
"print(\"ANALYSIS COMPLETE\")\n",
|
| 1178 |
+
"print(\"=\" * 80)\n",
|
| 1179 |
+
"print(\"\\nNext steps:\")\n",
|
| 1180 |
+
"print(\" 1. Explore individual predictions using: explain_sample(index)\")\n",
|
| 1181 |
+
"print(\" 2. Try different K values for KNN\")\n",
|
| 1182 |
+
"print(\" 3. Test with other datasets (load_wine, load_iris, etc.)\")\n",
|
| 1183 |
+
"print(\" 4. Compare with other algorithms (Random Forest, SVM, XGBoost)\")\n",
|
| 1184 |
+
"print(\" 5. Use SHAP insights for feature engineering\")\n"
|
| 1185 |
+
]
|
| 1186 |
+
},
|
| 1187 |
+
{
|
| 1188 |
+
"cell_type": "markdown",
|
| 1189 |
+
"metadata": {},
|
| 1190 |
+
"source": [
|
| 1191 |
+
"## 17. Cleanup and Resource Management"
|
| 1192 |
+
]
|
| 1193 |
+
},
|
| 1194 |
+
{
|
| 1195 |
+
"cell_type": "code",
|
| 1196 |
+
"execution_count": null,
|
| 1197 |
+
"metadata": {},
|
| 1198 |
+
"outputs": [],
|
| 1199 |
+
"source": [
|
| 1200 |
+
"# Optional: Clear GPU memory when done\n",
|
| 1201 |
+
"# Uncomment the line below to free GPU memory\n",
|
| 1202 |
+
"# clear_gpu_memory()\n",
|
| 1203 |
+
"\n",
|
| 1204 |
+
"print(\"=\" * 80)\n",
|
| 1205 |
+
"print(\"NOTEBOOK COMPLETE\")\n",
|
| 1206 |
+
"print(\"=\" * 80)\n",
|
| 1207 |
+
"print(\"\\nThank you for using this SHAP explainability notebook!\")\n",
|
| 1208 |
+
"print(\"\\nTo clear GPU memory, run: clear_gpu_memory()\")\n",
|
| 1209 |
+
"print(\"To check GPU status, run: print_gpu_memory()\")"
|
| 1210 |
+
]
|
| 1211 |
+
}
|
| 1212 |
+
],
|
| 1213 |
+
"metadata": {
|
| 1214 |
+
"kernelspec": {
|
| 1215 |
+
"display_name": "Python 3",
|
| 1216 |
+
"language": "python",
|
| 1217 |
+
"name": "python3"
|
| 1218 |
+
},
|
| 1219 |
+
"language_info": {
|
| 1220 |
+
"codemirror_mode": {
|
| 1221 |
+
"name": "ipython",
|
| 1222 |
+
"version": 3
|
| 1223 |
+
},
|
| 1224 |
+
"file_extension": ".py",
|
| 1225 |
+
"mimetype": "text/x-python",
|
| 1226 |
+
"name": "python",
|
| 1227 |
+
"nbconvert_exporter": "python",
|
| 1228 |
+
"pygments_lexer": "ipython3",
|
| 1229 |
+
"version": "3.8.10"
|
| 1230 |
+
}
|
| 1231 |
+
},
|
| 1232 |
+
"nbformat": 4,
|
| 1233 |
+
"nbformat_minor": 4
|
| 1234 |
+
}
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 KNN SHAP Explainability Project
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# KNN Explainability with SHAP
|
| 2 |
+
|
| 3 |
+
A comprehensive Jupyter notebook demonstrating how to use SHAP (SHapley Additive exPlanations) to interpret K-Nearest Neighbors (KNN) model predictions with detailed visualizations.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This project provides a complete walkthrough of:
|
| 8 |
+
- Training a K-Nearest Neighbors classifier on the Breast Cancer Wisconsin dataset
|
| 9 |
+
- Using SHAP to explain model predictions at both global and local levels
|
| 10 |
+
- Creating comprehensive visualizations to understand feature importance and model behavior
|
| 11 |
+
- Interactive exploration of individual predictions
|
| 12 |
+
|
| 13 |
+
## Features
|
| 14 |
+
|
| 15 |
+
### Model Training
|
| 16 |
+
- Optimal K value selection through cross-validation
|
| 17 |
+
- StandardScaler preprocessing for KNN optimization
|
| 18 |
+
- Comprehensive model evaluation metrics
|
| 19 |
+
- ROC curves and confusion matrices
|
| 20 |
+
|
| 21 |
+
### SHAP Explainability
|
| 22 |
+
- **Summary Plots**: Global feature importance with value distributions
|
| 23 |
+
- **Bar Plots**: Mean absolute feature importance rankings
|
| 24 |
+
- **Waterfall Plots**: Step-by-step breakdown of individual predictions
|
| 25 |
+
- **Force Plots**: Visual representation of feature contributions
|
| 26 |
+
- **Dependence Plots**: Feature-value relationships and interactions
|
| 27 |
+
- **Decision Plots**: Cumulative prediction paths for multiple samples
|
| 28 |
+
|
| 29 |
+
### Interactive Analysis
|
| 30 |
+
- Custom `explain_sample()` function for detailed prediction exploration
|
| 31 |
+
- Comparison of correct vs. incorrect predictions
|
| 32 |
+
- GPU memory management utilities
|
| 33 |
+
- Comprehensive reporting and summaries
|
| 34 |
+
|
| 35 |
+
## Prerequisites
|
| 36 |
+
|
| 37 |
+
### Required Packages
|
| 38 |
+
```
|
| 39 |
+
torch
|
| 40 |
+
shap
|
| 41 |
+
scikit-learn
|
| 42 |
+
matplotlib
|
| 43 |
+
seaborn
|
| 44 |
+
pandas
|
| 45 |
+
numpy
|
| 46 |
+
plotly
|
| 47 |
+
ipywidgets
|
| 48 |
+
tqdm
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Environment
|
| 52 |
+
- Python 3.7+
|
| 53 |
+
- Jupyter Notebook or JupyterLab
|
| 54 |
+
- VS Code with Jupyter extension (recommended)
|
| 55 |
+
- GPU support optional (CUDA-enabled PyTorch for faster computation)
|
| 56 |
+
|
| 57 |
+
## Installation
|
| 58 |
+
|
| 59 |
+
1. Clone this repository:
|
| 60 |
+
```bash
|
| 61 |
+
git clone <repository-url>
|
| 62 |
+
cd <repository-directory>
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
2. Install required packages:
|
| 66 |
+
```bash
|
| 67 |
+
pip install torch shap scikit-learn matplotlib seaborn pandas numpy plotly ipywidgets tqdm
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
3. Launch Jupyter:
|
| 71 |
+
```bash
|
| 72 |
+
jupyter notebook
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
4. Open `KNN_SHAP_Explainability.ipynb`
|
| 76 |
+
|
| 77 |
+
## Usage
|
| 78 |
+
|
| 79 |
+
### Basic Usage
|
| 80 |
+
|
| 81 |
+
Run the notebook cells sequentially from top to bottom. The notebook is structured in logical sections:
|
| 82 |
+
|
| 83 |
+
1. **Environment Setup**: GPU verification and package installation
|
| 84 |
+
2. **Data Loading**: Load and explore the Breast Cancer Wisconsin dataset
|
| 85 |
+
3. **Preprocessing**: Feature scaling and train-test split
|
| 86 |
+
4. **Model Training**: KNN optimization and evaluation
|
| 87 |
+
5. **SHAP Analysis**: Compute SHAP values for test samples
|
| 88 |
+
6. **Visualizations**: Generate all SHAP plots and explanations
|
| 89 |
+
7. **Interactive Exploration**: Use custom functions to explore predictions
|
| 90 |
+
|
| 91 |
+
### Interactive Exploration
|
| 92 |
+
|
| 93 |
+
After running all cells, use the `explain_sample()` function to explore any prediction:
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
# Explain the first sample
|
| 97 |
+
explain_sample(0)
|
| 98 |
+
|
| 99 |
+
# Explain a high-confidence correct prediction
|
| 100 |
+
explain_sample(5)
|
| 101 |
+
|
| 102 |
+
# Explain a misclassified sample
|
| 103 |
+
explain_sample(42)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### GPU Memory Management
|
| 107 |
+
|
| 108 |
+
If running on GPU, monitor and manage memory:
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
# Check current GPU memory usage
|
| 112 |
+
print_gpu_memory()
|
| 113 |
+
|
| 114 |
+
# Clear GPU cache
|
| 115 |
+
clear_gpu_memory()
|
| 116 |
+
|
| 117 |
+
# Get optimal device (CPU or GPU)
|
| 118 |
+
device = get_optimal_device()
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Dataset
|
| 122 |
+
|
| 123 |
+
The notebook uses the **Breast Cancer Wisconsin (Diagnostic)** dataset from scikit-learn:
|
| 124 |
+
- **Samples**: 569
|
| 125 |
+
- **Features**: 30 (mean, standard error, and worst values of 10 real-valued features)
|
| 126 |
+
- **Classes**: 2 (Malignant, Benign)
|
| 127 |
+
- **Task**: Binary classification
|
| 128 |
+
|
| 129 |
+
Features include radius, texture, perimeter, area, smoothness, compactness, concavity, concave points, symmetry, and fractal dimension.
|
| 130 |
+
|
| 131 |
+
## Model Performance
|
| 132 |
+
|
| 133 |
+
The KNN model achieves:
|
| 134 |
+
- High accuracy on both training and test sets
|
| 135 |
+
- Optimal K value determined through cross-validation
|
| 136 |
+
- ROC AUC score > 0.95 (typical)
|
| 137 |
+
- Interpretable predictions through SHAP analysis
|
| 138 |
+
|
| 139 |
+
## SHAP Visualization Guide
|
| 140 |
+
|
| 141 |
+
### Summary Plot (Beeswarm)
|
| 142 |
+
- Shows global feature importance across all samples
|
| 143 |
+
- Color indicates feature value (red=high, blue=low)
|
| 144 |
+
- Horizontal position shows impact on prediction
|
| 145 |
+
|
| 146 |
+
### Waterfall Plot
|
| 147 |
+
- Explains individual predictions step-by-step
|
| 148 |
+
- Starts from base value (expected prediction)
|
| 149 |
+
- Each bar shows a feature's contribution
|
| 150 |
+
- Red pushes toward malignant, blue toward benign
|
| 151 |
+
|
| 152 |
+
### Dependence Plot
|
| 153 |
+
- Reveals non-linear feature relationships
|
| 154 |
+
- Shows feature interactions through color
|
| 155 |
+
- Identifies threshold effects
|
| 156 |
+
|
| 157 |
+
### Decision Plot
|
| 158 |
+
- Visualizes prediction paths for multiple samples
|
| 159 |
+
- Shows cumulative effect of features
|
| 160 |
+
- Helps identify prediction patterns
|
| 161 |
+
|
| 162 |
+
## Key Insights
|
| 163 |
+
|
| 164 |
+
1. **Feature Importance**: SHAP identifies the most critical features for cancer diagnosis
|
| 165 |
+
2. **Non-linearity**: Dependence plots reveal complex feature-value relationships
|
| 166 |
+
3. **Interactions**: Color gradients show which features interact
|
| 167 |
+
4. **Individual Explanations**: Each prediction can be fully explained and understood
|
| 168 |
+
5. **Model Trust**: Transparent explanations increase confidence in model decisions
|
| 169 |
+
|
| 170 |
+
## Customization
|
| 171 |
+
|
| 172 |
+
### Using Different Datasets
|
| 173 |
+
|
| 174 |
+
Replace the data loading section with your own dataset:
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
# Load your dataset
|
| 178 |
+
X = pd.DataFrame(your_data)
|
| 179 |
+
y = pd.Series(your_labels)
|
| 180 |
+
|
| 181 |
+
# Continue with the rest of the notebook
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
### Adjusting SHAP Computation
|
| 185 |
+
|
| 186 |
+
Modify SHAP parameters for speed/accuracy tradeoff:
|
| 187 |
+
|
| 188 |
+
```python
|
| 189 |
+
# Faster computation (less accurate)
|
| 190 |
+
shap_values = explainer.shap_values(X_test_sample, nsamples=50)
|
| 191 |
+
|
| 192 |
+
# More accurate (slower)
|
| 193 |
+
shap_values = explainer.shap_values(X_test_sample, nsamples=200)
|
| 194 |
+
|
| 195 |
+
# Smaller background dataset (faster)
|
| 196 |
+
background = shap.kmeans(X_train_scaled, 50)
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
### Trying Other Algorithms
|
| 200 |
+
|
| 201 |
+
The SHAP approach works with any model:
|
| 202 |
+
|
| 203 |
+
```python
|
| 204 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 205 |
+
|
| 206 |
+
# Train Random Forest instead of KNN
|
| 207 |
+
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
| 208 |
+
model.fit(X_train_scaled, y_train)
|
| 209 |
+
|
| 210 |
+
# Use TreeExplainer for faster computation on tree-based models
|
| 211 |
+
explainer = shap.TreeExplainer(model)
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
## Performance Tips
|
| 215 |
+
|
| 216 |
+
1. **GPU Acceleration**: Use GPU for faster PyTorch operations
|
| 217 |
+
2. **Background Size**: Reduce background dataset size for faster SHAP computation
|
| 218 |
+
3. **Sample Size**: Start with fewer samples (e.g., 50) for quick testing
|
| 219 |
+
4. **nsamples Parameter**: Lower values speed up computation but reduce accuracy
|
| 220 |
+
5. **Memory Management**: Clear GPU cache between major computations
|
| 221 |
+
|
| 222 |
+
## Troubleshooting
|
| 223 |
+
|
| 224 |
+
### Common Issues
|
| 225 |
+
|
| 226 |
+
**GPU not detected:**
|
| 227 |
+
- Check CUDA installation
|
| 228 |
+
- Verify PyTorch GPU support: `torch.cuda.is_available()`
|
| 229 |
+
- Notebook will fall back to CPU automatically
|
| 230 |
+
|
| 231 |
+
**SHAP computation too slow:**
|
| 232 |
+
- Reduce background dataset size
|
| 233 |
+
- Decrease number of test samples
|
| 234 |
+
- Lower nsamples parameter
|
| 235 |
+
|
| 236 |
+
**Memory errors:**
|
| 237 |
+
- Process fewer samples at once
|
| 238 |
+
- Clear GPU cache with `clear_gpu_memory()`
|
| 239 |
+
- Reduce background dataset size
|
| 240 |
+
|
| 241 |
+
**Visualization issues:**
|
| 242 |
+
- Ensure matplotlib backend is compatible
|
| 243 |
+
- Update SHAP to latest version
|
| 244 |
+
- Restart kernel if plots don't render
|
| 245 |
+
|
| 246 |
+
## Contributing
|
| 247 |
+
|
| 248 |
+
Contributions are welcome! Please feel free to submit pull requests or open issues for bugs, questions, or new features.
|
| 249 |
+
|
| 250 |
+
## License
|
| 251 |
+
|
| 252 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 253 |
+
|
| 254 |
+
## Acknowledgments
|
| 255 |
+
|
| 256 |
+
- **SHAP**: Scott Lundberg et al. for the SHAP library
|
| 257 |
+
- **scikit-learn**: For the Breast Cancer Wisconsin dataset and ML tools
|
| 258 |
+
- **PyTorch**: For GPU acceleration capabilities
|
| 259 |
+
- **Community**: All contributors to the open-source ML/AI ecosystem
|
| 260 |
+
|
| 261 |
+
## References
|
| 262 |
+
|
| 263 |
+
- Lundberg, S. M., & Lee, S. I. (2017). A unified approach to interpreting model predictions. NeurIPS.
|
| 264 |
+
- SHAP Documentation: https://shap.readthedocs.io/
|
| 265 |
+
- scikit-learn Documentation: https://scikit-learn.org/
|
| 266 |
+
- Breast Cancer Wisconsin Dataset: https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)
|
| 267 |
+
|
| 268 |
+
## Contact
|
| 269 |
+
|
| 270 |
+
For questions or feedback, please open an issue in the repository.
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
**Note**: This notebook is designed for educational purposes and demonstrates best practices for ML model interpretability using SHAP.
|