ahczhg commited on
Commit
9aa33c0
·
verified ·
1 Parent(s): ac22b9d

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. KNN_SHAP_Explainability.ipynb +1234 -0
  2. LICENSE +21 -0
  3. README.md +274 -0
KNN_SHAP_Explainability.ipynb ADDED
@@ -0,0 +1,1234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": "# KNN Explainability with SHAP on Cloud GPU\n\nThis notebook demonstrates how to use SHAP (SHapley Additive exPlanations) to interpret K-Nearest Neighbors (KNN) model predictions with comprehensive visualizations.\n\n**Environment:** Cloud GPU Instance (Running in VS Code)\n\n## Prerequisites\n- Cloud GPU instance with GPU (RTX 3090, RTX 4090, or A100 recommended)\n- VS Code with Jupyter extension installed\n- SSH connection to cloud instance"
7
+ },
8
+ {
9
+ "cell_type": "markdown",
10
+ "metadata": {},
11
+ "source": [
12
+ "## 1. Environment Setup and GPU Verification"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "# Check GPU availability and specifications\n",
22
+ "import subprocess\n",
23
+ "import sys\n",
24
+ "\n",
25
+ "print(\"=\" * 80)\n",
26
+ "print(\"VAST.AI GPU INFORMATION\")\n",
27
+ "print(\"=\" * 80)\n",
28
+ "\n",
29
+ "try:\n",
30
+ " result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
31
+ " print(result.stdout)\n",
32
+ "except FileNotFoundError:\n",
33
+ " print(\"nvidia-smi not found. GPU may not be available.\")\n",
34
+ "\n",
35
+ "print(\"\\n\" + \"=\" * 80)\n",
36
+ "print(\"PYTHON ENVIRONMENT\")\n",
37
+ "print(\"=\" * 80)\n",
38
+ "print(f\"Python version: {sys.version}\")\n",
39
+ "print(f\"Python executable: {sys.executable}\")"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "# Check PyTorch CUDA availability\n",
49
+ "try:\n",
50
+ " import torch\n",
51
+ " print(\"\\n\" + \"=\" * 80)\n",
52
+ " print(\"PYTORCH & CUDA INFORMATION\")\n",
53
+ " print(\"=\" * 80)\n",
54
+ " print(f\"PyTorch version: {torch.__version__}\")\n",
55
+ " print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
56
+ " \n",
57
+ " if torch.cuda.is_available():\n",
58
+ " print(f\"CUDA version: {torch.version.cuda}\")\n",
59
+ " print(f\"Number of GPUs: {torch.cuda.device_count()}\")\n",
60
+ " for i in range(torch.cuda.device_count()):\n",
61
+ " print(f\"\\nGPU {i}: {torch.cuda.get_device_name(i)}\")\n",
62
+ " print(f\" Memory Total: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB\")\n",
63
+ " print(f\" Memory Allocated: {torch.cuda.memory_allocated(i) / 1e9:.4f} GB\")\n",
64
+ " print(f\" Memory Cached: {torch.cuda.memory_reserved(i) / 1e9:.4f} GB\")\n",
65
+ " else:\n",
66
+ " print(\"\\nWARNING: CUDA not available. Running on CPU.\")\n",
67
+ " print(\"This notebook is optimized for GPU but will work on CPU (slower).\")\n",
68
+ "except ImportError:\n",
69
+ " print(\"\\nPyTorch not installed yet. Will install in next cell.\")"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# Install required packages\n",
79
+ "print(\"Installing required packages...\\n\")\n",
80
+ "\n",
81
+ "packages = [\n",
82
+ " 'torch',\n",
83
+ " 'shap',\n",
84
+ " 'scikit-learn',\n",
85
+ " 'matplotlib',\n",
86
+ " 'seaborn',\n",
87
+ " 'pandas',\n",
88
+ " 'numpy',\n",
89
+ " 'plotly',\n",
90
+ " 'ipywidgets',\n",
91
+ " 'tqdm'\n",
92
+ "]\n",
93
+ "\n",
94
+ "for package in packages:\n",
95
+ " print(f\"Installing {package}...\")\n",
96
+ " subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])\n",
97
+ "\n",
98
+ "print(\"\\n\" + \"=\"*80)\n",
99
+ "print(\"All packages installed successfully!\")\n",
100
+ "print(\"=\"*80)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "metadata": {},
106
+ "source": [
107
+ "## 2. Import Libraries"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": "# Import all necessary libraries\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport shap\nimport torch\nimport sklearn\nfrom sklearn.neighbors import KNeighborsClassifier\nfrom sklearn.model_selection import train_test_split, cross_val_score\nfrom sklearn.preprocessing import StandardScaler\nfrom sklearn.datasets import load_breast_cancer, load_wine, load_iris, make_classification\nfrom sklearn.metrics import (\n accuracy_score, \n classification_report, \n confusion_matrix,\n roc_curve,\n roc_auc_score,\n precision_recall_curve\n)\nfrom tqdm.auto import tqdm\nimport warnings\nwarnings.filterwarnings('ignore')\n\n# Set random seeds for reproducibility\nnp.random.seed(42)\ntorch.manual_seed(42)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(42)\n\n# Configure plotting style\nsns.set_style('whitegrid')\nplt.rcParams['figure.figsize'] = (14, 8)\nplt.rcParams['font.size'] = 10\nplt.rcParams['axes.titlesize'] = 14\nplt.rcParams['axes.labelsize'] = 12\n\n# Initialize SHAP's JavaScript visualization\nshap.initjs()\n\nprint(\"=\" * 80)\nprint(\"LIBRARY VERSIONS\")\nprint(\"=\" * 80)\nprint(f\"NumPy version: {np.__version__}\")\nprint(f\"Pandas version: {pd.__version__}\")\nprint(f\"Matplotlib version: {plt.matplotlib.__version__}\")\nprint(f\"Seaborn version: {sns.__version__}\")\nprint(f\"SHAP version: {shap.__version__}\")\nprint(f\"PyTorch version: {torch.__version__}\")\nprint(f\"Scikit-learn version: {sklearn.__version__}\")\nprint(\"\\nLibraries imported successfully!\")"
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "metadata": {},
120
+ "source": [
121
+ "## 3. GPU Memory Management Functions"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# Utility functions for GPU memory management on Vast.ai\n",
131
+ "\n",
132
+ "def print_gpu_memory():\n",
133
+ " \"\"\"Print current GPU memory usage\"\"\"\n",
134
+ " if torch.cuda.is_available():\n",
135
+ " for i in range(torch.cuda.device_count()):\n",
136
+ " allocated = torch.cuda.memory_allocated(i) / 1e9\n",
137
+ " cached = torch.cuda.memory_reserved(i) / 1e9\n",
138
+ " total = torch.cuda.get_device_properties(i).total_memory / 1e9\n",
139
+ " print(f\"GPU {i} ({torch.cuda.get_device_name(i)}):\")\n",
140
+ " print(f\" Allocated: {allocated:.3f} GB / {total:.2f} GB ({allocated/total*100:.1f}%)\")\n",
141
+ " print(f\" Cached: {cached:.3f} GB\")\n",
142
+ " else:\n",
143
+ " print(\"No GPU available\")\n",
144
+ "\n",
145
+ "def clear_gpu_memory():\n",
146
+ " \"\"\"Clear GPU cache\"\"\"\n",
147
+ " if torch.cuda.is_available():\n",
148
+ " torch.cuda.empty_cache()\n",
149
+ " print(\"GPU cache cleared\")\n",
150
+ "\n",
151
+ "def get_optimal_device():\n",
152
+ " \"\"\"Get optimal device for computation\"\"\"\n",
153
+ " if torch.cuda.is_available():\n",
154
+ " device = torch.device('cuda')\n",
155
+ " print(f\"Using GPU: {torch.cuda.get_device_name(0)}\")\n",
156
+ " else:\n",
157
+ " device = torch.device('cpu')\n",
158
+ " print(\"Using CPU (GPU not available)\")\n",
159
+ " return device\n",
160
+ "\n",
161
+ "# Initialize device\n",
162
+ "device = get_optimal_device()\n",
163
+ "print(\"\\nInitial GPU Memory:\")\n",
164
+ "print_gpu_memory()"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "metadata": {},
170
+ "source": [
171
+ "## 4. Data Loading and Exploration"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "# Load the Breast Cancer Wisconsin dataset\n",
181
+ "print(\"=\" * 80)\n",
182
+ "print(\"LOADING DATASET\")\n",
183
+ "print(\"=\" * 80)\n",
184
+ "\n",
185
+ "data = load_breast_cancer()\n",
186
+ "X = pd.DataFrame(data.data, columns=data.feature_names)\n",
187
+ "y = pd.Series(data.target, name='target')\n",
188
+ "\n",
189
+ "print(f\"\\nDataset: Breast Cancer Wisconsin (Diagnostic)\")\n",
190
+ "print(f\"Number of samples: {X.shape[0]}\")\n",
191
+ "print(f\"Number of features: {X.shape[1]}\")\n",
192
+ "print(f\"Number of classes: {len(data.target_names)}\")\n",
193
+ "print(f\"Class names: {list(data.target_names)}\")\n",
194
+ "print(f\"\\nTarget distribution:\")\n",
195
+ "for idx, name in enumerate(data.target_names):\n",
196
+ " count = (y == idx).sum()\n",
197
+ " percentage = count / len(y) * 100\n",
198
+ " print(f\" {name}: {count} ({percentage:.2f}%)\")\n",
199
+ "\n",
200
+ "print(f\"\\nFeature statistics:\")\n",
201
+ "print(X.describe().round(2))"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "# Display first few rows\n",
211
+ "print(\"\\nFirst 5 rows of the dataset:\")\n",
212
+ "display_df = X.head()\n",
213
+ "display_df['target'] = y.head().map({0: data.target_names[0], 1: data.target_names[1]})\n",
214
+ "display_df"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "metadata": {},
220
+ "source": [
221
+ "## 5. Exploratory Data Analysis"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "# Target distribution visualization\n",
231
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
232
+ "\n",
233
+ "# Bar plot\n",
234
+ "target_counts = y.value_counts().sort_index()\n",
235
+ "colors = ['#FF6B6B', '#4ECDC4']\n",
236
+ "axes[0].bar(data.target_names, target_counts.values, color=colors, edgecolor='black', linewidth=1.5)\n",
237
+ "axes[0].set_ylabel('Count', fontweight='bold')\n",
238
+ "axes[0].set_title('Target Class Distribution', fontweight='bold', fontsize=14)\n",
239
+ "axes[0].grid(axis='y', alpha=0.3)\n",
240
+ "\n",
241
+ "# Add count labels on bars\n",
242
+ "for i, (name, count) in enumerate(zip(data.target_names, target_counts.values)):\n",
243
+ " axes[0].text(i, count + 5, str(count), ha='center', fontweight='bold')\n",
244
+ "\n",
245
+ "# Pie chart\n",
246
+ "axes[1].pie(target_counts.values, labels=data.target_names, autopct='%1.1f%%',\n",
247
+ " colors=colors, startangle=90, textprops={'fontsize': 12, 'fontweight': 'bold'})\n",
248
+ "axes[1].set_title('Target Class Proportion', fontweight='bold', fontsize=14)\n",
249
+ "\n",
250
+ "plt.tight_layout()\n",
251
+ "plt.show()"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "# Feature distributions for key features\n",
261
+ "print(\"Feature Distributions (Top 12 Features):\")\n",
262
+ "fig, axes = plt.subplots(3, 4, figsize=(16, 12))\n",
263
+ "axes = axes.ravel()\n",
264
+ "\n",
265
+ "for idx, col in enumerate(X.columns[:12]):\n",
266
+ " # Histogram for each class\n",
267
+ " for target_idx, target_name in enumerate(data.target_names):\n",
268
+ " mask = y == target_idx\n",
269
+ " axes[idx].hist(X.loc[mask, col], bins=25, alpha=0.6, \n",
270
+ " label=target_name, color=colors[target_idx], edgecolor='black')\n",
271
+ " \n",
272
+ " axes[idx].set_title(f'{col}', fontsize=10, fontweight='bold')\n",
273
+ " axes[idx].set_xlabel('Value', fontsize=9)\n",
274
+ " axes[idx].set_ylabel('Frequency', fontsize=9)\n",
275
+ " axes[idx].legend(fontsize=8)\n",
276
+ " axes[idx].grid(alpha=0.3)\n",
277
+ "\n",
278
+ "plt.tight_layout()\n",
279
+ "plt.suptitle('Feature Distributions by Class', y=1.002, fontsize=16, fontweight='bold')\n",
280
+ "plt.show()"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "# Correlation heatmap for top features\n",
290
+ "print(\"\\nFeature Correlation Analysis (Top 15 Features):\")\n",
291
+ "top_features = X.columns[:15]\n",
292
+ "correlation_matrix = X[top_features].corr()\n",
293
+ "\n",
294
+ "plt.figure(figsize=(14, 12))\n",
295
+ "mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))\n",
296
+ "sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt='.2f', \n",
297
+ " cmap='coolwarm', center=0, square=True, linewidths=0.5,\n",
298
+ " cbar_kws={\"shrink\": 0.8, \"label\": \"Correlation Coefficient\"})\n",
299
+ "plt.title('Feature Correlation Matrix (Top 15 Features)', fontsize=14, fontweight='bold', pad=20)\n",
300
+ "plt.tight_layout()\n",
301
+ "plt.show()\n",
302
+ "\n",
303
+ "# Find highly correlated pairs\n",
304
+ "high_corr_pairs = []\n",
305
+ "for i in range(len(correlation_matrix.columns)):\n",
306
+ " for j in range(i+1, len(correlation_matrix.columns)):\n",
307
+ " if abs(correlation_matrix.iloc[i, j]) > 0.8:\n",
308
+ " high_corr_pairs.append((\n",
309
+ " correlation_matrix.columns[i],\n",
310
+ " correlation_matrix.columns[j],\n",
311
+ " correlation_matrix.iloc[i, j]\n",
312
+ " ))\n",
313
+ "\n",
314
+ "if high_corr_pairs:\n",
315
+ " print(\"\\nHighly correlated feature pairs (|r| > 0.8):\")\n",
316
+ " for feat1, feat2, corr in high_corr_pairs[:5]:\n",
317
+ " print(f\" {feat1} <-> {feat2}: {corr:.3f}\")"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "metadata": {},
323
+ "source": [
324
+ "## 6. Data Preprocessing"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "metadata": {},
331
+ "outputs": [],
332
+ "source": [
333
+ "# Split the data\n",
334
+ "print(\"=\" * 80)\n",
335
+ "print(\"DATA PREPROCESSING\")\n",
336
+ "print(\"=\" * 80)\n",
337
+ "\n",
338
+ "X_train, X_test, y_train, y_test = train_test_split(\n",
339
+ " X, y, test_size=0.2, random_state=42, stratify=y\n",
340
+ ")\n",
341
+ "\n",
342
+ "print(f\"\\nTraining set size: {X_train.shape[0]} samples\")\n",
343
+ "print(f\"Test set size: {X_test.shape[0]} samples\")\n",
344
+ "print(f\"\\nTraining set class distribution:\")\n",
345
+ "for idx, name in enumerate(data.target_names):\n",
346
+ " count = (y_train == idx).sum()\n",
347
+ " percentage = count / len(y_train) * 100\n",
348
+ " print(f\" {name}: {count} ({percentage:.2f}%)\")\n",
349
+ "\n",
350
+ "# Feature scaling (critical for KNN)\n",
351
+ "print(\"\\nApplying StandardScaler...\")\n",
352
+ "scaler = StandardScaler()\n",
353
+ "X_train_scaled = scaler.fit_transform(X_train)\n",
354
+ "X_test_scaled = scaler.transform(X_test)\n",
355
+ "\n",
356
+ "# Convert back to DataFrame for better handling\n",
357
+ "X_train_scaled = pd.DataFrame(X_train_scaled, columns=X.columns, index=X_train.index)\n",
358
+ "X_test_scaled = pd.DataFrame(X_test_scaled, columns=X.columns, index=X_test.index)\n",
359
+ "\n",
360
+ "print(f\"\\nScaled features - Mean: {X_train_scaled.mean().mean():.6f}\")\n",
361
+ "print(f\"Scaled features - Std: {X_train_scaled.std().mean():.6f}\")\n",
362
+ "print(\"\\nData preprocessing completed!\")"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "# Visualize scaling effect\n",
372
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
373
+ "\n",
374
+ "# Before scaling\n",
375
+ "sample_features = X.columns[:5]\n",
376
+ "X_train[sample_features].boxplot(ax=axes[0])\n",
377
+ "axes[0].set_title('Feature Scales - Before Scaling', fontweight='bold', fontsize=12)\n",
378
+ "axes[0].set_ylabel('Value', fontweight='bold')\n",
379
+ "axes[0].tick_params(axis='x', rotation=45)\n",
380
+ "axes[0].grid(alpha=0.3)\n",
381
+ "\n",
382
+ "# After scaling\n",
383
+ "X_train_scaled[sample_features].boxplot(ax=axes[1])\n",
384
+ "axes[1].set_title('Feature Scales - After Scaling', fontweight='bold', fontsize=12)\n",
385
+ "axes[1].set_ylabel('Standardized Value', fontweight='bold')\n",
386
+ "axes[1].tick_params(axis='x', rotation=45)\n",
387
+ "axes[1].grid(alpha=0.3)\n",
388
+ "\n",
389
+ "plt.tight_layout()\n",
390
+ "plt.show()"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {},
396
+ "source": [
397
+ "## 7. KNN Model Training and Optimization"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "# Find optimal K value using cross-validation\n",
407
+ "print(\"=\" * 80)\n",
408
+ "print(\"KNN MODEL TRAINING\")\n",
409
+ "print(\"=\" * 80)\n",
410
+ "print(\"\\nFinding optimal K value...\\n\")\n",
411
+ "\n",
412
+ "k_range = range(1, 31)\n",
413
+ "train_scores = []\n",
414
+ "test_scores = []\n",
415
+ "cv_scores = []\n",
416
+ "\n",
417
+ "# Use tqdm for progress bar\n",
418
+ "for k in tqdm(k_range, desc=\"Testing K values\"):\n",
419
+ " knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)\n",
420
+ " \n",
421
+ " # Training score\n",
422
+ " knn.fit(X_train_scaled, y_train)\n",
423
+ " train_scores.append(knn.score(X_train_scaled, y_train))\n",
424
+ " \n",
425
+ " # Test score\n",
426
+ " test_scores.append(knn.score(X_test_scaled, y_test))\n",
427
+ " \n",
428
+ " # Cross-validation score\n",
429
+ " cv_score = cross_val_score(knn, X_train_scaled, y_train, cv=5, n_jobs=-1)\n",
430
+ " cv_scores.append(cv_score.mean())\n",
431
+ "\n",
432
+ "# Find best K\n",
433
+ "best_k_test = k_range[np.argmax(test_scores)]\n",
434
+ "best_k_cv = k_range[np.argmax(cv_scores)]\n",
435
+ "\n",
436
+ "print(f\"\\n✓ Optimal K (based on test accuracy): {best_k_test}\")\n",
437
+ "print(f\"✓ Optimal K (based on CV score): {best_k_cv}\")\n",
438
+ "print(f\"\\nUsing K = {best_k_cv} for final model\")"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "# Plot K vs Accuracy\n",
448
+ "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
449
+ "\n",
450
+ "# Left plot: All scores\n",
451
+ "axes[0].plot(k_range, train_scores, label='Training Accuracy', \n",
452
+ " marker='o', linewidth=2, markersize=4, color='#2ecc71')\n",
453
+ "axes[0].plot(k_range, test_scores, label='Test Accuracy', \n",
454
+ " marker='s', linewidth=2, markersize=4, color='#e74c3c')\n",
455
+ "axes[0].plot(k_range, cv_scores, label='CV Accuracy (5-fold)', \n",
456
+ " marker='^', linewidth=2, markersize=4, color='#3498db')\n",
457
+ "axes[0].axvline(x=best_k_cv, color='black', linestyle='--', alpha=0.5, label=f'Best K={best_k_cv}')\n",
458
+ "axes[0].set_xlabel('K Value (Number of Neighbors)', fontweight='bold')\n",
459
+ "axes[0].set_ylabel('Accuracy', fontweight='bold')\n",
460
+ "axes[0].set_title('KNN: K Value vs Accuracy', fontweight='bold', fontsize=14)\n",
461
+ "axes[0].legend(loc='best', fontsize=10)\n",
462
+ "axes[0].grid(True, alpha=0.3)\n",
463
+ "\n",
464
+ "# Right plot: Train-test gap\n",
465
+ "gap = np.array(train_scores) - np.array(test_scores)\n",
466
+ "axes[1].plot(k_range, gap, marker='o', linewidth=2, markersize=4, color='#9b59b6')\n",
467
+ "axes[1].axvline(x=best_k_cv, color='black', linestyle='--', alpha=0.5, label=f'Best K={best_k_cv}')\n",
468
+ "axes[1].axhline(y=0, color='red', linestyle='-', alpha=0.3)\n",
469
+ "axes[1].set_xlabel('K Value (Number of Neighbors)', fontweight='bold')\n",
470
+ "axes[1].set_ylabel('Train-Test Accuracy Gap', fontweight='bold')\n",
471
+ "axes[1].set_title('Overfitting Analysis', fontweight='bold', fontsize=14)\n",
472
+ "axes[1].legend(loc='best', fontsize=10)\n",
473
+ "axes[1].grid(True, alpha=0.3)\n",
474
+ "\n",
475
+ "plt.tight_layout()\n",
476
+ "plt.show()\n",
477
+ "\n",
478
+ "print(f\"\\nBest test accuracy: {max(test_scores):.4f} at K={best_k_test}\")\n",
479
+ "print(f\"Best CV accuracy: {max(cv_scores):.4f} at K={best_k_cv}\")"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "# Train final KNN model with optimal K\n",
489
+ "print(\"\\nTraining final KNN model...\")\n",
490
+ "optimal_k = best_k_cv\n",
491
+ "knn_model = KNeighborsClassifier(n_neighbors=optimal_k, n_jobs=-1)\n",
492
+ "knn_model.fit(X_train_scaled, y_train)\n",
493
+ "\n",
494
+ "# Make predictions\n",
495
+ "y_train_pred = knn_model.predict(X_train_scaled)\n",
496
+ "y_test_pred = knn_model.predict(X_test_scaled)\n",
497
+ "y_train_proba = knn_model.predict_proba(X_train_scaled)\n",
498
+ "y_test_proba = knn_model.predict_proba(X_test_scaled)\n",
499
+ "\n",
500
+ "# Calculate metrics\n",
501
+ "train_accuracy = accuracy_score(y_train, y_train_pred)\n",
502
+ "test_accuracy = accuracy_score(y_test, y_test_pred)\n",
503
+ "\n",
504
+ "print(\"\\n\" + \"=\" * 80)\n",
505
+ "print(\"MODEL PERFORMANCE\")\n",
506
+ "print(\"=\" * 80)\n",
507
+ "print(f\"\\nOptimal K: {optimal_k}\")\n",
508
+ "print(f\"Training Accuracy: {train_accuracy:.4f}\")\n",
509
+ "print(f\"Test Accuracy: {test_accuracy:.4f}\")\n",
510
+ "print(f\"\\nClassification Report (Test Set):\")\n",
511
+ "print(classification_report(y_test, y_test_pred, target_names=data.target_names, digits=4))"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": null,
517
+ "metadata": {},
518
+ "outputs": [],
519
+ "source": [
520
+ "# Confusion Matrix Visualization\n",
521
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
522
+ "\n",
523
+ "# Absolute counts\n",
524
+ "cm = confusion_matrix(y_test, y_test_pred)\n",
525
+ "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n",
526
+ " xticklabels=data.target_names, \n",
527
+ " yticklabels=data.target_names,\n",
528
+ " cbar_kws={'label': 'Count'},\n",
529
+ " ax=axes[0],\n",
530
+ " annot_kws={'fontsize': 14, 'fontweight': 'bold'})\n",
531
+ "axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')\n",
532
+ "axes[0].set_ylabel('True Label', fontsize=12, fontweight='bold')\n",
533
+ "axes[0].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')\n",
534
+ "\n",
535
+ "# Normalized\n",
536
+ "cm_norm = confusion_matrix(y_test, y_test_pred, normalize='true')\n",
537
+ "sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Greens', \n",
538
+ " xticklabels=data.target_names, \n",
539
+ " yticklabels=data.target_names,\n",
540
+ " cbar_kws={'label': 'Proportion'},\n",
541
+ " ax=axes[1],\n",
542
+ " annot_kws={'fontsize': 14, 'fontweight': 'bold'})\n",
543
+ "axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')\n",
544
+ "axes[1].set_ylabel('True Label', fontsize=12, fontweight='bold')\n",
545
+ "axes[1].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')\n",
546
+ "\n",
547
+ "plt.tight_layout()\n",
548
+ "plt.show()"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": null,
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": [
557
+ "# ROC Curve and AUC\n",
558
+ "fpr, tpr, thresholds = roc_curve(y_test, y_test_proba[:, 1])\n",
559
+ "roc_auc = roc_auc_score(y_test, y_test_proba[:, 1])\n",
560
+ "\n",
561
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
562
+ "\n",
563
+ "# ROC Curve\n",
564
+ "axes[0].plot(fpr, tpr, color='darkorange', lw=2, \n",
565
+ " label=f'ROC curve (AUC = {roc_auc:.4f})')\n",
566
+ "axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')\n",
567
+ "axes[0].set_xlim([0.0, 1.0])\n",
568
+ "axes[0].set_ylim([0.0, 1.05])\n",
569
+ "axes[0].set_xlabel('False Positive Rate', fontweight='bold')\n",
570
+ "axes[0].set_ylabel('True Positive Rate', fontweight='bold')\n",
571
+ "axes[0].set_title('Receiver Operating Characteristic (ROC) Curve', fontweight='bold', fontsize=12)\n",
572
+ "axes[0].legend(loc=\"lower right\", fontsize=10)\n",
573
+ "axes[0].grid(alpha=0.3)\n",
574
+ "\n",
575
+ "# Precision-Recall Curve\n",
576
+ "precision, recall, _ = precision_recall_curve(y_test, y_test_proba[:, 1])\n",
577
+ "axes[1].plot(recall, precision, color='green', lw=2, label='Precision-Recall curve')\n",
578
+ "axes[1].set_xlabel('Recall', fontweight='bold')\n",
579
+ "axes[1].set_ylabel('Precision', fontweight='bold')\n",
580
+ "axes[1].set_title('Precision-Recall Curve', fontweight='bold', fontsize=12)\n",
581
+ "axes[1].legend(loc=\"best\", fontsize=10)\n",
582
+ "axes[1].grid(alpha=0.3)\n",
583
+ "\n",
584
+ "plt.tight_layout()\n",
585
+ "plt.show()"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "markdown",
590
+ "metadata": {},
591
+ "source": [
592
+ "## 8. SHAP Explainability Setup"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "code",
597
+ "execution_count": null,
598
+ "metadata": {},
599
+ "outputs": [],
600
+ "source": [
601
+ "print(\"=\" * 80)\n",
602
+ "print(\"SHAP EXPLAINABILITY ANALYSIS\")\n",
603
+ "print(\"=\" * 80)\n",
604
+ "print(\"\\nSetting up SHAP explainer...\\n\")\n",
605
+ "\n",
606
+ "# Create background dataset for SHAP\n",
607
+ "# Using kmeans to select representative samples (faster for large datasets)\n",
608
+ "background_size = 100\n",
609
+ "background = shap.kmeans(X_train_scaled, background_size)\n",
610
+ "\n",
611
+ "print(f\"Background dataset size: {background_size} samples\")\n",
612
+ "print(f\"Background dataset shape: {background.data.shape}\")\n",
613
+ "\n",
614
+ "# Create SHAP explainer\n",
615
+ "# Using KernelExplainer (model-agnostic) for KNN\n",
616
+ "print(\"\\nCreating SHAP KernelExplainer (this may take a moment)...\")\n",
617
+ "explainer = shap.KernelExplainer(knn_model.predict_proba, background)\n",
618
+ "\n",
619
+ "print(\"\\n✓ SHAP explainer created successfully!\")\n",
620
+ "print(f\"Expected value (class 0): {explainer.expected_value[0]:.4f}\")\n",
621
+ "print(f\"Expected value (class 1): {explainer.expected_value[1]:.4f}\")\n",
622
+ "\n",
623
+ "# Check GPU memory after setup\n",
624
+ "print(\"\\nGPU Memory Status:\")\n",
625
+ "print_gpu_memory()"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": null,
631
+ "metadata": {},
632
+ "outputs": [],
633
+ "source": "# Compute SHAP values for test set\n# Adjust sample size based on your GPU memory and time constraints\nn_samples = min(100, len(X_test_scaled)) # Use 100 samples or all test samples if less\nX_test_sample = X_test_scaled.iloc[:n_samples]\ny_test_sample = y_test.iloc[:n_samples]\n\nprint(f\"Computing SHAP values for {n_samples} test samples...\")\nprint(\"This may take several minutes depending on your GPU...\")\nprint(\"Progress will be shown below:\\n\")\n\n# Compute SHAP values with progress tracking\nimport time\nstart_time = time.time()\n\nshap_values = explainer.shap_values(X_test_sample, nsamples=100) # nsamples controls accuracy/speed tradeoff\n\nelapsed_time = time.time() - start_time\n\nprint(f\"\\n✓ SHAP values computed successfully!\")\nprint(f\"Computation time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)\")\nprint(f\"Time per sample: {elapsed_time/n_samples:.2f} seconds\")\n\n# Debug: Check the structure of shap_values\nprint(f\"\\nDEBUG: Type of shap_values: {type(shap_values)}\")\nif isinstance(shap_values, list):\n print(f\"DEBUG: shap_values is a list with {len(shap_values)} elements\")\n for i, sv in enumerate(shap_values):\n print(f\"DEBUG: shap_values[{i}].shape = {np.array(sv).shape}\")\nelse:\n print(f\"DEBUG: shap_values.shape = {np.array(shap_values).shape}\")\nprint(f\"DEBUG: X_test_sample.shape = {X_test_sample.shape}\")\n\nprint(f\"\\nSHAP values shape: {np.array(shap_values).shape}\")\nprint(f\" - 2 classes (binary classification)\")\nprint(f\" - {n_samples} samples explained\")\nprint(f\" - {X_test_sample.shape[1]} features\")\n\n# Check GPU memory\nprint(\"\\nGPU Memory Status:\")\nprint_gpu_memory()"
634
+ },
635
+ {
636
+ "cell_type": "markdown",
637
+ "metadata": {},
638
+ "source": "## 9. SHAP Visualizations - Global Explanations\n\n# Fix SHAP values format for binary classification\n# SHAP KernelExplainer returns shape (n_samples, n_features, n_classes) for predict_proba\n# We need to convert it to a list of arrays: [class_0_shap_values, class_1_shap_values]\nprint(\"Reshaping SHAP values for visualization...\")\nprint(f\"Original shape: {shap_values.shape}\")\n\nif len(shap_values.shape) == 3 and shap_values.shape[2] == 2:\n # Convert from (n_samples, n_features, n_classes) to list of (n_samples, n_features)\n shap_values_list = [shap_values[:, :, i] for i in range(shap_values.shape[2])]\n print(f\"Converted to list format:\")\n print(f\" Class 0 SHAP values shape: {shap_values_list[0].shape}\")\n print(f\" Class 1 SHAP values shape: {shap_values_list[1].shape}\")\n shap_values = shap_values_list\nelse:\n print(f\"SHAP values already in correct format\")\n\nprint(\"✓ SHAP values ready for visualization!\")"
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": null,
643
+ "metadata": {},
644
+ "outputs": [],
645
+ "source": "# SHAP Summary Plot (Beeswarm) - Shows global feature importance\nprint(\"=\" * 80)\nprint(\"SHAP VISUALIZATION 1: SUMMARY PLOT (BEESWARM)\")\nprint(\"=\" * 80)\nprint(\"\"\"\nThis plot shows:\n- Feature importance (vertical axis, ordered by importance)\n- SHAP values (horizontal axis, impact on prediction)\n- Feature values (color, red=high, blue=low)\n- Distribution across all samples (density)\n\nReading the plot:\n- Features at the top are most important\n- Points to the right increase probability of class 1 (malignant)\n- Points to the left decrease probability of class 1\n- Color shows whether high (red) or low (blue) feature values have that effect\n\"\"\")\n\nplt.figure(figsize=(14, 10))\nshap.summary_plot(shap_values[1], X_test_sample.values, \n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"dot\", show=False, max_display=20)\nplt.title('SHAP Summary Plot - Global Feature Importance and Impact\\n(Predicting Malignant Class)', \n fontsize=14, fontweight='bold', pad=20)\nplt.tight_layout()\nplt.show()"
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": null,
650
+ "metadata": {},
651
+ "outputs": [],
652
+ "source": "# SHAP Bar Plot - Mean absolute SHAP values\nprint(\"=\" * 80)\nprint(\"SHAP VISUALIZATION 2: BAR PLOT\")\nprint(\"=\" * 80)\nprint(\"\"\"\nThis plot shows:\n- Average magnitude of feature impact (mean |SHAP value|)\n- Overall feature importance ranking\n- Which features have the strongest effect on predictions (regardless of direction)\n\"\"\")\n\nplt.figure(figsize=(14, 10))\nshap.summary_plot(shap_values[1], X_test_sample.values,\n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"bar\", show=False, max_display=20)\nplt.title('SHAP Bar Plot - Mean Absolute Feature Importance', \n fontsize=14, fontweight='bold', pad=20)\nplt.xlabel('Mean |SHAP value| (average impact on model output magnitude)', fontweight='bold')\nplt.tight_layout()\nplt.show()"
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": null,
657
+ "metadata": {},
658
+ "outputs": [],
659
+ "source": [
660
+ "# Calculate and display feature importance rankings\n",
661
+ "feature_importance = np.abs(shap_values[1]).mean(axis=0)\n",
662
+ "feature_importance_df = pd.DataFrame({\n",
663
+ " 'Feature': X_test_sample.columns,\n",
664
+ " 'Mean_Abs_SHAP': feature_importance,\n",
665
+ " 'Mean_SHAP': shap_values[1].mean(axis=0),\n",
666
+ " 'Std_SHAP': np.std(shap_values[1], axis=0)\n",
667
+ "}).sort_values('Mean_Abs_SHAP', ascending=False)\n",
668
+ "\n",
669
+ "print(\"\\n\" + \"=\" * 80)\n",
670
+ "print(\"TOP 15 MOST IMPORTANT FEATURES (by mean |SHAP value|)\")\n",
671
+ "print(\"=\" * 80)\n",
672
+ "print(feature_importance_df.head(15).to_string(index=False))\n",
673
+ "\n",
674
+ "# Visualize feature importance with custom plot\n",
675
+ "fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n",
676
+ "\n",
677
+ "# Top features by absolute importance\n",
678
+ "top_15 = feature_importance_df.head(15)\n",
679
+ "y_pos = np.arange(len(top_15))\n",
680
+ "axes[0].barh(y_pos, top_15['Mean_Abs_SHAP'], color='steelblue', edgecolor='black')\n",
681
+ "axes[0].set_yticks(y_pos)\n",
682
+ "axes[0].set_yticklabels(top_15['Feature'], fontsize=9)\n",
683
+ "axes[0].set_xlabel('Mean |SHAP Value|', fontweight='bold')\n",
684
+ "axes[0].set_title('Top 15 Features by Absolute Importance', fontweight='bold', fontsize=12)\n",
685
+ "axes[0].invert_yaxis()\n",
686
+ "axes[0].grid(axis='x', alpha=0.3)\n",
687
+ "\n",
688
+ "# Top features by directional importance (showing positive/negative effect)\n",
689
+ "colors = ['green' if x > 0 else 'red' for x in top_15['Mean_SHAP']]\n",
690
+ "axes[1].barh(y_pos, top_15['Mean_SHAP'], color=colors, edgecolor='black', alpha=0.7)\n",
691
+ "axes[1].set_yticks(y_pos)\n",
692
+ "axes[1].set_yticklabels(top_15['Feature'], fontsize=9)\n",
693
+ "axes[1].set_xlabel('Mean SHAP Value (Directional)', fontweight='bold')\n",
694
+ "axes[1].set_title('Top 15 Features - Directional Impact\\n(Green: Increases Malignant Prob, Red: Decreases)', \n",
695
+ " fontweight='bold', fontsize=12)\n",
696
+ "axes[1].axvline(x=0, color='black', linestyle='-', linewidth=0.8)\n",
697
+ "axes[1].invert_yaxis()\n",
698
+ "axes[1].grid(axis='x', alpha=0.3)\n",
699
+ "\n",
700
+ "plt.tight_layout()\n",
701
+ "plt.show()"
702
+ ]
703
+ },
704
+ {
705
+ "cell_type": "markdown",
706
+ "metadata": {},
707
+ "source": [
708
+ "## 10. SHAP Visualizations - Local Explanations (Individual Predictions)"
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": null,
714
+ "metadata": {},
715
+ "outputs": [],
716
+ "source": [
717
+ "# Select interesting samples for detailed explanation\n",
718
+ "print(\"=\" * 80)\n",
719
+ "print(\"SELECTING SAMPLES FOR DETAILED EXPLANATION\")\n",
720
+ "print(\"=\" * 80)\n",
721
+ "\n",
722
+ "# Get predictions for our sample\n",
723
+ "y_pred_sample = knn_model.predict(X_test_sample)\n",
724
+ "y_proba_sample = knn_model.predict_proba(X_test_sample)\n",
725
+ "\n",
726
+ "# Find interesting samples\n",
727
+ "correct_mask = y_test_sample == y_pred_sample\n",
728
+ "incorrect_mask = ~correct_mask\n",
729
+ "\n",
730
+ "# High confidence correct\n",
731
+ "high_conf_correct_idx = np.where(correct_mask & (np.max(y_proba_sample, axis=1) > 0.9))[0]\n",
732
+ "# Low confidence correct\n",
733
+ "low_conf_correct_idx = np.where(correct_mask & (np.max(y_proba_sample, axis=1) < 0.7))[0]\n",
734
+ "# Misclassified\n",
735
+ "misclassified_idx = np.where(incorrect_mask)[0]\n",
736
+ "\n",
737
+ "print(f\"\\nTotal samples analyzed: {len(y_test_sample)}\")\n",
738
+ "print(f\"Correctly classified: {correct_mask.sum()} ({correct_mask.sum()/len(y_test_sample)*100:.1f}%)\")\n",
739
+ "print(f\"Misclassified: {incorrect_mask.sum()} ({incorrect_mask.sum()/len(y_test_sample)*100:.1f}%)\")\n",
740
+ "print(f\"\\nHigh confidence correct predictions: {len(high_conf_correct_idx)}\")\n",
741
+ "print(f\"Low confidence correct predictions: {len(low_conf_correct_idx)}\")\n",
742
+ "print(f\"Misclassified predictions: {len(misclassified_idx)}\")\n",
743
+ "\n",
744
+ "# Select samples to explain\n",
745
+ "samples_to_explain = []\n",
746
+ "sample_descriptions = []\n",
747
+ "\n",
748
+ "if len(high_conf_correct_idx) > 0:\n",
749
+ " samples_to_explain.append(high_conf_correct_idx[0])\n",
750
+ " sample_descriptions.append(\"High confidence correct\")\n",
751
+ "\n",
752
+ "if len(low_conf_correct_idx) > 0:\n",
753
+ " samples_to_explain.append(low_conf_correct_idx[0])\n",
754
+ " sample_descriptions.append(\"Low confidence correct\")\n",
755
+ "\n",
756
+ "if len(misclassified_idx) > 0:\n",
757
+ " samples_to_explain.append(misclassified_idx[0])\n",
758
+ " sample_descriptions.append(\"Misclassified\")\n",
759
+ "\n",
760
+ "if len(samples_to_explain) == 0:\n",
761
+ " # If no special cases, just use first sample\n",
762
+ " samples_to_explain = [0]\n",
763
+ " sample_descriptions = [\"First sample\"]\n",
764
+ "\n",
765
+ "print(f\"\\nSelected {len(samples_to_explain)} samples for detailed explanation\")"
766
+ ]
767
+ },
768
+ {
769
+ "cell_type": "code",
770
+ "execution_count": null,
771
+ "metadata": {},
772
+ "outputs": [],
773
+ "source": [
774
+ "# Waterfall plots for selected samples\n",
775
+ "print(\"\\n\" + \"=\" * 80)\n",
776
+ "print(\"SHAP VISUALIZATION 3: WATERFALL PLOTS (Individual Predictions)\")\n",
777
+ "print(\"=\" * 80)\n",
778
+ "print(\"\"\"\n",
779
+ "Waterfall plots show how each feature contributes to pushing the prediction\n",
780
+ "from the base value (expected value) to the final prediction for a single sample.\n",
781
+ "\n",
782
+ "Reading the plot:\n",
783
+ "- Starts at E[f(x)] (expected value/average prediction)\n",
784
+ "- Each bar shows a feature's contribution\n",
785
+ "- Red bars push prediction higher (toward malignant)\n",
786
+ "- Blue bars push prediction lower (toward benign)\n",
787
+ "- Final value f(x) is the model's output for this sample\n",
788
+ "\"\"\")\n",
789
+ "\n",
790
+ "for idx, (sample_idx, description) in enumerate(zip(samples_to_explain, sample_descriptions)):\n",
791
+ " true_label = data.target_names[y_test_sample.iloc[sample_idx]]\n",
792
+ " pred_label = data.target_names[y_pred_sample[sample_idx]]\n",
793
+ " pred_proba = y_proba_sample[sample_idx]\n",
794
+ " \n",
795
+ " print(f\"\\n{'-'*80}\")\n",
796
+ " print(f\"Sample {idx+1}: {description} (Index {sample_idx})\")\n",
797
+ " print(f\"{'-'*80}\")\n",
798
+ " print(f\"True Label: {true_label}\")\n",
799
+ " print(f\"Predicted Label: {pred_label}\")\n",
800
+ " print(f\"Prediction Probabilities:\")\n",
801
+ " for class_idx, class_name in enumerate(data.target_names):\n",
802
+ " print(f\" {class_name}: {pred_proba[class_idx]:.4f} ({pred_proba[class_idx]*100:.2f}%)\")\n",
803
+ " print(f\"Correct: {true_label == pred_label}\")\n",
804
+ " \n",
805
+ " # Create waterfall plot\n",
806
+ " shap.plots.waterfall(\n",
807
+ " shap.Explanation(\n",
808
+ " values=shap_values[1][sample_idx],\n",
809
+ " base_values=explainer.expected_value[1],\n",
810
+ " data=X_test_sample.iloc[sample_idx],\n",
811
+ " feature_names=X_test_sample.columns.tolist()\n",
812
+ " ),\n",
813
+ " max_display=15\n",
814
+ " )"
815
+ ]
816
+ },
817
+ {
818
+ "cell_type": "code",
819
+ "execution_count": null,
820
+ "metadata": {},
821
+ "outputs": [],
822
+ "source": [
823
+ "# Force plots for individual predictions\n",
824
+ "print(\"\\n\" + \"=\" * 80)\n",
825
+ "print(\"SHAP VISUALIZATION 4: FORCE PLOTS (Individual Predictions)\")\n",
826
+ "print(\"=\" * 80)\n",
827
+ "print(\"\"\"\n",
828
+ "Force plots provide another view of individual predictions:\n",
829
+ "- Red features push prediction toward higher values (malignant)\n",
830
+ "- Blue features push prediction toward lower values (benign)\n",
831
+ "- Width of each feature shows magnitude of impact\n",
832
+ "\"\"\")\n",
833
+ "\n",
834
+ "for idx, (sample_idx, description) in enumerate(zip(samples_to_explain[:3], sample_descriptions[:3])):\n",
835
+ " print(f\"\\nSample {idx+1}: {description}\")\n",
836
+ " \n",
837
+ " shap.force_plot(\n",
838
+ " explainer.expected_value[1],\n",
839
+ " shap_values[1][sample_idx],\n",
840
+ " X_test_sample.iloc[sample_idx],\n",
841
+ " matplotlib=True,\n",
842
+ " show=False,\n",
843
+ " figsize=(20, 3)\n",
844
+ " )\n",
845
+ " plt.title(f'Force Plot - {description} (Sample {sample_idx})', \n",
846
+ " fontsize=12, fontweight='bold', pad=10)\n",
847
+ " plt.tight_layout()\n",
848
+ " plt.show()"
849
+ ]
850
+ },
851
+ {
852
+ "cell_type": "code",
853
+ "execution_count": null,
854
+ "metadata": {},
855
+ "outputs": [],
856
+ "source": [
857
+ "# Interactive force plot for multiple samples\n",
858
+ "print(\"\\n\" + \"=\" * 80)\n",
859
+ "print(\"SHAP VISUALIZATION 5: INTERACTIVE FORCE PLOT (Multiple Predictions)\")\n",
860
+ "print(\"=\" * 80)\n",
861
+ "print(\"\"\"\n",
862
+ "This interactive visualization shows force plots for multiple samples simultaneously.\n",
863
+ "Samples are sorted by their predicted probability, allowing you to see patterns\n",
864
+ "across different prediction strengths.\n",
865
+ "\"\"\")\n",
866
+ "\n",
867
+ "# Use first 50 samples for visualization\n",
868
+ "n_force_samples = min(50, len(X_test_sample))\n",
869
+ "\n",
870
+ "shap.force_plot(\n",
871
+ " explainer.expected_value[1],\n",
872
+ " shap_values[1][:n_force_samples],\n",
873
+ " X_test_sample.iloc[:n_force_samples]\n",
874
+ ")"
875
+ ]
876
+ },
877
+ {
878
+ "cell_type": "markdown",
879
+ "metadata": {},
880
+ "source": [
881
+ "## 11. SHAP Dependence Plots - Feature Interactions"
882
+ ]
883
+ },
884
+ {
885
+ "cell_type": "code",
886
+ "execution_count": null,
887
+ "metadata": {},
888
+ "outputs": [],
889
+ "source": "# Dependence plots for top features\nprint(\"=\" * 80)\nprint(\"SHAP VISUALIZATION 6: DEPENDENCE PLOTS\")\nprint(\"=\" * 80)\nprint(\"\"\"\nDependence plots show how feature values relate to their SHAP values:\n- X-axis: Feature value\n- Y-axis: SHAP value (impact on prediction)\n- Color: Another feature that may interact with this feature\n\nThese plots reveal:\n- Non-linear relationships between features and predictions\n- Feature interactions (shown by color patterns)\n- Threshold effects\n\"\"\")\n\n# Get top 6 most important features\ntop_features = feature_importance_df.head(6)['Feature'].values\n\nfig, axes = plt.subplots(2, 3, figsize=(18, 12))\naxes = axes.ravel()\n\nfor idx, feature in enumerate(top_features):\n plt.sca(axes[idx])\n shap.dependence_plot(\n feature,\n shap_values[1],\n X_test_sample.values,\n feature_names=X_test_sample.columns.tolist(),\n show=False,\n ax=axes[idx]\n )\n axes[idx].set_title(f'Dependence Plot: {feature}', fontsize=11, fontweight='bold')\n axes[idx].grid(alpha=0.3)\n\nplt.suptitle('SHAP Dependence Plots - Top 6 Features', \n fontsize=14, fontweight='bold', y=1.002)\nplt.tight_layout()\nplt.show()\n\nprint(\"\\nKey observations from dependence plots:\")\nprint(\"- Look for non-linear patterns in the scatter plots\")\nprint(\"- Color gradients indicate feature interactions\")\nprint(\"- Vertical spread at a given x-value suggests interactions with other features\")"
890
+ },
891
+ {
892
+ "cell_type": "markdown",
893
+ "metadata": {},
894
+ "source": [
895
+ "## 12. SHAP Decision Plot"
896
+ ]
897
+ },
898
+ {
899
+ "cell_type": "code",
900
+ "execution_count": null,
901
+ "metadata": {},
902
+ "outputs": [],
903
+ "source": [
904
+ "# Decision plot showing prediction paths\n",
905
+ "print(\"=\" * 80)\n",
906
+ "print(\"SHAP VISUALIZATION 7: DECISION PLOT\")\n",
907
+ "print(\"=\" * 80)\n",
908
+ "print(\"\"\"\n",
909
+ "Decision plots show the cumulative effect of features on predictions:\n",
910
+ "- Each line represents one sample's prediction path\n",
911
+ "- Starts from expected value at bottom\n",
912
+ "- Each feature shifts the prediction up or down\n",
913
+ "- Final position (top) is the model's prediction\n",
914
+ "- Color indicates the final predicted class\n",
915
+ "\n",
916
+ "This helps visualize:\n",
917
+ "- Which features drive different predictions\n",
918
+ "- Where predictions diverge\n",
919
+ "- Similarity between prediction paths\n",
920
+ "\"\"\")\n",
921
+ "\n",
922
+ "# Select diverse samples for decision plot\n",
923
+ "n_decision_samples = min(30, len(X_test_sample))\n",
924
+ "decision_indices = np.linspace(0, len(X_test_sample)-1, n_decision_samples, dtype=int)\n",
925
+ "\n",
926
+ "plt.figure(figsize=(14, 10))\n",
927
+ "shap.decision_plot(\n",
928
+ " explainer.expected_value[1],\n",
929
+ " shap_values[1][decision_indices],\n",
930
+ " X_test_sample.iloc[decision_indices],\n",
931
+ " show=False,\n",
932
+ " feature_display_range=slice(-1, -21, -1) # Show top 20 features\n",
933
+ ")\n",
934
+ "plt.title(f'SHAP Decision Plot - Prediction Paths for {n_decision_samples} Samples\\n(Top 20 Features)', \n",
935
+ " fontsize=14, fontweight='bold', pad=20)\n",
936
+ "plt.tight_layout()\n",
937
+ "plt.show()"
938
+ ]
939
+ },
940
+ {
941
+ "cell_type": "markdown",
942
+ "metadata": {},
943
+ "source": [
944
+ "## 13. Advanced Analysis - Correct vs Incorrect Predictions"
945
+ ]
946
+ },
947
+ {
948
+ "cell_type": "code",
949
+ "execution_count": null,
950
+ "metadata": {},
951
+ "outputs": [],
952
+ "source": "# Compare SHAP patterns between correct and incorrect predictions\nprint(\"=\" * 80)\nprint(\"ADVANCED ANALYSIS: CORRECT VS INCORRECT PREDICTIONS\")\nprint(\"=\" * 80)\n\nif incorrect_mask.sum() > 0:\n print(f\"\\nAnalyzing {incorrect_mask.sum()} misclassified samples...\\n\")\n \n # Compare average SHAP values\n shap_correct = np.abs(shap_values[1][correct_mask]).mean(axis=0)\n shap_incorrect = np.abs(shap_values[1][incorrect_mask]).mean(axis=0)\n \n comparison_df = pd.DataFrame({\n 'Feature': X_test_sample.columns,\n 'Correct_Predictions': shap_correct,\n 'Incorrect_Predictions': shap_incorrect,\n 'Difference': shap_incorrect - shap_correct\n }).sort_values('Difference', ascending=False)\n \n print(\"Features with largest difference in importance:\")\n print(\"\\nTop 10 features MORE important in incorrect predictions:\")\n print(comparison_df.head(10).to_string(index=False))\n \n # Visualize comparison\n fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n \n # Convert mask to numpy array for indexing\n correct_mask_np = correct_mask.values if hasattr(correct_mask, 'values') else correct_mask\n incorrect_mask_np = incorrect_mask.values if hasattr(incorrect_mask, 'values') else incorrect_mask\n \n # Get the data for correct and incorrect predictions\n X_correct = X_test_sample.values[correct_mask_np]\n X_incorrect = X_test_sample.values[incorrect_mask_np]\n \n # Summary plot for correct predictions\n plt.sca(axes[0])\n shap.summary_plot(shap_values[1][correct_mask_np], \n X_correct,\n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"bar\", show=False, max_display=15)\n axes[0].set_title(f'Feature Importance - Correct Predictions (n={correct_mask.sum()})', \n fontweight='bold', fontsize=12)\n \n # Summary plot for incorrect predictions\n plt.sca(axes[1])\n shap.summary_plot(shap_values[1][incorrect_mask_np], \n X_incorrect,\n feature_names=X_test_sample.columns.tolist(),\n plot_type=\"bar\", show=False, max_display=15)\n axes[1].set_title(f'Feature Importance - Incorrect Predictions (n={incorrect_mask.sum()})', \n fontweight='bold', fontsize=12)\n \n plt.tight_layout()\n plt.show()\n \nelse:\n print(\"\\nAll samples in the test set were correctly classified!\")\n print(\"This indicates excellent model performance.\")"
953
+ },
954
+ {
955
+ "cell_type": "markdown",
956
+ "metadata": {},
957
+ "source": [
958
+ "## 14. Interactive Exploration Function"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "code",
963
+ "execution_count": null,
964
+ "metadata": {},
965
+ "outputs": [],
966
+ "source": [
967
+ "# Create interactive exploration function\n",
968
+ "def explain_sample(sample_index):\n",
969
+ " \"\"\"\n",
970
+ " Provide detailed explanation for a specific sample prediction.\n",
971
+ " \n",
972
+ " Args:\n",
973
+ " sample_index: Index of the sample to explain (0 to len(X_test_sample)-1)\n",
974
+ " \"\"\"\n",
975
+ " if sample_index < 0 or sample_index >= len(X_test_sample):\n",
976
+ " print(f\"Error: Sample index out of range. Please use 0 to {len(X_test_sample)-1}\")\n",
977
+ " return\n",
978
+ " \n",
979
+ " print(\"\\n\" + \"=\" * 80)\n",
980
+ " print(f\"DETAILED EXPLANATION FOR SAMPLE {sample_index}\")\n",
981
+ " print(\"=\" * 80)\n",
982
+ " \n",
983
+ " # Get prediction information\n",
984
+ " true_label = data.target_names[y_test_sample.iloc[sample_index]]\n",
985
+ " pred_label = data.target_names[y_pred_sample[sample_index]]\n",
986
+ " pred_proba = y_proba_sample[sample_index]\n",
987
+ " \n",
988
+ " print(f\"\\n1. PREDICTION SUMMARY\")\n",
989
+ " print(f\"{'-'*80}\")\n",
990
+ " print(f\"True Label: {true_label}\")\n",
991
+ " print(f\"Predicted Label: {pred_label}\")\n",
992
+ " print(f\"Correct: {'✓ Yes' if true_label == pred_label else '✗ No'}\")\n",
993
+ " print(f\"\\nPrediction Probabilities:\")\n",
994
+ " for class_idx, class_name in enumerate(data.target_names):\n",
995
+ " prob = pred_proba[class_idx]\n",
996
+ " bar = '█' * int(prob * 50)\n",
997
+ " print(f\" {class_name:12s}: {prob:.4f} ({prob*100:5.2f}%) {bar}\")\n",
998
+ " \n",
999
+ " # SHAP explanation\n",
1000
+ " print(f\"\\n2. SHAP EXPLANATION\")\n",
1001
+ " print(f\"{'-'*80}\")\n",
1002
+ " print(f\"Base value (expected output): {explainer.expected_value[1]:.4f}\")\n",
1003
+ " print(f\"Model output for this sample: {explainer.expected_value[1] + shap_values[1][sample_index].sum():.4f}\")\n",
1004
+ " \n",
1005
+ " # Top positive and negative contributors\n",
1006
+ " shap_sample = shap_values[1][sample_index]\n",
1007
+ " feature_impacts = pd.DataFrame({\n",
1008
+ " 'Feature': X_test_sample.columns,\n",
1009
+ " 'Value': X_test_sample.iloc[sample_index].values,\n",
1010
+ " 'SHAP_Value': shap_sample\n",
1011
+ " }).sort_values('SHAP_Value', key=abs, ascending=False)\n",
1012
+ " \n",
1013
+ " print(f\"\\nTop 5 features INCREASING malignant probability:\")\n",
1014
+ " positive_features = feature_impacts[feature_impacts['SHAP_Value'] > 0].head(5)\n",
1015
+ " for idx, row in positive_features.iterrows():\n",
1016
+ " print(f\" {row['Feature']:30s}: {row['SHAP_Value']:+.4f} (value={row['Value']:.4f})\")\n",
1017
+ " \n",
1018
+ " print(f\"\\nTop 5 features DECREASING malignant probability:\")\n",
1019
+ " negative_features = feature_impacts[feature_impacts['SHAP_Value'] < 0].head(5)\n",
1020
+ " for idx, row in negative_features.iterrows():\n",
1021
+ " print(f\" {row['Feature']:30s}: {row['SHAP_Value']:+.4f} (value={row['Value']:.4f})\")\n",
1022
+ " \n",
1023
+ " # Visualizations\n",
1024
+ " print(f\"\\n3. VISUALIZATIONS\")\n",
1025
+ " print(f\"{'-'*80}\\n\")\n",
1026
+ " \n",
1027
+ " # Waterfall plot\n",
1028
+ " print(\"Waterfall Plot:\")\n",
1029
+ " shap.plots.waterfall(\n",
1030
+ " shap.Explanation(\n",
1031
+ " values=shap_values[1][sample_index],\n",
1032
+ " base_values=explainer.expected_value[1],\n",
1033
+ " data=X_test_sample.iloc[sample_index],\n",
1034
+ " feature_names=X_test_sample.columns.tolist()\n",
1035
+ " ),\n",
1036
+ " max_display=15\n",
1037
+ " )\n",
1038
+ " \n",
1039
+ " # Feature values comparison\n",
1040
+ " print(\"\\n4. FEATURE VALUES COMPARISON\")\n",
1041
+ " print(f\"{'-'*80}\")\n",
1042
+ " print(\"\\nTop 10 features by absolute value (original scale):\")\n",
1043
+ " original_values = X_test.iloc[sample_index].sort_values(ascending=False).head(10)\n",
1044
+ " for feature, value in original_values.items():\n",
1045
+ " mean_val = X_train[feature].mean()\n",
1046
+ " std_val = X_train[feature].std()\n",
1047
+ " z_score = (value - mean_val) / std_val\n",
1048
+ " print(f\" {feature:30s}: {value:10.4f} (μ={mean_val:8.4f}, z={z_score:+6.2f})\")\n",
1049
+ "\n",
1050
+ "# Display usage instructions\n",
1051
+ "print(\"=\" * 80)\n",
1052
+ "print(\"INTERACTIVE EXPLORATION\")\n",
1053
+ "print(\"=\" * 80)\n",
1054
+ "print(f\"\\nUse the explain_sample() function to explore any prediction:\")\n",
1055
+ "print(f\"\\nExample usage:\")\n",
1056
+ "print(f\" explain_sample(0) # Explain first sample\")\n",
1057
+ "print(f\" explain_sample(10) # Explain 11th sample\")\n",
1058
+ "print(f\"\\nValid range: 0 to {len(X_test_sample)-1}\")\n",
1059
+ "print(f\"\\nTry these interesting samples:\")\n",
1060
+ "if len(high_conf_correct_idx) > 0:\n",
1061
+ " print(f\" explain_sample({high_conf_correct_idx[0]}) # High confidence correct\")\n",
1062
+ "if len(low_conf_correct_idx) > 0:\n",
1063
+ " print(f\" explain_sample({low_conf_correct_idx[0]}) # Low confidence correct\")\n",
1064
+ "if len(misclassified_idx) > 0:\n",
1065
+ " print(f\" explain_sample({misclassified_idx[0]}) # Misclassified sample\")\n",
1066
+ "\n",
1067
+ "print(\"\\n\" + \"=\"*80)\n",
1068
+ "print(\"Example: Explaining sample 0\")\n",
1069
+ "print(\"=\"*80)\n",
1070
+ "explain_sample(0)"
1071
+ ]
1072
+ },
1073
+ {
1074
+ "cell_type": "markdown",
1075
+ "metadata": {},
1076
+ "source": [
1077
+ "## 15. GPU Resource Monitoring"
1078
+ ]
1079
+ },
1080
+ {
1081
+ "cell_type": "code",
1082
+ "execution_count": null,
1083
+ "metadata": {},
1084
+ "outputs": [],
1085
+ "source": [
1086
+ "# Final GPU memory check\n",
1087
+ "print(\"=\" * 80)\n",
1088
+ "print(\"VAST.AI GPU RESOURCE SUMMARY\")\n",
1089
+ "print(\"=\" * 80)\n",
1090
+ "print(\"\\nFinal GPU Memory Status:\")\n",
1091
+ "print_gpu_memory()\n",
1092
+ "\n",
1093
+ "print(\"\\nTo free GPU memory, run: clear_gpu_memory()\")"
1094
+ ]
1095
+ },
1096
+ {
1097
+ "cell_type": "markdown",
1098
+ "metadata": {},
1099
+ "source": [
1100
+ "## 16. Summary and Key Insights"
1101
+ ]
1102
+ },
1103
+ {
1104
+ "cell_type": "code",
1105
+ "execution_count": null,
1106
+ "metadata": {},
1107
+ "outputs": [],
1108
+ "source": [
1109
+ "# Comprehensive summary\n",
1110
+ "print(\"=\" * 80)\n",
1111
+ "print(\"COMPREHENSIVE SUMMARY\")\n",
1112
+ "print(\"=\" * 80)\n",
1113
+ "\n",
1114
+ "print(\"\\n1. MODEL PERFORMANCE\")\n",
1115
+ "print(\"-\" * 80)\n",
1116
+ "print(f\"Algorithm: K-Nearest Neighbors (KNN)\")\n",
1117
+ "print(f\"Optimal K: {optimal_k} neighbors\")\n",
1118
+ "print(f\"Training Accuracy: {train_accuracy:.4f} ({train_accuracy*100:.2f}%)\")\n",
1119
+ "print(f\"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)\")\n",
1120
+ "print(f\"ROC AUC Score: {roc_auc:.4f}\")\n",
1121
+ "\n",
1122
+ "print(\"\\n2. DATASET INFORMATION\")\n",
1123
+ "print(\"-\" * 80)\n",
1124
+ "print(f\"Dataset: Breast Cancer Wisconsin (Diagnostic)\")\n",
1125
+ "print(f\"Total Samples: {len(X)}\")\n",
1126
+ "print(f\"Features: {X.shape[1]}\")\n",
1127
+ "print(f\"Classes: {len(data.target_names)} ({', '.join(data.target_names)})\")\n",
1128
+ "print(f\"Train/Test Split: {len(X_train)}/{len(X_test)}\")\n",
1129
+ "\n",
1130
+ "print(\"\\n3. TOP 5 MOST IMPORTANT FEATURES (SHAP)\")\n",
1131
+ "print(\"-\" * 80)\n",
1132
+ "for idx, row in feature_importance_df.head(5).iterrows():\n",
1133
+ " print(f\"{idx+1}. {row['Feature']:30s} (mean |SHAP|={row['Mean_Abs_SHAP']:.4f})\")\n",
1134
+ "\n",
1135
+ "print(\"\\n4. SHAP EXPLAINABILITY INSIGHTS\")\n",
1136
+ "print(\"-\" * 80)\n",
1137
+ "print(f\"Samples explained: {n_samples}\")\n",
1138
+ "print(f\"Background dataset size: {background_size}\")\n",
1139
+ "print(f\"Computation time: {elapsed_time:.2f} seconds\")\n",
1140
+ "print(f\"Time per sample: {elapsed_time/n_samples:.2f} seconds\")\n",
1141
+ "\n",
1142
+ "print(\"\\n5. KEY TAKEAWAYS\")\n",
1143
+ "print(\"-\" * 80)\n",
1144
+ "print(\"\"\"\n",
1145
+ "✓ SHAP provides model-agnostic explainability for KNN predictions\n",
1146
+ "✓ Feature importance rankings identify which features drive predictions\n",
1147
+ "✓ Waterfall plots explain individual predictions step-by-step\n",
1148
+ "✓ Dependence plots reveal non-linear relationships and interactions\n",
1149
+ "✓ Summary plots show global patterns across all predictions\n",
1150
+ "✓ Decision plots visualize prediction paths for multiple samples\n",
1151
+ "✓ KNN combined with SHAP offers both accuracy and interpretability\n",
1152
+ "\"\"\")\n",
1153
+ "\n",
1154
+ "print(\"\\n6. SHAP VISUALIZATION TYPES USED\")\n",
1155
+ "print(\"-\" * 80)\n",
1156
+ "visualizations = [\n",
1157
+ " (\"Summary Plot (Beeswarm)\", \"Global feature importance with value distributions\"),\n",
1158
+ " (\"Bar Plot\", \"Mean absolute feature importance\"),\n",
1159
+ " (\"Waterfall Plot\", \"Individual prediction breakdown\"),\n",
1160
+ " (\"Force Plot\", \"Visual representation of feature contributions\"),\n",
1161
+ " (\"Dependence Plot\", \"Feature-value relationships and interactions\"),\n",
1162
+ " (\"Decision Plot\", \"Cumulative prediction paths for multiple samples\")\n",
1163
+ "]\n",
1164
+ "for viz_name, description in visualizations:\n",
1165
+ " print(f\" • {viz_name:30s}: {description}\")\n",
1166
+ "\n",
1167
+ "print(\"\\n7. VAST.AI ENVIRONMENT\")\n",
1168
+ "print(\"-\" * 80)\n",
1169
+ "if torch.cuda.is_available():\n",
1170
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
1171
+ " print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")\n",
1172
+ " print(f\"CUDA Version: {torch.version.cuda}\")\n",
1173
+ "else:\n",
1174
+ " print(\"Running on CPU (GPU not detected)\")\n",
1175
+ "\n",
1176
+ "print(\"\\n\" + \"=\" * 80)\n",
1177
+ "print(\"ANALYSIS COMPLETE\")\n",
1178
+ "print(\"=\" * 80)\n",
1179
+ "print(\"\\nNext steps:\")\n",
1180
+ "print(\" 1. Explore individual predictions using: explain_sample(index)\")\n",
1181
+ "print(\" 2. Try different K values for KNN\")\n",
1182
+ "print(\" 3. Test with other datasets (load_wine, load_iris, etc.)\")\n",
1183
+ "print(\" 4. Compare with other algorithms (Random Forest, SVM, XGBoost)\")\n",
1184
+ "print(\" 5. Use SHAP insights for feature engineering\")\n"
1185
+ ]
1186
+ },
1187
+ {
1188
+ "cell_type": "markdown",
1189
+ "metadata": {},
1190
+ "source": [
1191
+ "## 17. Cleanup and Resource Management"
1192
+ ]
1193
+ },
1194
+ {
1195
+ "cell_type": "code",
1196
+ "execution_count": null,
1197
+ "metadata": {},
1198
+ "outputs": [],
1199
+ "source": [
1200
+ "# Optional: Clear GPU memory when done\n",
1201
+ "# Uncomment the line below to free GPU memory\n",
1202
+ "# clear_gpu_memory()\n",
1203
+ "\n",
1204
+ "print(\"=\" * 80)\n",
1205
+ "print(\"NOTEBOOK COMPLETE\")\n",
1206
+ "print(\"=\" * 80)\n",
1207
+ "print(\"\\nThank you for using this SHAP explainability notebook!\")\n",
1208
+ "print(\"\\nTo clear GPU memory, run: clear_gpu_memory()\")\n",
1209
+ "print(\"To check GPU status, run: print_gpu_memory()\")"
1210
+ ]
1211
+ }
1212
+ ],
1213
+ "metadata": {
1214
+ "kernelspec": {
1215
+ "display_name": "Python 3",
1216
+ "language": "python",
1217
+ "name": "python3"
1218
+ },
1219
+ "language_info": {
1220
+ "codemirror_mode": {
1221
+ "name": "ipython",
1222
+ "version": 3
1223
+ },
1224
+ "file_extension": ".py",
1225
+ "mimetype": "text/x-python",
1226
+ "name": "python",
1227
+ "nbconvert_exporter": "python",
1228
+ "pygments_lexer": "ipython3",
1229
+ "version": "3.8.10"
1230
+ }
1231
+ },
1232
+ "nbformat": 4,
1233
+ "nbformat_minor": 4
1234
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 KNN SHAP Explainability Project
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # KNN Explainability with SHAP
2
+
3
+ A comprehensive Jupyter notebook demonstrating how to use SHAP (SHapley Additive exPlanations) to interpret K-Nearest Neighbors (KNN) model predictions with detailed visualizations.
4
+
5
+ ## Overview
6
+
7
+ This project provides a complete walkthrough of:
8
+ - Training a K-Nearest Neighbors classifier on the Breast Cancer Wisconsin dataset
9
+ - Using SHAP to explain model predictions at both global and local levels
10
+ - Creating comprehensive visualizations to understand feature importance and model behavior
11
+ - Interactive exploration of individual predictions
12
+
13
+ ## Features
14
+
15
+ ### Model Training
16
+ - Optimal K value selection through cross-validation
17
+ - StandardScaler preprocessing for KNN optimization
18
+ - Comprehensive model evaluation metrics
19
+ - ROC curves and confusion matrices
20
+
21
+ ### SHAP Explainability
22
+ - **Summary Plots**: Global feature importance with value distributions
23
+ - **Bar Plots**: Mean absolute feature importance rankings
24
+ - **Waterfall Plots**: Step-by-step breakdown of individual predictions
25
+ - **Force Plots**: Visual representation of feature contributions
26
+ - **Dependence Plots**: Feature-value relationships and interactions
27
+ - **Decision Plots**: Cumulative prediction paths for multiple samples
28
+
29
+ ### Interactive Analysis
30
+ - Custom `explain_sample()` function for detailed prediction exploration
31
+ - Comparison of correct vs. incorrect predictions
32
+ - GPU memory management utilities
33
+ - Comprehensive reporting and summaries
34
+
35
+ ## Prerequisites
36
+
37
+ ### Required Packages
38
+ ```
39
+ torch
40
+ shap
41
+ scikit-learn
42
+ matplotlib
43
+ seaborn
44
+ pandas
45
+ numpy
46
+ plotly
47
+ ipywidgets
48
+ tqdm
49
+ ```
50
+
51
+ ### Environment
52
+ - Python 3.7+
53
+ - Jupyter Notebook or JupyterLab
54
+ - VS Code with Jupyter extension (recommended)
55
+ - GPU support optional (CUDA-enabled PyTorch for faster computation)
56
+
57
+ ## Installation
58
+
59
+ 1. Clone this repository:
60
+ ```bash
61
+ git clone <repository-url>
62
+ cd <repository-directory>
63
+ ```
64
+
65
+ 2. Install required packages:
66
+ ```bash
67
+ pip install torch shap scikit-learn matplotlib seaborn pandas numpy plotly ipywidgets tqdm
68
+ ```
69
+
70
+ 3. Launch Jupyter:
71
+ ```bash
72
+ jupyter notebook
73
+ ```
74
+
75
+ 4. Open `KNN_SHAP_Explainability.ipynb`
76
+
77
+ ## Usage
78
+
79
+ ### Basic Usage
80
+
81
+ Run the notebook cells sequentially from top to bottom. The notebook is structured in logical sections:
82
+
83
+ 1. **Environment Setup**: GPU verification and package installation
84
+ 2. **Data Loading**: Load and explore the Breast Cancer Wisconsin dataset
85
+ 3. **Preprocessing**: Feature scaling and train-test split
86
+ 4. **Model Training**: KNN optimization and evaluation
87
+ 5. **SHAP Analysis**: Compute SHAP values for test samples
88
+ 6. **Visualizations**: Generate all SHAP plots and explanations
89
+ 7. **Interactive Exploration**: Use custom functions to explore predictions
90
+
91
+ ### Interactive Exploration
92
+
93
+ After running all cells, use the `explain_sample()` function to explore any prediction:
94
+
95
+ ```python
96
+ # Explain the first sample
97
+ explain_sample(0)
98
+
99
+ # Explain a high-confidence correct prediction
100
+ explain_sample(5)
101
+
102
+ # Explain a misclassified sample
103
+ explain_sample(42)
104
+ ```
105
+
106
+ ### GPU Memory Management
107
+
108
+ If running on GPU, monitor and manage memory:
109
+
110
+ ```python
111
+ # Check current GPU memory usage
112
+ print_gpu_memory()
113
+
114
+ # Clear GPU cache
115
+ clear_gpu_memory()
116
+
117
+ # Get optimal device (CPU or GPU)
118
+ device = get_optimal_device()
119
+ ```
120
+
121
+ ## Dataset
122
+
123
+ The notebook uses the **Breast Cancer Wisconsin (Diagnostic)** dataset from scikit-learn:
124
+ - **Samples**: 569
125
+ - **Features**: 30 (mean, standard error, and worst values of 10 real-valued features)
126
+ - **Classes**: 2 (Malignant, Benign)
127
+ - **Task**: Binary classification
128
+
129
+ Features include radius, texture, perimeter, area, smoothness, compactness, concavity, concave points, symmetry, and fractal dimension.
130
+
131
+ ## Model Performance
132
+
133
+ The KNN model achieves:
134
+ - High accuracy on both training and test sets
135
+ - Optimal K value determined through cross-validation
136
+ - ROC AUC score > 0.95 (typical)
137
+ - Interpretable predictions through SHAP analysis
138
+
139
+ ## SHAP Visualization Guide
140
+
141
+ ### Summary Plot (Beeswarm)
142
+ - Shows global feature importance across all samples
143
+ - Color indicates feature value (red=high, blue=low)
144
+ - Horizontal position shows impact on prediction
145
+
146
+ ### Waterfall Plot
147
+ - Explains individual predictions step-by-step
148
+ - Starts from base value (expected prediction)
149
+ - Each bar shows a feature's contribution
150
+ - Red pushes toward malignant, blue toward benign
151
+
152
+ ### Dependence Plot
153
+ - Reveals non-linear feature relationships
154
+ - Shows feature interactions through color
155
+ - Identifies threshold effects
156
+
157
+ ### Decision Plot
158
+ - Visualizes prediction paths for multiple samples
159
+ - Shows cumulative effect of features
160
+ - Helps identify prediction patterns
161
+
162
+ ## Key Insights
163
+
164
+ 1. **Feature Importance**: SHAP identifies the most critical features for cancer diagnosis
165
+ 2. **Non-linearity**: Dependence plots reveal complex feature-value relationships
166
+ 3. **Interactions**: Color gradients show which features interact
167
+ 4. **Individual Explanations**: Each prediction can be fully explained and understood
168
+ 5. **Model Trust**: Transparent explanations increase confidence in model decisions
169
+
170
+ ## Customization
171
+
172
+ ### Using Different Datasets
173
+
174
+ Replace the data loading section with your own dataset:
175
+
176
+ ```python
177
+ # Load your dataset
178
+ X = pd.DataFrame(your_data)
179
+ y = pd.Series(your_labels)
180
+
181
+ # Continue with the rest of the notebook
182
+ ```
183
+
184
+ ### Adjusting SHAP Computation
185
+
186
+ Modify SHAP parameters for speed/accuracy tradeoff:
187
+
188
+ ```python
189
+ # Faster computation (less accurate)
190
+ shap_values = explainer.shap_values(X_test_sample, nsamples=50)
191
+
192
+ # More accurate (slower)
193
+ shap_values = explainer.shap_values(X_test_sample, nsamples=200)
194
+
195
+ # Smaller background dataset (faster)
196
+ background = shap.kmeans(X_train_scaled, 50)
197
+ ```
198
+
199
+ ### Trying Other Algorithms
200
+
201
+ The SHAP approach works with any model:
202
+
203
+ ```python
204
+ from sklearn.ensemble import RandomForestClassifier
205
+
206
+ # Train Random Forest instead of KNN
207
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
208
+ model.fit(X_train_scaled, y_train)
209
+
210
+ # Use TreeExplainer for faster computation on tree-based models
211
+ explainer = shap.TreeExplainer(model)
212
+ ```
213
+
214
+ ## Performance Tips
215
+
216
+ 1. **GPU Acceleration**: Use GPU for faster PyTorch operations
217
+ 2. **Background Size**: Reduce background dataset size for faster SHAP computation
218
+ 3. **Sample Size**: Start with fewer samples (e.g., 50) for quick testing
219
+ 4. **nsamples Parameter**: Lower values speed up computation but reduce accuracy
220
+ 5. **Memory Management**: Clear GPU cache between major computations
221
+
222
+ ## Troubleshooting
223
+
224
+ ### Common Issues
225
+
226
+ **GPU not detected:**
227
+ - Check CUDA installation
228
+ - Verify PyTorch GPU support: `torch.cuda.is_available()`
229
+ - Notebook will fall back to CPU automatically
230
+
231
+ **SHAP computation too slow:**
232
+ - Reduce background dataset size
233
+ - Decrease number of test samples
234
+ - Lower nsamples parameter
235
+
236
+ **Memory errors:**
237
+ - Process fewer samples at once
238
+ - Clear GPU cache with `clear_gpu_memory()`
239
+ - Reduce background dataset size
240
+
241
+ **Visualization issues:**
242
+ - Ensure matplotlib backend is compatible
243
+ - Update SHAP to latest version
244
+ - Restart kernel if plots don't render
245
+
246
+ ## Contributing
247
+
248
+ Contributions are welcome! Please feel free to submit pull requests or open issues for bugs, questions, or new features.
249
+
250
+ ## License
251
+
252
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
253
+
254
+ ## Acknowledgments
255
+
256
+ - **SHAP**: Scott Lundberg et al. for the SHAP library
257
+ - **scikit-learn**: For the Breast Cancer Wisconsin dataset and ML tools
258
+ - **PyTorch**: For GPU acceleration capabilities
259
+ - **Community**: All contributors to the open-source ML/AI ecosystem
260
+
261
+ ## References
262
+
263
+ - Lundberg, S. M., & Lee, S. I. (2017). A unified approach to interpreting model predictions. NeurIPS.
264
+ - SHAP Documentation: https://shap.readthedocs.io/
265
+ - scikit-learn Documentation: https://scikit-learn.org/
266
+ - Breast Cancer Wisconsin Dataset: https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)
267
+
268
+ ## Contact
269
+
270
+ For questions or feedback, please open an issue in the repository.
271
+
272
+ ---
273
+
274
+ **Note**: This notebook is designed for educational purposes and demonstrates best practices for ML model interpretability using SHAP.