Upload 4 files
Browse files- dataPreprocessing.ipynb +1203 -0
- detectionLayer.ipynb +1799 -0
- regressionLayer.ipynb +962 -0
- severityPredictionLayer.ipynb +1858 -0
dataPreprocessing.ipynb
ADDED
|
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "37cfe918-891e-4c07-914b-a5c42ff01f12",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"### Data Import"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 1,
|
| 14 |
+
"id": "db8bb49a-2e2d-4408-b571-44c7547b463b",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"# Libraries\n",
|
| 19 |
+
"import pandas as pd\n",
|
| 20 |
+
"import numpy as np\n",
|
| 21 |
+
"import random as rnd\n",
|
| 22 |
+
"import seaborn as sns\n",
|
| 23 |
+
"import matplotlib.pyplot as plt\n",
|
| 24 |
+
"%matplotlib inline"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": 2,
|
| 30 |
+
"id": "8d470c8b-9188-45ba-9356-ec22c1b146bf",
|
| 31 |
+
"metadata": {
|
| 32 |
+
"scrolled": true
|
| 33 |
+
},
|
| 34 |
+
"outputs": [
|
| 35 |
+
{
|
| 36 |
+
"name": "stdout",
|
| 37 |
+
"output_type": "stream",
|
| 38 |
+
"text": [
|
| 39 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 40 |
+
"RangeIndex: 3005 entries, 0 to 3004\n",
|
| 41 |
+
"Columns: 820 entries, ID to CASEDIF\n",
|
| 42 |
+
"dtypes: float64(93), int64(6), object(721)\n",
|
| 43 |
+
"memory usage: 18.8+ MB\n"
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"source": [
|
| 48 |
+
"# Load up data\n",
|
| 49 |
+
"pd.set_option('display.max_rows', 10)\n",
|
| 50 |
+
"df = pd.read_csv('./data.csv')\n",
|
| 51 |
+
"df.info()"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "markdown",
|
| 56 |
+
"id": "6c6fe342-0591-48e4-9d5f-d981d62d50c9",
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"source": [
|
| 59 |
+
"# Feature Deletion"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": 3,
|
| 65 |
+
"id": "6508cb2e-4304-4e2c-a101-79e7caeaffc2",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"# Features with high amount of missing values\n",
|
| 70 |
+
"empty_values = df.isna().sum()\n",
|
| 71 |
+
"total_rows = len(df)\n",
|
| 72 |
+
"empty_percentages = (empty_values / total_rows) * 100\n",
|
| 73 |
+
"filtered_empty_values_ab25per = empty_percentages[empty_percentages > 25]\n",
|
| 74 |
+
"cols_to_drop = filtered_empty_values_ab25per.index.tolist()\n",
|
| 75 |
+
"df = df.drop(columns = cols_to_drop)"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": 4,
|
| 81 |
+
"id": "f67860b6-5837-49de-a970-50f40a7e4bc7",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [
|
| 84 |
+
{
|
| 85 |
+
"name": "stdout",
|
| 86 |
+
"output_type": "stream",
|
| 87 |
+
"text": [
|
| 88 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 89 |
+
"RangeIndex: 3005 entries, 0 to 3004\n",
|
| 90 |
+
"Columns: 556 entries, ID to CASEDIF\n",
|
| 91 |
+
"dtypes: float64(49), int64(6), object(501)\n",
|
| 92 |
+
"memory usage: 12.7+ MB\n"
|
| 93 |
+
]
|
| 94 |
+
}
|
| 95 |
+
],
|
| 96 |
+
"source": [
|
| 97 |
+
"df.info()"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": 5,
|
| 103 |
+
"id": "2020f8b3-bda8-459b-819c-f601c4aa6a53",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"# Medecine features except for ones with high correlation with depression\n",
|
| 108 |
+
"medicine_cols_to_keep = ['PSYCHOTHERAGENTS', 'ANTIDEPRESSANTS', 'SSRIANTIDEPRESSA',\n",
|
| 109 |
+
" 'ANXIOLYTICSSEDAT', 'ANTICONVULSANTS', 'BENZODIAZEPINES']\n",
|
| 110 |
+
"medicine_cols = [\n",
|
| 111 |
+
" \"antiinfectives\",\n",
|
| 112 |
+
" \"amebicides\",\n",
|
| 113 |
+
" \"antifungals\",\n",
|
| 114 |
+
" \"antimalarialagen\",\n",
|
| 115 |
+
" \"antituberagents\",\n",
|
| 116 |
+
" \"cephalosporins\",\n",
|
| 117 |
+
" \"leprostatics\",\n",
|
| 118 |
+
" \"macrolidederivat\",\n",
|
| 119 |
+
" \"miscantibiotics\",\n",
|
| 120 |
+
" \"penicillins\",\n",
|
| 121 |
+
" \"quinolones\",\n",
|
| 122 |
+
" \"sulfonamides\",\n",
|
| 123 |
+
" \"tetracyclines\",\n",
|
| 124 |
+
" \"urinaryantiinfec\",\n",
|
| 125 |
+
" \"antihyplipagents\",\n",
|
| 126 |
+
" \"antineoplastics\",\n",
|
| 127 |
+
" \"alkylatingagents\",\n",
|
| 128 |
+
" \"antimetabolites\",\n",
|
| 129 |
+
" \"hormonesantineop\",\n",
|
| 130 |
+
" \"miscantineoplast\",\n",
|
| 131 |
+
" \"biologicals\",\n",
|
| 132 |
+
" \"recombinanthuman\",\n",
|
| 133 |
+
" \"cardiovascularag\",\n",
|
| 134 |
+
" \"angiotensinconve\",\n",
|
| 135 |
+
" \"antiadrenergperi\",\n",
|
| 136 |
+
" \"antiadrenergcent\",\n",
|
| 137 |
+
" \"antianginalagent\",\n",
|
| 138 |
+
" \"antiarrhythmicag\",\n",
|
| 139 |
+
" \"betaadrenergicbl\",\n",
|
| 140 |
+
" \"calciumchannelbl\",\n",
|
| 141 |
+
" \"diuretics\",\n",
|
| 142 |
+
" \"inotropicagents\",\n",
|
| 143 |
+
" \"misccardiovascul\",\n",
|
| 144 |
+
" \"peripheralvasodi\",\n",
|
| 145 |
+
" \"vasodilators\",\n",
|
| 146 |
+
" \"vasopressors\",\n",
|
| 147 |
+
" \"antihypertensive\",\n",
|
| 148 |
+
" \"angiotensiniiinh\",\n",
|
| 149 |
+
" \"centralnervoussy\",\n",
|
| 150 |
+
" \"analgesics\",\n",
|
| 151 |
+
" \"miscanalgesics\",\n",
|
| 152 |
+
" \"narcanalgs\",\n",
|
| 153 |
+
" \"nonsteroidalanti\",\n",
|
| 154 |
+
" \"salicylates\",\n",
|
| 155 |
+
" \"analgesiccombina\",\n",
|
| 156 |
+
" \"anticonvulsants\",\n",
|
| 157 |
+
" \"antiemeticantive\",\n",
|
| 158 |
+
" \"antiparkinsonage\",\n",
|
| 159 |
+
" \"anxiolyticssedat\",\n",
|
| 160 |
+
" \"barbiturates\",\n",
|
| 161 |
+
" \"benzodiazepines\",\n",
|
| 162 |
+
" \"miscanxiolyticss\",\n",
|
| 163 |
+
" \"cnsstimulants\",\n",
|
| 164 |
+
" \"musclerelaxants\",\n",
|
| 165 |
+
" \"miscantidepressa\",\n",
|
| 166 |
+
" \"miscantipsychoti\",\n",
|
| 167 |
+
" \"psychothercombin\",\n",
|
| 168 |
+
" \"misccentralnervo\",\n",
|
| 169 |
+
" \"coagulationmodif\",\n",
|
| 170 |
+
" \"anticoagulants\",\n",
|
| 171 |
+
" \"antiplateletagen\",\n",
|
| 172 |
+
" \"misccoagulationm\",\n",
|
| 173 |
+
" \"gastrointestinal\",\n",
|
| 174 |
+
" \"antacids\",\n",
|
| 175 |
+
" \"anticholsantispa\",\n",
|
| 176 |
+
" \"antidiarrheals\",\n",
|
| 177 |
+
" \"digestiveenzymes\",\n",
|
| 178 |
+
" \"gallstonesolubil\",\n",
|
| 179 |
+
" \"gistimulants\",\n",
|
| 180 |
+
" \"h2antagonists\",\n",
|
| 181 |
+
" \"laxatives\",\n",
|
| 182 |
+
" \"miscgiagents\",\n",
|
| 183 |
+
" \"hormones\",\n",
|
| 184 |
+
" \"adrenalcorticals\",\n",
|
| 185 |
+
" \"antidiabeticagen\",\n",
|
| 186 |
+
" \"mischormones\",\n",
|
| 187 |
+
" \"sexhormones\",\n",
|
| 188 |
+
" \"contraceptives\",\n",
|
| 189 |
+
" \"thyroiddrugs\",\n",
|
| 190 |
+
" \"immunosuppressiv\",\n",
|
| 191 |
+
" \"miscagents\",\n",
|
| 192 |
+
" \"antidotes\",\n",
|
| 193 |
+
" \"chelatingagents\",\n",
|
| 194 |
+
" \"cholinergicmuscl\",\n",
|
| 195 |
+
" \"localinjectablea\",\n",
|
| 196 |
+
" \"miscuncategorize\",\n",
|
| 197 |
+
" \"genitourinarytra\",\n",
|
| 198 |
+
" \"nutritionalprods\",\n",
|
| 199 |
+
" \"ironproducts\",\n",
|
| 200 |
+
" \"mineralsandelect\",\n",
|
| 201 |
+
" \"vitamins\",\n",
|
| 202 |
+
" \"vitaminmineral\",\n",
|
| 203 |
+
" \"respiratoryagent\",\n",
|
| 204 |
+
" \"antihistamines\",\n",
|
| 205 |
+
" \"antitussives\",\n",
|
| 206 |
+
" \"bronchodilators\",\n",
|
| 207 |
+
" \"methylxanthines\",\n",
|
| 208 |
+
" \"decongestants\",\n",
|
| 209 |
+
" \"expectorants\",\n",
|
| 210 |
+
" \"miscrespiratorya\",\n",
|
| 211 |
+
" \"respiratoryinhal\",\n",
|
| 212 |
+
" \"upperrespiratory\",\n",
|
| 213 |
+
" \"topicalagents\",\n",
|
| 214 |
+
" \"dermatologicalag\",\n",
|
| 215 |
+
" \"topicalantiinfec\",\n",
|
| 216 |
+
" \"topicalsteroids\",\n",
|
| 217 |
+
" \"topicalanestheti\",\n",
|
| 218 |
+
" \"misctopicalagent\",\n",
|
| 219 |
+
" \"topicalacneagent\",\n",
|
| 220 |
+
" \"mouthandthroatpr\",\n",
|
| 221 |
+
" \"ophthalpreparati\",\n",
|
| 222 |
+
" \"oticpreparations\",\n",
|
| 223 |
+
" \"vaginalpreparati\",\n",
|
| 224 |
+
" \"loopdiuretics\",\n",
|
| 225 |
+
" \"potassiumsparing\",\n",
|
| 226 |
+
" \"thiazidediuretic\",\n",
|
| 227 |
+
" \"carbonicanhydras\",\n",
|
| 228 |
+
" \"firstgenerationc\",\n",
|
| 229 |
+
" \"thirdgenerationc\",\n",
|
| 230 |
+
" \"ophthalantiinfec\",\n",
|
| 231 |
+
" \"ophthalglaucomaa\",\n",
|
| 232 |
+
" \"ophthalsteroids\",\n",
|
| 233 |
+
" \"ophthalsteroidsw\",\n",
|
| 234 |
+
" \"ophthalantiinfla\",\n",
|
| 235 |
+
" \"miscophthalagent\",\n",
|
| 236 |
+
" \"oticsteroidswith\",\n",
|
| 237 |
+
" \"miscoticagents\",\n",
|
| 238 |
+
" \"hmgcoareductasei\",\n",
|
| 239 |
+
" \"miscantihyplipag\",\n",
|
| 240 |
+
" \"skelmuscrels\",\n",
|
| 241 |
+
" \"adrenergicbronch\",\n",
|
| 242 |
+
" \"bronchodilatorco\",\n",
|
| 243 |
+
" \"androgensandanab\",\n",
|
| 244 |
+
" \"estrogens\",\n",
|
| 245 |
+
" \"progestins\",\n",
|
| 246 |
+
" \"sexhormonecombin\",\n",
|
| 247 |
+
" \"narcanalgcombina\",\n",
|
| 248 |
+
" \"antirheumatics\",\n",
|
| 249 |
+
" \"antimigraineagen\",\n",
|
| 250 |
+
" \"antigoutagents\",\n",
|
| 251 |
+
" \"fiveht3receptora\",\n",
|
| 252 |
+
" \"phenthiazantieme\",\n",
|
| 253 |
+
" \"anticholantiemet\",\n",
|
| 254 |
+
" \"miscantiemetics\",\n",
|
| 255 |
+
" \"hydantoinanticon\",\n",
|
| 256 |
+
" \"barbiturateantic\",\n",
|
| 257 |
+
" \"benzodiazepinean\",\n",
|
| 258 |
+
" \"miscanticonvulsa\",\n",
|
| 259 |
+
" \"anticholantipark\",\n",
|
| 260 |
+
" \"ssriantidepressa\",\n",
|
| 261 |
+
" \"tricyclicantidep\",\n",
|
| 262 |
+
" \"phenthiazantipsy\",\n",
|
| 263 |
+
" \"plateletaggregat\",\n",
|
| 264 |
+
" \"sulfonylureas\",\n",
|
| 265 |
+
" \"nonsulfonylureas\",\n",
|
| 266 |
+
" \"insulin\",\n",
|
| 267 |
+
" \"alphaglucosidase\",\n",
|
| 268 |
+
" \"bisphosphonates\",\n",
|
| 269 |
+
" \"alternativemeds\",\n",
|
| 270 |
+
" \"nutraceuticals\",\n",
|
| 271 |
+
" \"herbalproducts\",\n",
|
| 272 |
+
" \"penicillinaseres\",\n",
|
| 273 |
+
" \"aminopenicillins\",\n",
|
| 274 |
+
" \"betalactamaseinh\",\n",
|
| 275 |
+
" \"adamantaneantivi\",\n",
|
| 276 |
+
" \"purinenucleoside\",\n",
|
| 277 |
+
" \"miscantituberage\",\n",
|
| 278 |
+
" \"polyenes\",\n",
|
| 279 |
+
" \"azoleantifungals\",\n",
|
| 280 |
+
" \"miscantifungals\",\n",
|
| 281 |
+
" \"antimalarialquin\",\n",
|
| 282 |
+
" \"miscantimalarial\",\n",
|
| 283 |
+
" \"lincomycinderiva\",\n",
|
| 284 |
+
" \"fibricacidderiva\",\n",
|
| 285 |
+
" \"psychotheragents\",\n",
|
| 286 |
+
" \"leukotrienemodif\",\n",
|
| 287 |
+
" \"nasallubricants\",\n",
|
| 288 |
+
" \"nasalsteroids\",\n",
|
| 289 |
+
" \"nasalantihistami\",\n",
|
| 290 |
+
" \"nasalpreparation\",\n",
|
| 291 |
+
" \"antidepressants\",\n",
|
| 292 |
+
" \"monoamineoxidase\",\n",
|
| 293 |
+
" \"antipsychotics\",\n",
|
| 294 |
+
" \"bileacidsequestr\",\n",
|
| 295 |
+
" \"anorexiants\",\n",
|
| 296 |
+
" \"immunologicagent\",\n",
|
| 297 |
+
" \"monoclonalantibo\",\n",
|
| 298 |
+
" \"heparins\",\n",
|
| 299 |
+
" \"coumarinsandinda\",\n",
|
| 300 |
+
" \"impotenceagents\",\n",
|
| 301 |
+
" \"urinaryantispasm\",\n",
|
| 302 |
+
" \"urinaryphmodifie\",\n",
|
| 303 |
+
" \"miscgenitourinar\",\n",
|
| 304 |
+
" \"ophthalantihista\",\n",
|
| 305 |
+
" \"miscvaginalagent\",\n",
|
| 306 |
+
" \"antipsoriatics\",\n",
|
| 307 |
+
" \"thiazolidinedion\",\n",
|
| 308 |
+
" \"protonpumpinhibi\",\n",
|
| 309 |
+
" \"cardioselectiveb\",\n",
|
| 310 |
+
" \"noncardioselecti\",\n",
|
| 311 |
+
" \"dopaminergicanti\",\n",
|
| 312 |
+
" \"fiveaminosalic\",\n",
|
| 313 |
+
" \"cox2inhibitors\",\n",
|
| 314 |
+
" \"meglitinides\",\n",
|
| 315 |
+
" \"fivealphareducti\",\n",
|
| 316 |
+
" \"antihyperuricemi\",\n",
|
| 317 |
+
" \"topicalantibioti\",\n",
|
| 318 |
+
" \"topicalantifunga\",\n",
|
| 319 |
+
" \"inhaledcorticost\",\n",
|
| 320 |
+
" \"mastcellstabiliz\",\n",
|
| 321 |
+
" \"anticholbronchod\",\n",
|
| 322 |
+
" \"glucocorticoids\",\n",
|
| 323 |
+
" \"mineralocorticoi\",\n",
|
| 324 |
+
" \"agentsforpulmona\",\n",
|
| 325 |
+
" \"macrolides\",\n",
|
| 326 |
+
" \"ketolides\",\n",
|
| 327 |
+
" \"phenylpiperazine\",\n",
|
| 328 |
+
" \"tetracyclicantid\",\n",
|
| 329 |
+
" \"ssnriantidepress\",\n",
|
| 330 |
+
" \"miscantidiabetic\",\n",
|
| 331 |
+
" \"dibenzazepineant\",\n",
|
| 332 |
+
" \"cholinergicagoni\",\n",
|
| 333 |
+
" \"cholinesterasein\",\n",
|
| 334 |
+
" \"antidiabeticcomb\",\n",
|
| 335 |
+
" \"cholesterolabsor\",\n",
|
| 336 |
+
" \"antihyplipcombin\",\n",
|
| 337 |
+
" \"smokingcessation\",\n",
|
| 338 |
+
" \"othersupplements\"\n",
|
| 339 |
+
"]\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"medicine_cols_to_drop = [x.upper() for x in medicine_cols]\n",
|
| 342 |
+
" \n",
|
| 343 |
+
"medicine_cols_to_drop = [element for element in medicine_cols_to_drop if element not in medicine_cols_to_keep]\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"df = df.drop(columns = medicine_cols_to_drop)"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"cell_type": "code",
|
| 350 |
+
"execution_count": 6,
|
| 351 |
+
"id": "1b6f4c04-d4af-48cb-941e-d206c5fea6ff",
|
| 352 |
+
"metadata": {},
|
| 353 |
+
"outputs": [
|
| 354 |
+
{
|
| 355 |
+
"name": "stdout",
|
| 356 |
+
"output_type": "stream",
|
| 357 |
+
"text": [
|
| 358 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 359 |
+
"RangeIndex: 3005 entries, 0 to 3004\n",
|
| 360 |
+
"Columns: 334 entries, ID to CASEDIF\n",
|
| 361 |
+
"dtypes: float64(49), int64(6), object(279)\n",
|
| 362 |
+
"memory usage: 7.7+ MB\n"
|
| 363 |
+
]
|
| 364 |
+
}
|
| 365 |
+
],
|
| 366 |
+
"source": [
|
| 367 |
+
"df.info()"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"cell_type": "code",
|
| 372 |
+
"execution_count": 7,
|
| 373 |
+
"id": "a282ca80-1dfd-41a6-8773-67fbc747e254",
|
| 374 |
+
"metadata": {
|
| 375 |
+
"scrolled": true
|
| 376 |
+
},
|
| 377 |
+
"outputs": [],
|
| 378 |
+
"source": [
|
| 379 |
+
"# Delete useless features\n",
|
| 380 |
+
"useless_cols = ['ID', 'FI_ID', 'PATH', 'VERSION', 'INT_START',\n",
|
| 381 |
+
" 'WEIGHT_SEL', 'WEIGHT_ADJ', 'STRATUM', \n",
|
| 382 |
+
" 'CLUSTER']\n",
|
| 383 |
+
"\n",
|
| 384 |
+
"df = df.drop(columns = useless_cols)"
|
| 385 |
+
]
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "code",
|
| 389 |
+
"execution_count": 8,
|
| 390 |
+
"id": "33bce8f0-e10f-4adf-8775-a36b202e41ab",
|
| 391 |
+
"metadata": {},
|
| 392 |
+
"outputs": [
|
| 393 |
+
{
|
| 394 |
+
"name": "stdout",
|
| 395 |
+
"output_type": "stream",
|
| 396 |
+
"text": [
|
| 397 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 398 |
+
"RangeIndex: 3005 entries, 0 to 3004\n",
|
| 399 |
+
"Columns: 325 entries, GENDER to CASEDIF\n",
|
| 400 |
+
"dtypes: float64(47), int64(1), object(277)\n",
|
| 401 |
+
"memory usage: 7.5+ MB\n"
|
| 402 |
+
]
|
| 403 |
+
}
|
| 404 |
+
],
|
| 405 |
+
"source": [
|
| 406 |
+
"df.info()"
|
| 407 |
+
]
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"cell_type": "markdown",
|
| 411 |
+
"id": "d8130d58-98ce-456e-8992-545b737fae14",
|
| 412 |
+
"metadata": {},
|
| 413 |
+
"source": [
|
| 414 |
+
"# Depression Scale Creation and Entry Cleaning"
|
| 415 |
+
]
|
| 416 |
+
},
|
| 417 |
+
{
|
| 418 |
+
"cell_type": "code",
|
| 419 |
+
"execution_count": 9,
|
| 420 |
+
"id": "4315d3d1-daa1-4add-a891-8b420b21faba",
|
| 421 |
+
"metadata": {},
|
| 422 |
+
"outputs": [
|
| 423 |
+
{
|
| 424 |
+
"name": "stdout",
|
| 425 |
+
"output_type": "stream",
|
| 426 |
+
"text": [
|
| 427 |
+
"RESTLES 222\n",
|
| 428 |
+
"CONFIDNT 216\n",
|
| 429 |
+
"WORRY 212\n",
|
| 430 |
+
"RELAXED 211\n",
|
| 431 |
+
"FLTEFF 13\n",
|
| 432 |
+
"NOTGETGO 10\n",
|
| 433 |
+
"FLTDEP 9\n",
|
| 434 |
+
"NOSLEEP 7\n",
|
| 435 |
+
"NOTEAT 7\n",
|
| 436 |
+
"dtype: int64\n"
|
| 437 |
+
]
|
| 438 |
+
}
|
| 439 |
+
],
|
| 440 |
+
"source": [
|
| 441 |
+
"df_pure = df\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"symptoms_array = [\n",
|
| 444 |
+
" \"NOTGETGO\", # Difficulty getting going\n",
|
| 445 |
+
" \"FLTDEP\", # Feeling of deep sadness or emptiness\n",
|
| 446 |
+
" \"NOSLEEP\", # Insomnia or sleeping too much\n",
|
| 447 |
+
" \"RESTLES\", # Feeling restless\n",
|
| 448 |
+
" \"NOTEAT\", # Changes in appetite or weight\n",
|
| 449 |
+
" \"CONFIDNT\", # Lack of confidence\n",
|
| 450 |
+
" \"FLTEFF\", # Feeling things are out of control\n",
|
| 451 |
+
" \"RELAXED\", # Unable to feel relaxed\n",
|
| 452 |
+
" \"WORRY\" # Worrying thoughts\n",
|
| 453 |
+
"]\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"df_phq9 = df_pure[symptoms_array]\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"missing_counts = df_phq9.isna().sum()\n",
|
| 458 |
+
"missing_counts_sorted = missing_counts.sort_values(ascending=False)\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"# Print the sorted counts\n",
|
| 461 |
+
"print(missing_counts_sorted)"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"cell_type": "code",
|
| 466 |
+
"execution_count": 10,
|
| 467 |
+
"id": "98932270-aeae-4758-a9d9-5234087e6734",
|
| 468 |
+
"metadata": {},
|
| 469 |
+
"outputs": [
|
| 470 |
+
{
|
| 471 |
+
"name": "stdout",
|
| 472 |
+
"output_type": "stream",
|
| 473 |
+
"text": [
|
| 474 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 475 |
+
"RangeIndex: 3005 entries, 0 to 3004\n",
|
| 476 |
+
"Data columns (total 9 columns):\n",
|
| 477 |
+
" # Column Non-Null Count Dtype \n",
|
| 478 |
+
"--- ------ -------------- ----- \n",
|
| 479 |
+
" 0 NOTGETGO 2995 non-null object\n",
|
| 480 |
+
" 1 FLTDEP 2996 non-null object\n",
|
| 481 |
+
" 2 NOSLEEP 2998 non-null object\n",
|
| 482 |
+
" 3 RESTLES 2783 non-null object\n",
|
| 483 |
+
" 4 NOTEAT 2998 non-null object\n",
|
| 484 |
+
" 5 CONFIDNT 2789 non-null object\n",
|
| 485 |
+
" 6 FLTEFF 2992 non-null object\n",
|
| 486 |
+
" 7 RELAXED 2794 non-null object\n",
|
| 487 |
+
" 8 WORRY 2793 non-null object\n",
|
| 488 |
+
"dtypes: object(9)\n",
|
| 489 |
+
"memory usage: 211.4+ KB\n"
|
| 490 |
+
]
|
| 491 |
+
}
|
| 492 |
+
],
|
| 493 |
+
"source": [
|
| 494 |
+
"df_phq9.info()"
|
| 495 |
+
]
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
"cell_type": "code",
|
| 499 |
+
"execution_count": 11,
|
| 500 |
+
"id": "7a9e9234-32bd-4d21-8017-db9ef7354bd9",
|
| 501 |
+
"metadata": {},
|
| 502 |
+
"outputs": [
|
| 503 |
+
{
|
| 504 |
+
"name": "stdout",
|
| 505 |
+
"output_type": "stream",
|
| 506 |
+
"text": [
|
| 507 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 508 |
+
"Index: 2763 entries, 0 to 3004\n",
|
| 509 |
+
"Data columns (total 9 columns):\n",
|
| 510 |
+
" # Column Non-Null Count Dtype \n",
|
| 511 |
+
"--- ------ -------------- ----- \n",
|
| 512 |
+
" 0 NOTGETGO 2763 non-null object\n",
|
| 513 |
+
" 1 FLTDEP 2763 non-null object\n",
|
| 514 |
+
" 2 NOSLEEP 2763 non-null object\n",
|
| 515 |
+
" 3 RESTLES 2763 non-null object\n",
|
| 516 |
+
" 4 NOTEAT 2763 non-null object\n",
|
| 517 |
+
" 5 CONFIDNT 2763 non-null object\n",
|
| 518 |
+
" 6 FLTEFF 2763 non-null object\n",
|
| 519 |
+
" 7 RELAXED 2763 non-null object\n",
|
| 520 |
+
" 8 WORRY 2763 non-null object\n",
|
| 521 |
+
"dtypes: object(9)\n",
|
| 522 |
+
"memory usage: 215.9+ KB\n"
|
| 523 |
+
]
|
| 524 |
+
}
|
| 525 |
+
],
|
| 526 |
+
"source": [
|
| 527 |
+
"# Delete all entries that has a missing value in one of these features\n",
|
| 528 |
+
"df_phq9_2 = df_phq9.copy()\n",
|
| 529 |
+
"df_phq9_2 = df_phq9_2.dropna()\n",
|
| 530 |
+
"df_phq9_2.info() # los around 244 entries from that"
|
| 531 |
+
]
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"cell_type": "code",
|
| 535 |
+
"execution_count": 12,
|
| 536 |
+
"id": "5ba060e8-46e9-4f4f-a680-33660116e35e",
|
| 537 |
+
"metadata": {},
|
| 538 |
+
"outputs": [],
|
| 539 |
+
"source": [
|
| 540 |
+
"# Making depression scale\n",
|
| 541 |
+
"from sklearn.preprocessing import LabelEncoder\n",
|
| 542 |
+
"label_encoder = LabelEncoder()\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"for col in df_phq9_2.columns:\n",
|
| 545 |
+
" df_phq9_2.loc[:, col] = label_encoder.fit_transform(df_phq9_2[col])\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"# Need to reverse the order from bad-good to good-bad\n",
|
| 548 |
+
"df_phq9_2.loc[:, 'CONFIDNT'] = df_phq9_2['CONFIDNT'].apply(lambda x: 3 - x)\n",
|
| 549 |
+
"df_phq9_2.loc[:, 'RELAXED'] = df_phq9_2['RELAXED'].apply(lambda x: 3 - x)\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"df_phq9_2['total_sum'] = df_phq9_2.sum(axis = 1)"
|
| 552 |
+
]
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"cell_type": "code",
|
| 556 |
+
"execution_count": 13,
|
| 557 |
+
"id": "c7684ba0-4a07-4cf4-a7a6-3d2087e5e125",
|
| 558 |
+
"metadata": {},
|
| 559 |
+
"outputs": [],
|
| 560 |
+
"source": [
|
| 561 |
+
"# Categorize depression\n",
|
| 562 |
+
"def categorize_score(score):\n",
|
| 563 |
+
" if 0 <= score <= 4:\n",
|
| 564 |
+
" return 'Normal'\n",
|
| 565 |
+
" elif 5 <= score <= 9:\n",
|
| 566 |
+
" return 'Mild'\n",
|
| 567 |
+
" elif 10 <= score <= 27:\n",
|
| 568 |
+
" return 'ModerateSevere'\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"# Applying the categorization function to the 'Total_Sum' column\n",
|
| 571 |
+
"df_phq9_2['depression_category'] = df_phq9_2['total_sum'].apply(categorize_score)"
|
| 572 |
+
]
|
| 573 |
+
},
|
| 574 |
+
{
|
| 575 |
+
"cell_type": "code",
|
| 576 |
+
"execution_count": 14,
|
| 577 |
+
"id": "eb79389e-0a1e-4f8f-b27e-efb320610038",
|
| 578 |
+
"metadata": {},
|
| 579 |
+
"outputs": [],
|
| 580 |
+
"source": [
|
| 581 |
+
"# Seperate by Category\n",
|
| 582 |
+
"df_phq9_normal = df_phq9_2[df_phq9_2['depression_category'] == 'Normal']\n",
|
| 583 |
+
"df_phq9_mild = df_phq9_2[df_phq9_2['depression_category'] == 'Mild']\n",
|
| 584 |
+
"df_phq9_moderatesevere = df_phq9_2[df_phq9_2['depression_category'] == 'ModerateSevere']"
|
| 585 |
+
]
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
"cell_type": "code",
|
| 589 |
+
"execution_count": 15,
|
| 590 |
+
"id": "c03612d6-788b-4cdd-85f6-8e8e7c541956",
|
| 591 |
+
"metadata": {
|
| 592 |
+
"scrolled": true
|
| 593 |
+
},
|
| 594 |
+
"outputs": [
|
| 595 |
+
{
|
| 596 |
+
"name": "stdout",
|
| 597 |
+
"output_type": "stream",
|
| 598 |
+
"text": [
|
| 599 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 600 |
+
"Index: 1308 entries, 0 to 3003\n",
|
| 601 |
+
"Data columns (total 11 columns):\n",
|
| 602 |
+
" # Column Non-Null Count Dtype \n",
|
| 603 |
+
"--- ------ -------------- ----- \n",
|
| 604 |
+
" 0 NOTGETGO 1308 non-null object\n",
|
| 605 |
+
" 1 FLTDEP 1308 non-null object\n",
|
| 606 |
+
" 2 NOSLEEP 1308 non-null object\n",
|
| 607 |
+
" 3 RESTLES 1308 non-null object\n",
|
| 608 |
+
" 4 NOTEAT 1308 non-null object\n",
|
| 609 |
+
" 5 CONFIDNT 1308 non-null object\n",
|
| 610 |
+
" 6 FLTEFF 1308 non-null object\n",
|
| 611 |
+
" 7 RELAXED 1308 non-null object\n",
|
| 612 |
+
" 8 WORRY 1308 non-null object\n",
|
| 613 |
+
" 9 total_sum 1308 non-null object\n",
|
| 614 |
+
" 10 depression_category 1308 non-null object\n",
|
| 615 |
+
"dtypes: object(11)\n",
|
| 616 |
+
"memory usage: 122.6+ KB\n",
|
| 617 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 618 |
+
"Index: 940 entries, 2 to 3004\n",
|
| 619 |
+
"Data columns (total 11 columns):\n",
|
| 620 |
+
" # Column Non-Null Count Dtype \n",
|
| 621 |
+
"--- ------ -------------- ----- \n",
|
| 622 |
+
" 0 NOTGETGO 940 non-null object\n",
|
| 623 |
+
" 1 FLTDEP 940 non-null object\n",
|
| 624 |
+
" 2 NOSLEEP 940 non-null object\n",
|
| 625 |
+
" 3 RESTLES 940 non-null object\n",
|
| 626 |
+
" 4 NOTEAT 940 non-null object\n",
|
| 627 |
+
" 5 CONFIDNT 940 non-null object\n",
|
| 628 |
+
" 6 FLTEFF 940 non-null object\n",
|
| 629 |
+
" 7 RELAXED 940 non-null object\n",
|
| 630 |
+
" 8 WORRY 940 non-null object\n",
|
| 631 |
+
" 9 total_sum 940 non-null object\n",
|
| 632 |
+
" 10 depression_category 940 non-null object\n",
|
| 633 |
+
"dtypes: object(11)\n",
|
| 634 |
+
"memory usage: 88.1+ KB\n",
|
| 635 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 636 |
+
"Index: 515 entries, 3 to 3000\n",
|
| 637 |
+
"Data columns (total 11 columns):\n",
|
| 638 |
+
" # Column Non-Null Count Dtype \n",
|
| 639 |
+
"--- ------ -------------- ----- \n",
|
| 640 |
+
" 0 NOTGETGO 515 non-null object\n",
|
| 641 |
+
" 1 FLTDEP 515 non-null object\n",
|
| 642 |
+
" 2 NOSLEEP 515 non-null object\n",
|
| 643 |
+
" 3 RESTLES 515 non-null object\n",
|
| 644 |
+
" 4 NOTEAT 515 non-null object\n",
|
| 645 |
+
" 5 CONFIDNT 515 non-null object\n",
|
| 646 |
+
" 6 FLTEFF 515 non-null object\n",
|
| 647 |
+
" 7 RELAXED 515 non-null object\n",
|
| 648 |
+
" 8 WORRY 515 non-null object\n",
|
| 649 |
+
" 9 total_sum 515 non-null object\n",
|
| 650 |
+
" 10 depression_category 515 non-null object\n",
|
| 651 |
+
"dtypes: object(11)\n",
|
| 652 |
+
"memory usage: 48.3+ KB\n"
|
| 653 |
+
]
|
| 654 |
+
}
|
| 655 |
+
],
|
| 656 |
+
"source": [
|
| 657 |
+
"df_phq9_normal.info()\n",
|
| 658 |
+
"df_phq9_mild.info()\n",
|
| 659 |
+
"df_phq9_moderatesevere.info()"
|
| 660 |
+
]
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"cell_type": "code",
|
| 664 |
+
"execution_count": 16,
|
| 665 |
+
"id": "7b8ab44d-aab2-4b5c-ba3e-e1eabd3d7239",
|
| 666 |
+
"metadata": {},
|
| 667 |
+
"outputs": [],
|
| 668 |
+
"source": [
|
| 669 |
+
"# Connect with original dataset\n",
|
| 670 |
+
"mild_indices = df_phq9_mild.index\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"df_mild_connection = df_pure.loc[mild_indices]\n",
|
| 673 |
+
"df_mild_connection['depression_category'] = 'mild'\n",
|
| 674 |
+
"df_mild_connection\n",
|
| 675 |
+
"\n",
|
| 676 |
+
"moderatesevere_indices = df_phq9_moderatesevere.index\n",
|
| 677 |
+
"\n",
|
| 678 |
+
"df_moderatesevere_connection = df_pure.loc[moderatesevere_indices]\n",
|
| 679 |
+
"df_moderatesevere_connection['depression_category'] = 'moderatesevere'\n",
|
| 680 |
+
"df_moderatesevere_connection\n",
|
| 681 |
+
"\n",
|
| 682 |
+
"normal_indices = df_phq9_normal.index\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"df_normal_connection = df_pure.loc[normal_indices]\n",
|
| 685 |
+
"df_normal_connection['depression_category'] = 'normal'\n",
|
| 686 |
+
"df_normal_connection\n",
|
| 687 |
+
"\n",
|
| 688 |
+
"df_appended = pd.concat([df_normal_connection, df_mild_connection, df_moderatesevere_connection], ignore_index=False)\n",
|
| 689 |
+
"\n",
|
| 690 |
+
"symptoms_array = [\n",
|
| 691 |
+
" \"NOTGETGO\", # Difficulty getting going\n",
|
| 692 |
+
" \"FLTDEP\", # Feeling of deep sadness or emptiness\n",
|
| 693 |
+
" \"NOSLEEP\", # Insomnia or sleeping too much\n",
|
| 694 |
+
" \"RESTLES\", # Feeling restless\n",
|
| 695 |
+
" \"NOTEAT\", # Changes in appetite or weight\n",
|
| 696 |
+
" \"CONFIDNT\", # Lack of confidence\n",
|
| 697 |
+
" \"UNCNTRL\", # Feeling things are out of control\n",
|
| 698 |
+
" \"RELAXED\", # Unable to feel relaxed\n",
|
| 699 |
+
" \"WORRY\" # Worrying thoughts\n",
|
| 700 |
+
"]\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"df_appended = df_appended.drop(columns = symptoms_array)\n",
|
| 703 |
+
"df_appended['total_sum'] = df_phq9_2['total_sum']"
|
| 704 |
+
]
|
| 705 |
+
},
|
| 706 |
+
{
|
| 707 |
+
"cell_type": "code",
|
| 708 |
+
"execution_count": 17,
|
| 709 |
+
"id": "d6d0ae13-8b6a-466f-af99-571509ef1201",
|
| 710 |
+
"metadata": {
|
| 711 |
+
"scrolled": true
|
| 712 |
+
},
|
| 713 |
+
"outputs": [
|
| 714 |
+
{
|
| 715 |
+
"data": {
|
| 716 |
+
"text/html": [
|
| 717 |
+
"<div>\n",
|
| 718 |
+
"<style scoped>\n",
|
| 719 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 720 |
+
" vertical-align: middle;\n",
|
| 721 |
+
" }\n",
|
| 722 |
+
"\n",
|
| 723 |
+
" .dataframe tbody tr th {\n",
|
| 724 |
+
" vertical-align: top;\n",
|
| 725 |
+
" }\n",
|
| 726 |
+
"\n",
|
| 727 |
+
" .dataframe thead th {\n",
|
| 728 |
+
" text-align: right;\n",
|
| 729 |
+
" }\n",
|
| 730 |
+
"</style>\n",
|
| 731 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 732 |
+
" <thead>\n",
|
| 733 |
+
" <tr style=\"text-align: right;\">\n",
|
| 734 |
+
" <th></th>\n",
|
| 735 |
+
" <th>GENDER</th>\n",
|
| 736 |
+
" <th>AGE</th>\n",
|
| 737 |
+
" <th>AGEGRP</th>\n",
|
| 738 |
+
" <th>DEGREE_RECODE</th>\n",
|
| 739 |
+
" <th>EDUC</th>\n",
|
| 740 |
+
" <th>RACE_RECODE</th>\n",
|
| 741 |
+
" <th>HISPANIC</th>\n",
|
| 742 |
+
" <th>ETHGRP</th>\n",
|
| 743 |
+
" <th>MILITARY</th>\n",
|
| 744 |
+
" <th>JAIL</th>\n",
|
| 745 |
+
" <th>...</th>\n",
|
| 746 |
+
" <th>IWLOC5</th>\n",
|
| 747 |
+
" <th>IWLOC6</th>\n",
|
| 748 |
+
" <th>STRUCTQ</th>\n",
|
| 749 |
+
" <th>BUILD</th>\n",
|
| 750 |
+
" <th>OTBUILD</th>\n",
|
| 751 |
+
" <th>COMBUILD</th>\n",
|
| 752 |
+
" <th>CASECOMP</th>\n",
|
| 753 |
+
" <th>CASEDIF</th>\n",
|
| 754 |
+
" <th>depression_category</th>\n",
|
| 755 |
+
" <th>total_sum</th>\n",
|
| 756 |
+
" </tr>\n",
|
| 757 |
+
" </thead>\n",
|
| 758 |
+
" <tbody>\n",
|
| 759 |
+
" <tr>\n",
|
| 760 |
+
" <th>0</th>\n",
|
| 761 |
+
" <td>(2) female</td>\n",
|
| 762 |
+
" <td>62</td>\n",
|
| 763 |
+
" <td>(1) 57-64</td>\n",
|
| 764 |
+
" <td>(5) masters</td>\n",
|
| 765 |
+
" <td>(4) bachelors or more</td>\n",
|
| 766 |
+
" <td>(1) white/caucasian</td>\n",
|
| 767 |
+
" <td>(0) no</td>\n",
|
| 768 |
+
" <td>(1) white</td>\n",
|
| 769 |
+
" <td>(0) no</td>\n",
|
| 770 |
+
" <td>(0) no</td>\n",
|
| 771 |
+
" <td>...</td>\n",
|
| 772 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 773 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 774 |
+
" <td>(02) detached single family house</td>\n",
|
| 775 |
+
" <td>(4) very well kept</td>\n",
|
| 776 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 777 |
+
" <td>(3) average</td>\n",
|
| 778 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 779 |
+
" <td>(2) somewhat difficult</td>\n",
|
| 780 |
+
" <td>normal</td>\n",
|
| 781 |
+
" <td>2</td>\n",
|
| 782 |
+
" </tr>\n",
|
| 783 |
+
" <tr>\n",
|
| 784 |
+
" <th>1</th>\n",
|
| 785 |
+
" <td>(2) female</td>\n",
|
| 786 |
+
" <td>79</td>\n",
|
| 787 |
+
" <td>(3) 75-85</td>\n",
|
| 788 |
+
" <td>(2) high school diploma/equivalency</td>\n",
|
| 789 |
+
" <td>(3) voc cert/some college/assoc</td>\n",
|
| 790 |
+
" <td>(1) white/caucasian</td>\n",
|
| 791 |
+
" <td>(0) no</td>\n",
|
| 792 |
+
" <td>(1) white</td>\n",
|
| 793 |
+
" <td>(0) no</td>\n",
|
| 794 |
+
" <td>(0) no</td>\n",
|
| 795 |
+
" <td>...</td>\n",
|
| 796 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 797 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 798 |
+
" <td>(02) detached single family house</td>\n",
|
| 799 |
+
" <td>(4) very well kept</td>\n",
|
| 800 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 801 |
+
" <td>(4) above average</td>\n",
|
| 802 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 803 |
+
" <td>(3) not very difficult</td>\n",
|
| 804 |
+
" <td>normal</td>\n",
|
| 805 |
+
" <td>3</td>\n",
|
| 806 |
+
" </tr>\n",
|
| 807 |
+
" <tr>\n",
|
| 808 |
+
" <th>17</th>\n",
|
| 809 |
+
" <td>(1) male</td>\n",
|
| 810 |
+
" <td>58</td>\n",
|
| 811 |
+
" <td>(1) 57-64</td>\n",
|
| 812 |
+
" <td>(4) bachelors</td>\n",
|
| 813 |
+
" <td>(4) bachelors or more</td>\n",
|
| 814 |
+
" <td>(1) white/caucasian</td>\n",
|
| 815 |
+
" <td>(0) no</td>\n",
|
| 816 |
+
" <td>(1) white</td>\n",
|
| 817 |
+
" <td>(0) no</td>\n",
|
| 818 |
+
" <td>(0) no</td>\n",
|
| 819 |
+
" <td>...</td>\n",
|
| 820 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 821 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 822 |
+
" <td>(02) detached single family house</td>\n",
|
| 823 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 824 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 825 |
+
" <td>(3) average</td>\n",
|
| 826 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 827 |
+
" <td>(2) somewhat difficult</td>\n",
|
| 828 |
+
" <td>normal</td>\n",
|
| 829 |
+
" <td>2</td>\n",
|
| 830 |
+
" </tr>\n",
|
| 831 |
+
" <tr>\n",
|
| 832 |
+
" <th>24</th>\n",
|
| 833 |
+
" <td>(1) male</td>\n",
|
| 834 |
+
" <td>79</td>\n",
|
| 835 |
+
" <td>(3) 75-85</td>\n",
|
| 836 |
+
" <td>(5) masters</td>\n",
|
| 837 |
+
" <td>(4) bachelors or more</td>\n",
|
| 838 |
+
" <td>(1) white/caucasian</td>\n",
|
| 839 |
+
" <td>(0) no</td>\n",
|
| 840 |
+
" <td>(1) white</td>\n",
|
| 841 |
+
" <td>(1) yes</td>\n",
|
| 842 |
+
" <td>(0) no</td>\n",
|
| 843 |
+
" <td>...</td>\n",
|
| 844 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 845 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 846 |
+
" <td>(02) detached single family house</td>\n",
|
| 847 |
+
" <td>(4) very well kept</td>\n",
|
| 848 |
+
" <td>(4) very well kept</td>\n",
|
| 849 |
+
" <td>(5) far above average</td>\n",
|
| 850 |
+
" <td>(08) eighth case</td>\n",
|
| 851 |
+
" <td>(2) somewhat difficult</td>\n",
|
| 852 |
+
" <td>normal</td>\n",
|
| 853 |
+
" <td>4</td>\n",
|
| 854 |
+
" </tr>\n",
|
| 855 |
+
" <tr>\n",
|
| 856 |
+
" <th>26</th>\n",
|
| 857 |
+
" <td>(2) female</td>\n",
|
| 858 |
+
" <td>68</td>\n",
|
| 859 |
+
" <td>(2) 65-74</td>\n",
|
| 860 |
+
" <td>(6) law, md or phd</td>\n",
|
| 861 |
+
" <td>(4) bachelors or more</td>\n",
|
| 862 |
+
" <td>(1) white/caucasian</td>\n",
|
| 863 |
+
" <td>(0) no</td>\n",
|
| 864 |
+
" <td>(1) white</td>\n",
|
| 865 |
+
" <td>(0) no</td>\n",
|
| 866 |
+
" <td>(0) no</td>\n",
|
| 867 |
+
" <td>...</td>\n",
|
| 868 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 869 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 870 |
+
" <td>(02) detached single family house</td>\n",
|
| 871 |
+
" <td>(4) very well kept</td>\n",
|
| 872 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 873 |
+
" <td>(4) above average</td>\n",
|
| 874 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 875 |
+
" <td>(2) somewhat difficult</td>\n",
|
| 876 |
+
" <td>normal</td>\n",
|
| 877 |
+
" <td>3</td>\n",
|
| 878 |
+
" </tr>\n",
|
| 879 |
+
" <tr>\n",
|
| 880 |
+
" <th>...</th>\n",
|
| 881 |
+
" <td>...</td>\n",
|
| 882 |
+
" <td>...</td>\n",
|
| 883 |
+
" <td>...</td>\n",
|
| 884 |
+
" <td>...</td>\n",
|
| 885 |
+
" <td>...</td>\n",
|
| 886 |
+
" <td>...</td>\n",
|
| 887 |
+
" <td>...</td>\n",
|
| 888 |
+
" <td>...</td>\n",
|
| 889 |
+
" <td>...</td>\n",
|
| 890 |
+
" <td>...</td>\n",
|
| 891 |
+
" <td>...</td>\n",
|
| 892 |
+
" <td>...</td>\n",
|
| 893 |
+
" <td>...</td>\n",
|
| 894 |
+
" <td>...</td>\n",
|
| 895 |
+
" <td>...</td>\n",
|
| 896 |
+
" <td>...</td>\n",
|
| 897 |
+
" <td>...</td>\n",
|
| 898 |
+
" <td>...</td>\n",
|
| 899 |
+
" <td>...</td>\n",
|
| 900 |
+
" <td>...</td>\n",
|
| 901 |
+
" <td>...</td>\n",
|
| 902 |
+
" </tr>\n",
|
| 903 |
+
" <tr>\n",
|
| 904 |
+
" <th>2988</th>\n",
|
| 905 |
+
" <td>(2) female</td>\n",
|
| 906 |
+
" <td>61</td>\n",
|
| 907 |
+
" <td>(1) 57-64</td>\n",
|
| 908 |
+
" <td>(1) none</td>\n",
|
| 909 |
+
" <td>(1) < hs</td>\n",
|
| 910 |
+
" <td>(1) white/caucasian</td>\n",
|
| 911 |
+
" <td>(1) yes</td>\n",
|
| 912 |
+
" <td>(3) hispanic, non-black</td>\n",
|
| 913 |
+
" <td>(0) no</td>\n",
|
| 914 |
+
" <td>(0) no</td>\n",
|
| 915 |
+
" <td>...</td>\n",
|
| 916 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 917 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 918 |
+
" <td>(01) trailer</td>\n",
|
| 919 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 920 |
+
" <td>(1) very poorly kept (needs major repairs)</td>\n",
|
| 921 |
+
" <td>(4) above average</td>\n",
|
| 922 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 923 |
+
" <td>(3) not very difficult</td>\n",
|
| 924 |
+
" <td>moderatesevere</td>\n",
|
| 925 |
+
" <td>14</td>\n",
|
| 926 |
+
" </tr>\n",
|
| 927 |
+
" <tr>\n",
|
| 928 |
+
" <th>2991</th>\n",
|
| 929 |
+
" <td>(2) female</td>\n",
|
| 930 |
+
" <td>70</td>\n",
|
| 931 |
+
" <td>(2) 65-74</td>\n",
|
| 932 |
+
" <td>(1) none</td>\n",
|
| 933 |
+
" <td>(1) < hs</td>\n",
|
| 934 |
+
" <td>(3) asian, pacific islander, american indian o...</td>\n",
|
| 935 |
+
" <td>(1) yes</td>\n",
|
| 936 |
+
" <td>(3) hispanic, non-black</td>\n",
|
| 937 |
+
" <td>(0) no</td>\n",
|
| 938 |
+
" <td>(0) no</td>\n",
|
| 939 |
+
" <td>...</td>\n",
|
| 940 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 941 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 942 |
+
" <td>(02) detached single family house</td>\n",
|
| 943 |
+
" <td>NaN</td>\n",
|
| 944 |
+
" <td>NaN</td>\n",
|
| 945 |
+
" <td>NaN</td>\n",
|
| 946 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 947 |
+
" <td>(3) not very difficult</td>\n",
|
| 948 |
+
" <td>moderatesevere</td>\n",
|
| 949 |
+
" <td>11</td>\n",
|
| 950 |
+
" </tr>\n",
|
| 951 |
+
" <tr>\n",
|
| 952 |
+
" <th>2992</th>\n",
|
| 953 |
+
" <td>(2) female</td>\n",
|
| 954 |
+
" <td>70</td>\n",
|
| 955 |
+
" <td>(2) 65-74</td>\n",
|
| 956 |
+
" <td>(3) associates</td>\n",
|
| 957 |
+
" <td>(3) voc cert/some college/assoc</td>\n",
|
| 958 |
+
" <td>(1) white/caucasian</td>\n",
|
| 959 |
+
" <td>(0) no</td>\n",
|
| 960 |
+
" <td>(1) white</td>\n",
|
| 961 |
+
" <td>(0) no</td>\n",
|
| 962 |
+
" <td>(0) no</td>\n",
|
| 963 |
+
" <td>...</td>\n",
|
| 964 |
+
" <td>(1) 1 (quiet)</td>\n",
|
| 965 |
+
" <td>(4) 4</td>\n",
|
| 966 |
+
" <td>(02) detached single family house</td>\n",
|
| 967 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 968 |
+
" <td>(3) fairly well kept (needs cosmetic work)</td>\n",
|
| 969 |
+
" <td>(3) average</td>\n",
|
| 970 |
+
" <td>(11) eleventh case or more</td>\n",
|
| 971 |
+
" <td>(4) not at all difficult</td>\n",
|
| 972 |
+
" <td>moderatesevere</td>\n",
|
| 973 |
+
" <td>10</td>\n",
|
| 974 |
+
" </tr>\n",
|
| 975 |
+
" <tr>\n",
|
| 976 |
+
" <th>2993</th>\n",
|
| 977 |
+
" <td>(1) male</td>\n",
|
| 978 |
+
" <td>63</td>\n",
|
| 979 |
+
" <td>(1) 57-64</td>\n",
|
| 980 |
+
" <td>(4) bachelors</td>\n",
|
| 981 |
+
" <td>(4) bachelors or more</td>\n",
|
| 982 |
+
" <td>(1) white/caucasian</td>\n",
|
| 983 |
+
" <td>(0) no</td>\n",
|
| 984 |
+
" <td>(1) white</td>\n",
|
| 985 |
+
" <td>(0) no</td>\n",
|
| 986 |
+
" <td>(0) no</td>\n",
|
| 987 |
+
" <td>...</td>\n",
|
| 988 |
+
" <td>(2) 2</td>\n",
|
| 989 |
+
" <td>(1) 1 (no smell)</td>\n",
|
| 990 |
+
" <td>(02) detached single family house</td>\n",
|
| 991 |
+
" <td>(4) very well kept</td>\n",
|
| 992 |
+
" <td>(4) very well kept</td>\n",
|
| 993 |
+
" <td>(3) average</td>\n",
|
| 994 |
+
" <td>(07) seventh case</td>\n",
|
| 995 |
+
" <td>(4) not at all difficult</td>\n",
|
| 996 |
+
" <td>moderatesevere</td>\n",
|
| 997 |
+
" <td>11</td>\n",
|
| 998 |
+
" </tr>\n",
|
| 999 |
+
" <tr>\n",
|
| 1000 |
+
" <th>3000</th>\n",
|
| 1001 |
+
" <td>(2) female</td>\n",
|
| 1002 |
+
" <td>73</td>\n",
|
| 1003 |
+
" <td>(2) 65-74</td>\n",
|
| 1004 |
+
" <td>(3) associates</td>\n",
|
| 1005 |
+
" <td>(3) voc cert/some college/assoc</td>\n",
|
| 1006 |
+
" <td>(1) white/caucasian</td>\n",
|
| 1007 |
+
" <td>(0) no</td>\n",
|
| 1008 |
+
" <td>(1) white</td>\n",
|
| 1009 |
+
" <td>(0) no</td>\n",
|
| 1010 |
+
" <td>(0) no</td>\n",
|
| 1011 |
+
" <td>...</td>\n",
|
| 1012 |
+
" <td>(2) 2</td>\n",
|
| 1013 |
+
" <td>(2) 2</td>\n",
|
| 1014 |
+
" <td>(02) detached single family house</td>\n",
|
| 1015 |
+
" <td>(4) very well kept</td>\n",
|
| 1016 |
+
" <td>(4) very well kept</td>\n",
|
| 1017 |
+
" <td>(3) average</td>\n",
|
| 1018 |
+
" <td>(05) fifth case</td>\n",
|
| 1019 |
+
" <td>(3) not very difficult</td>\n",
|
| 1020 |
+
" <td>moderatesevere</td>\n",
|
| 1021 |
+
" <td>17</td>\n",
|
| 1022 |
+
" </tr>\n",
|
| 1023 |
+
" </tbody>\n",
|
| 1024 |
+
"</table>\n",
|
| 1025 |
+
"<p>2763 rows × 318 columns</p>\n",
|
| 1026 |
+
"</div>"
|
| 1027 |
+
],
|
| 1028 |
+
"text/plain": [
|
| 1029 |
+
" GENDER AGE AGEGRP DEGREE_RECODE \\\n",
|
| 1030 |
+
"0 (2) female 62 (1) 57-64 (5) masters \n",
|
| 1031 |
+
"1 (2) female 79 (3) 75-85 (2) high school diploma/equivalency \n",
|
| 1032 |
+
"17 (1) male 58 (1) 57-64 (4) bachelors \n",
|
| 1033 |
+
"24 (1) male 79 (3) 75-85 (5) masters \n",
|
| 1034 |
+
"26 (2) female 68 (2) 65-74 (6) law, md or phd \n",
|
| 1035 |
+
"... ... ... ... ... \n",
|
| 1036 |
+
"2988 (2) female 61 (1) 57-64 (1) none \n",
|
| 1037 |
+
"2991 (2) female 70 (2) 65-74 (1) none \n",
|
| 1038 |
+
"2992 (2) female 70 (2) 65-74 (3) associates \n",
|
| 1039 |
+
"2993 (1) male 63 (1) 57-64 (4) bachelors \n",
|
| 1040 |
+
"3000 (2) female 73 (2) 65-74 (3) associates \n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" EDUC \\\n",
|
| 1043 |
+
"0 (4) bachelors or more \n",
|
| 1044 |
+
"1 (3) voc cert/some college/assoc \n",
|
| 1045 |
+
"17 (4) bachelors or more \n",
|
| 1046 |
+
"24 (4) bachelors or more \n",
|
| 1047 |
+
"26 (4) bachelors or more \n",
|
| 1048 |
+
"... ... \n",
|
| 1049 |
+
"2988 (1) < hs \n",
|
| 1050 |
+
"2991 (1) < hs \n",
|
| 1051 |
+
"2992 (3) voc cert/some college/assoc \n",
|
| 1052 |
+
"2993 (4) bachelors or more \n",
|
| 1053 |
+
"3000 (3) voc cert/some college/assoc \n",
|
| 1054 |
+
"\n",
|
| 1055 |
+
" RACE_RECODE HISPANIC \\\n",
|
| 1056 |
+
"0 (1) white/caucasian (0) no \n",
|
| 1057 |
+
"1 (1) white/caucasian (0) no \n",
|
| 1058 |
+
"17 (1) white/caucasian (0) no \n",
|
| 1059 |
+
"24 (1) white/caucasian (0) no \n",
|
| 1060 |
+
"26 (1) white/caucasian (0) no \n",
|
| 1061 |
+
"... ... ... \n",
|
| 1062 |
+
"2988 (1) white/caucasian (1) yes \n",
|
| 1063 |
+
"2991 (3) asian, pacific islander, american indian o... (1) yes \n",
|
| 1064 |
+
"2992 (1) white/caucasian (0) no \n",
|
| 1065 |
+
"2993 (1) white/caucasian (0) no \n",
|
| 1066 |
+
"3000 (1) white/caucasian (0) no \n",
|
| 1067 |
+
"\n",
|
| 1068 |
+
" ETHGRP MILITARY JAIL ... IWLOC5 \\\n",
|
| 1069 |
+
"0 (1) white (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1070 |
+
"1 (1) white (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1071 |
+
"17 (1) white (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1072 |
+
"24 (1) white (1) yes (0) no ... (1) 1 (quiet) \n",
|
| 1073 |
+
"26 (1) white (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1074 |
+
"... ... ... ... ... ... \n",
|
| 1075 |
+
"2988 (3) hispanic, non-black (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1076 |
+
"2991 (3) hispanic, non-black (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1077 |
+
"2992 (1) white (0) no (0) no ... (1) 1 (quiet) \n",
|
| 1078 |
+
"2993 (1) white (0) no (0) no ... (2) 2 \n",
|
| 1079 |
+
"3000 (1) white (0) no (0) no ... (2) 2 \n",
|
| 1080 |
+
"\n",
|
| 1081 |
+
" IWLOC6 STRUCTQ \\\n",
|
| 1082 |
+
"0 (1) 1 (no smell) (02) detached single family house \n",
|
| 1083 |
+
"1 (1) 1 (no smell) (02) detached single family house \n",
|
| 1084 |
+
"17 (1) 1 (no smell) (02) detached single family house \n",
|
| 1085 |
+
"24 (1) 1 (no smell) (02) detached single family house \n",
|
| 1086 |
+
"26 (1) 1 (no smell) (02) detached single family house \n",
|
| 1087 |
+
"... ... ... \n",
|
| 1088 |
+
"2988 (1) 1 (no smell) (01) trailer \n",
|
| 1089 |
+
"2991 (1) 1 (no smell) (02) detached single family house \n",
|
| 1090 |
+
"2992 (4) 4 (02) detached single family house \n",
|
| 1091 |
+
"2993 (1) 1 (no smell) (02) detached single family house \n",
|
| 1092 |
+
"3000 (2) 2 (02) detached single family house \n",
|
| 1093 |
+
"\n",
|
| 1094 |
+
" BUILD \\\n",
|
| 1095 |
+
"0 (4) very well kept \n",
|
| 1096 |
+
"1 (4) very well kept \n",
|
| 1097 |
+
"17 (3) fairly well kept (needs cosmetic work) \n",
|
| 1098 |
+
"24 (4) very well kept \n",
|
| 1099 |
+
"26 (4) very well kept \n",
|
| 1100 |
+
"... ... \n",
|
| 1101 |
+
"2988 (3) fairly well kept (needs cosmetic work) \n",
|
| 1102 |
+
"2991 NaN \n",
|
| 1103 |
+
"2992 (3) fairly well kept (needs cosmetic work) \n",
|
| 1104 |
+
"2993 (4) very well kept \n",
|
| 1105 |
+
"3000 (4) very well kept \n",
|
| 1106 |
+
"\n",
|
| 1107 |
+
" OTBUILD COMBUILD \\\n",
|
| 1108 |
+
"0 (3) fairly well kept (needs cosmetic work) (3) average \n",
|
| 1109 |
+
"1 (3) fairly well kept (needs cosmetic work) (4) above average \n",
|
| 1110 |
+
"17 (3) fairly well kept (needs cosmetic work) (3) average \n",
|
| 1111 |
+
"24 (4) very well kept (5) far above average \n",
|
| 1112 |
+
"26 (3) fairly well kept (needs cosmetic work) (4) above average \n",
|
| 1113 |
+
"... ... ... \n",
|
| 1114 |
+
"2988 (1) very poorly kept (needs major repairs) (4) above average \n",
|
| 1115 |
+
"2991 NaN NaN \n",
|
| 1116 |
+
"2992 (3) fairly well kept (needs cosmetic work) (3) average \n",
|
| 1117 |
+
"2993 (4) very well kept (3) average \n",
|
| 1118 |
+
"3000 (4) very well kept (3) average \n",
|
| 1119 |
+
"\n",
|
| 1120 |
+
" CASECOMP CASEDIF \\\n",
|
| 1121 |
+
"0 (11) eleventh case or more (2) somewhat difficult \n",
|
| 1122 |
+
"1 (11) eleventh case or more (3) not very difficult \n",
|
| 1123 |
+
"17 (11) eleventh case or more (2) somewhat difficult \n",
|
| 1124 |
+
"24 (08) eighth case (2) somewhat difficult \n",
|
| 1125 |
+
"26 (11) eleventh case or more (2) somewhat difficult \n",
|
| 1126 |
+
"... ... ... \n",
|
| 1127 |
+
"2988 (11) eleventh case or more (3) not very difficult \n",
|
| 1128 |
+
"2991 (11) eleventh case or more (3) not very difficult \n",
|
| 1129 |
+
"2992 (11) eleventh case or more (4) not at all difficult \n",
|
| 1130 |
+
"2993 (07) seventh case (4) not at all difficult \n",
|
| 1131 |
+
"3000 (05) fifth case (3) not very difficult \n",
|
| 1132 |
+
"\n",
|
| 1133 |
+
" depression_category total_sum \n",
|
| 1134 |
+
"0 normal 2 \n",
|
| 1135 |
+
"1 normal 3 \n",
|
| 1136 |
+
"17 normal 2 \n",
|
| 1137 |
+
"24 normal 4 \n",
|
| 1138 |
+
"26 normal 3 \n",
|
| 1139 |
+
"... ... ... \n",
|
| 1140 |
+
"2988 moderatesevere 14 \n",
|
| 1141 |
+
"2991 moderatesevere 11 \n",
|
| 1142 |
+
"2992 moderatesevere 10 \n",
|
| 1143 |
+
"2993 moderatesevere 11 \n",
|
| 1144 |
+
"3000 moderatesevere 17 \n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
"[2763 rows x 318 columns]"
|
| 1147 |
+
]
|
| 1148 |
+
},
|
| 1149 |
+
"execution_count": 17,
|
| 1150 |
+
"metadata": {},
|
| 1151 |
+
"output_type": "execute_result"
|
| 1152 |
+
}
|
| 1153 |
+
],
|
| 1154 |
+
"source": [
|
| 1155 |
+
"df_appended"
|
| 1156 |
+
]
|
| 1157 |
+
},
|
| 1158 |
+
{
|
| 1159 |
+
"cell_type": "code",
|
| 1160 |
+
"execution_count": 18,
|
| 1161 |
+
"id": "a25493c0-42b4-4ce4-8fb2-c3579ec28253",
|
| 1162 |
+
"metadata": {},
|
| 1163 |
+
"outputs": [],
|
| 1164 |
+
"source": [
|
| 1165 |
+
"# Export for regression\n",
|
| 1166 |
+
"df_appended.to_csv('3labelv4Regression.csv', index = False)"
|
| 1167 |
+
]
|
| 1168 |
+
},
|
| 1169 |
+
{
|
| 1170 |
+
"cell_type": "code",
|
| 1171 |
+
"execution_count": 19,
|
| 1172 |
+
"id": "4946608b-95e0-4752-b35b-a895d0a85763",
|
| 1173 |
+
"metadata": {},
|
| 1174 |
+
"outputs": [],
|
| 1175 |
+
"source": [
|
| 1176 |
+
"# Export for classification\n",
|
| 1177 |
+
"df_appended_classification = df_appended.drop('total_sum', axis = 1)\n",
|
| 1178 |
+
"df_appended_classification.to_csv('3labelv4Classification.csv', index = False)"
|
| 1179 |
+
]
|
| 1180 |
+
}
|
| 1181 |
+
],
|
| 1182 |
+
"metadata": {
|
| 1183 |
+
"kernelspec": {
|
| 1184 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1185 |
+
"language": "python",
|
| 1186 |
+
"name": "python3"
|
| 1187 |
+
},
|
| 1188 |
+
"language_info": {
|
| 1189 |
+
"codemirror_mode": {
|
| 1190 |
+
"name": "ipython",
|
| 1191 |
+
"version": 3
|
| 1192 |
+
},
|
| 1193 |
+
"file_extension": ".py",
|
| 1194 |
+
"mimetype": "text/x-python",
|
| 1195 |
+
"name": "python",
|
| 1196 |
+
"nbconvert_exporter": "python",
|
| 1197 |
+
"pygments_lexer": "ipython3",
|
| 1198 |
+
"version": "3.12.3"
|
| 1199 |
+
}
|
| 1200 |
+
},
|
| 1201 |
+
"nbformat": 4,
|
| 1202 |
+
"nbformat_minor": 5
|
| 1203 |
+
}
|
detectionLayer.ipynb
ADDED
|
@@ -0,0 +1,1799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "f9fd590d-a3c3-4e94-b5ca-59a6ee9b29c3",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Detection Layer"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "dddcb2a8-73d3-4476-b88a-bc1494a2c830",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import pandas as pd\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"import matplotlib.pyplot as plt\n",
|
| 21 |
+
"import seaborn as sns\n",
|
| 22 |
+
"from sklearn.model_selection import cross_val_score, KFold\n",
|
| 23 |
+
"from sklearn.impute import KNNImputer\n",
|
| 24 |
+
"from sklearn.pipeline import make_pipeline\n",
|
| 25 |
+
"from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, HistGradientBoostingClassifier\n",
|
| 26 |
+
"from xgboost import XGBClassifier\n",
|
| 27 |
+
"from sklearn.impute import SimpleImputer\n",
|
| 28 |
+
"from sklearn.experimental import enable_iterative_imputer\n",
|
| 29 |
+
"from imblearn.pipeline import make_pipeline as make_pipeline_imb\n",
|
| 30 |
+
"from imblearn.over_sampling import SMOTE,SMOTENC\n",
|
| 31 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 32 |
+
"from collections import Counter\n",
|
| 33 |
+
"from sklearn.metrics import classification_report, accuracy_score, confusion_matrix\n",
|
| 34 |
+
"from sklearn.ensemble import GradientBoostingClassifier\n",
|
| 35 |
+
"from sklearn.ensemble import VotingClassifier\n",
|
| 36 |
+
"from sklearn.svm import SVC\n",
|
| 37 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
| 38 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
| 39 |
+
"from sklearn.ensemble import BaggingClassifier\n",
|
| 40 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
| 41 |
+
"from sklearn.ensemble import ExtraTreesClassifier\n",
|
| 42 |
+
"from deslib.dcs import APosteriori\n",
|
| 43 |
+
"from deslib.des import KNORAE, KNORAU, KNOP, DESMI\n",
|
| 44 |
+
"from sklearn.neighbors import LocalOutlierFactor\n",
|
| 45 |
+
"from sklearn.utils import resample\n",
|
| 46 |
+
"import warnings\n",
|
| 47 |
+
"from imblearn.over_sampling import RandomOverSampler\n",
|
| 48 |
+
"from imblearn.under_sampling import RandomUnderSampler\n",
|
| 49 |
+
"from sklearn.preprocessing import MinMaxScaler\n",
|
| 50 |
+
"from imblearn.pipeline import Pipeline\n",
|
| 51 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
| 52 |
+
"from sklearn.preprocessing import LabelEncoder, PowerTransformer\n",
|
| 53 |
+
"from collections import defaultdict\n",
|
| 54 |
+
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier, VotingClassifier\n",
|
| 55 |
+
"from catboost import CatBoostClassifier\n",
|
| 56 |
+
"from lightgbm import LGBMClassifier\n",
|
| 57 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc\n",
|
| 58 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
| 59 |
+
"from sklearn.neural_network import MLPClassifier\n",
|
| 60 |
+
"import Orange\n",
|
| 61 |
+
"from scipy.stats import friedmanchisquare, rankdata\n",
|
| 62 |
+
"import shap\n",
|
| 63 |
+
"import scikit_posthocs as sp\n",
|
| 64 |
+
"from sklearn.feature_selection import SelectFromModel\n",
|
| 65 |
+
"from IPython.display import FileLink, display\n",
|
| 66 |
+
"import math\n",
|
| 67 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 68 |
+
"from skopt.space import Integer, Real\n",
|
| 69 |
+
"from sklearn.model_selection import StratifiedKFold\n",
|
| 70 |
+
"from skopt import BayesSearchCV\n",
|
| 71 |
+
"import xgboost as xgb\n",
|
| 72 |
+
"from imblearn.over_sampling import SMOTE\n",
|
| 73 |
+
"from sklearn.tree import DecisionTreeClassifier, export_text\n",
|
| 74 |
+
"from sklearn import tree\n",
|
| 75 |
+
"from skopt.space import Real, Integer, Categorical\n",
|
| 76 |
+
"from skopt.callbacks import VerboseCallback\n",
|
| 77 |
+
"from deslib.des.knora_e import KNORAE\n",
|
| 78 |
+
"from deslib.des.knora_u import KNORAU\n",
|
| 79 |
+
"from deslib.des.knop import KNOP\n",
|
| 80 |
+
"from deslib.des.meta_des import METADES\n",
|
| 81 |
+
"from deslib.des.des_knn import DESKNN\n",
|
| 82 |
+
"from deslib.des.des_p import DESP"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"id": "39d6683c-fd3f-4daa-afdc-2e7f83c3fce3",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"### Preparation before training"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"id": "ec17e47a-8d92-498d-8f28-d8259d6ebc4e",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"# Call Dataset\n",
|
| 101 |
+
"pd.set_option('display.max_rows', 10)\n",
|
| 102 |
+
"initial_df = pd.read_csv('3labelv4Classification.csv')\n",
|
| 103 |
+
"initial_df.info()"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "code",
|
| 108 |
+
"execution_count": null,
|
| 109 |
+
"id": "d3751352-73bc-42c6-b6a5-84a3badf14ff",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"# All categorical features except for label\n",
|
| 114 |
+
"cols = initial_df.columns\n",
|
| 115 |
+
"num_cols = initial_df._get_numeric_data().columns\n",
|
| 116 |
+
"categorical_features = list(set(cols) - set(num_cols))\n",
|
| 117 |
+
"categorical_features.remove('depression_category')\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"# Label Encode all categorical, but keep missing values\n",
|
| 120 |
+
"le_initial_df = initial_df.copy()\n",
|
| 121 |
+
"dropped_labels = le_initial_df['depression_category']\n",
|
| 122 |
+
"le_initial_df = le_initial_df.drop('depression_category', axis = 1)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"for col in le_initial_df.columns:\n",
|
| 125 |
+
" if le_initial_df[col].dtype == 'object':\n",
|
| 126 |
+
" le_initial_df[col] = le_initial_df[col].fillna('missing')\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" label_encoder = LabelEncoder()\n",
|
| 129 |
+
" le_initial_df[col] = label_encoder.fit_transform(le_initial_df[col])\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" missing_value_index = np.where(label_encoder.classes_ == 'missing')[0]\n",
|
| 132 |
+
" \n",
|
| 133 |
+
" le_initial_df[col] = le_initial_df[col].replace(missing_value_index, np.nan)\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"le_initial_df = pd.concat([le_initial_df, dropped_labels], axis = 1)"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": null,
|
| 141 |
+
"id": "9fa21a95-21f1-4274-86ba-d7977813066b",
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"outputs": [],
|
| 144 |
+
"source": [
|
| 145 |
+
"le_initial_df"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"id": "0ba54a81-d4de-4884-99d3-29867eb7ea40",
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"# Seperate and Combine\n",
|
| 156 |
+
"le_df_normal = le_initial_df[le_initial_df['depression_category'] == 'normal']\n",
|
| 157 |
+
"le_df_mild = le_initial_df[le_initial_df['depression_category'] == 'mild']\n",
|
| 158 |
+
"le_df_moderatesevere = le_initial_df[le_initial_df['depression_category'] == 'moderatesevere']\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"le_df_depression = pd.concat([le_df_mild, le_df_moderatesevere], ignore_index = False)\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"le_df_depression['depression_category'] = 'depression'\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"# Check depression category counts\n",
|
| 165 |
+
"dataframes = [le_df_normal, le_df_depression]\n",
|
| 166 |
+
"le_initial_df = pd.concat(dataframes, ignore_index=True)\n",
|
| 167 |
+
"label_counts = le_initial_df['depression_category'].value_counts()\n",
|
| 168 |
+
"label_counts"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"id": "c517213f-f6bb-4298-99ee-3b29f6f7d0cb",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [],
|
| 177 |
+
"source": [
|
| 178 |
+
"# Some outlier\n",
|
| 179 |
+
"threshold = int(0.8 * le_df_normal.shape[1])\n",
|
| 180 |
+
"le_df_normal = le_df_normal.dropna(thresh = threshold)\n",
|
| 181 |
+
"threshold = int(0.8 * le_df_depression.shape[1])\n",
|
| 182 |
+
"le_df_depression = le_df_depression.dropna(thresh = threshold)\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"# Check depression category counts\n",
|
| 185 |
+
"dataframes = [le_df_normal, le_df_depression]\n",
|
| 186 |
+
"le_initial_df = pd.concat(dataframes, ignore_index=True)\n",
|
| 187 |
+
"label_counts = le_initial_df['depression_category'].value_counts()"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "code",
|
| 192 |
+
"execution_count": null,
|
| 193 |
+
"id": "8c5bdf52-c645-4a11-920a-579e28db1a50",
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"outputs": [],
|
| 196 |
+
"source": [
|
| 197 |
+
"# Imputation\n",
|
| 198 |
+
"different_le_dfs = [le_df_normal, le_df_depression]\n",
|
| 199 |
+
"imputed_le_dfs = []\n",
|
| 200 |
+
"from sklearn.impute import IterativeImputer\n",
|
| 201 |
+
"for le_df in different_le_dfs:\n",
|
| 202 |
+
" y = le_df['depression_category']\n",
|
| 203 |
+
" X = le_df.drop('depression_category', axis = 1)\n",
|
| 204 |
+
" \n",
|
| 205 |
+
" imputer = SimpleImputer(strategy='median')\n",
|
| 206 |
+
" imputed_data = imputer.fit_transform(X)\n",
|
| 207 |
+
" imputed_df = pd.DataFrame(imputed_data, columns = X.columns)\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" imputed_df['depression_category'] = y.reset_index(drop = True)\n",
|
| 210 |
+
" imputed_le_dfs.append(imputed_df)\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"concatenated_le_dfs = pd.concat(imputed_le_dfs, ignore_index = True)\n",
|
| 213 |
+
"concatenated_le_dfs"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": null,
|
| 219 |
+
"id": "0e871a5f-beab-4f90-b8c4-e922ef86d0f0",
|
| 220 |
+
"metadata": {},
|
| 221 |
+
"outputs": [],
|
| 222 |
+
"source": [
|
| 223 |
+
"# Full label encode depression category\n",
|
| 224 |
+
"fully_LE_concatenated_le_dfs = concatenated_le_dfs.copy()\n",
|
| 225 |
+
"fully_LE_concatenated_le_dfs['depression_category'] = label_encoder.fit_transform(fully_LE_concatenated_le_dfs['depression_category'])\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"# The dataset after category connect, imputation, and label encoding\n",
|
| 228 |
+
"splitted_dataset = fully_LE_concatenated_le_dfs.copy()\n",
|
| 229 |
+
"splitted_dataset"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "markdown",
|
| 234 |
+
"id": "198fdc3c-ccbc-4ac5-8621-c0c5dc771cbb",
|
| 235 |
+
"metadata": {},
|
| 236 |
+
"source": [
|
| 237 |
+
"### Setup for training"
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"cell_type": "code",
|
| 242 |
+
"execution_count": null,
|
| 243 |
+
"id": "101c5549-7bc3-4573-867b-f09776e254db",
|
| 244 |
+
"metadata": {
|
| 245 |
+
"jupyter": {
|
| 246 |
+
"source_hidden": true
|
| 247 |
+
}
|
| 248 |
+
},
|
| 249 |
+
"outputs": [],
|
| 250 |
+
"source": [
|
| 251 |
+
"def plot_combined_roc_curve(roc_curves, classifier_names):\n",
|
| 252 |
+
" plt.figure(figsize=(12, 8))\n",
|
| 253 |
+
" mean_fpr = np.linspace(0, 1, 100)\n",
|
| 254 |
+
" colors = plt.cm.get_cmap('tab20', len(classifier_names))\n",
|
| 255 |
+
" \n",
|
| 256 |
+
" for i, clf_name in enumerate(classifier_names):\n",
|
| 257 |
+
" tprs = []\n",
|
| 258 |
+
" for fpr, tpr in roc_curves[clf_name]:\n",
|
| 259 |
+
" tprs.append(np.interp(mean_fpr, fpr, tpr))\n",
|
| 260 |
+
" mean_tpr = np.mean(tprs, axis=0)\n",
|
| 261 |
+
" mean_tpr[-1] = 1.0\n",
|
| 262 |
+
" mean_auc = auc(mean_fpr, mean_tpr)\n",
|
| 263 |
+
" plt.plot(mean_fpr, mean_tpr, color=colors(i), lw=2, linestyle='-', marker='o', markersize=4, \n",
|
| 264 |
+
" label=f'{clf_name} (AUC = {mean_auc:.3f})')\n",
|
| 265 |
+
"\n",
|
| 266 |
+
" plt.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')\n",
|
| 267 |
+
" plt.xlim([0.0, 1.0])\n",
|
| 268 |
+
" plt.ylim([0.0, 1.05])\n",
|
| 269 |
+
" plt.xlabel('False Positive Rate', fontsize=26)\n",
|
| 270 |
+
" plt.ylabel('True Positive Rate', fontsize=26)\n",
|
| 271 |
+
" plt.xticks(fontsize=30)\n",
|
| 272 |
+
" plt.yticks(fontsize=30)\n",
|
| 273 |
+
" plt.legend(loc=\"lower right\", fontsize=22, frameon=True, framealpha=0.9)\n",
|
| 274 |
+
" plt.grid(True)\n",
|
| 275 |
+
"\n",
|
| 276 |
+
" filename='bonk.svg'\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" plt.savefig(filename, format='svg')\n",
|
| 279 |
+
" plt.show()\n",
|
| 280 |
+
"\n",
|
| 281 |
+
" display(FileLink(filename))\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"# Preparation code to make CD diagram from older version of Orange\n",
|
| 284 |
+
"def compute_CD(avranks, n, alpha=\"0.05\", test=\"nemenyi\"):\n",
|
| 285 |
+
" \"\"\"\n",
|
| 286 |
+
" Returns critical difference for Nemenyi or Bonferroni-Dunn test\n",
|
| 287 |
+
" according to given alpha (either alpha=\"0.05\" or alpha=\"0.1\") for average\n",
|
| 288 |
+
" ranks and number of tested datasets N. Test can be either \"nemenyi\" for\n",
|
| 289 |
+
" for Nemenyi two tailed test or \"bonferroni-dunn\" for Bonferroni-Dunn test.\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" This function is deprecated and will be removed in Orange 3.34.\n",
|
| 292 |
+
" \"\"\"\n",
|
| 293 |
+
" k = len(avranks)\n",
|
| 294 |
+
" d = {(\"nemenyi\", \"0.05\"): [0, 0, 1.959964, 2.343701, 2.569032, 2.727774,\n",
|
| 295 |
+
" 2.849705, 2.94832, 3.030879, 3.101730, 3.163684,\n",
|
| 296 |
+
" 3.218654, 3.268004, 3.312739, 3.353618, 3.39123,\n",
|
| 297 |
+
" 3.426041, 3.458425, 3.488685, 3.517073,\n",
|
| 298 |
+
" 3.543799],\n",
|
| 299 |
+
" (\"nemenyi\", \"0.1\"): [0, 0, 1.644854, 2.052293, 2.291341, 2.459516,\n",
|
| 300 |
+
" 2.588521, 2.692732, 2.779884, 2.854606, 2.919889,\n",
|
| 301 |
+
" 2.977768, 3.029694, 3.076733, 3.119693, 3.159199,\n",
|
| 302 |
+
" 3.195743, 3.229723, 3.261461, 3.291224, 3.319233],\n",
|
| 303 |
+
" (\"bonferroni-dunn\", \"0.05\"): [0, 0, 1.960, 2.241, 2.394, 2.498, 2.576,\n",
|
| 304 |
+
" 2.638, 2.690, 2.724, 2.773],\n",
|
| 305 |
+
" (\"bonferroni-dunn\", \"0.1\"): [0, 0, 1.645, 1.960, 2.128, 2.241, 2.326,\n",
|
| 306 |
+
" 2.394, 2.450, 2.498, 2.539]}\n",
|
| 307 |
+
" q = d[(test, alpha)]\n",
|
| 308 |
+
" cd = q[k] * (k * (k + 1) / (6.0 * n)) ** 0.5\n",
|
| 309 |
+
" return cd\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"def graph_ranks(avranks, names, cd=None, cdmethod=None, lowv=None, highv=None,\n",
|
| 313 |
+
" width=6, textspace=1, reverse=False, filename=None, **kwargs):\n",
|
| 314 |
+
" \"\"\"\n",
|
| 315 |
+
" Draws a CD graph, which is used to display the differences in methods'\n",
|
| 316 |
+
" performance. See Janez Demsar, Statistical Comparisons of Classifiers over\n",
|
| 317 |
+
" Multiple Data Sets, 7(Jan):1--30, 2006.\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" Needs matplotlib to work.\n",
|
| 320 |
+
"\n",
|
| 321 |
+
" The image is ploted on `plt` imported using\n",
|
| 322 |
+
" `import matplotlib.pyplot as plt`.\n",
|
| 323 |
+
"\n",
|
| 324 |
+
" This function is deprecated and will be removed in Orange 3.34.\n",
|
| 325 |
+
"\n",
|
| 326 |
+
" Args:\n",
|
| 327 |
+
" avranks (list of float): average ranks of methods.\n",
|
| 328 |
+
" names (list of str): names of methods.\n",
|
| 329 |
+
" cd (float): Critical difference used for statistically significance of\n",
|
| 330 |
+
" difference between methods.\n",
|
| 331 |
+
" cdmethod (int, optional): the method that is compared with other methods\n",
|
| 332 |
+
" If omitted, show pairwise comparison of methods\n",
|
| 333 |
+
" lowv (int, optional): the lowest shown rank\n",
|
| 334 |
+
" highv (int, optional): the highest shown rank\n",
|
| 335 |
+
" width (int, optional): default width in inches (default: 6)\n",
|
| 336 |
+
" textspace (int, optional): space on figure sides (in inches) for the\n",
|
| 337 |
+
" method names (default: 1)\n",
|
| 338 |
+
" reverse (bool, optional): if set to `True`, the lowest rank is on the\n",
|
| 339 |
+
" right (default: `False`)\n",
|
| 340 |
+
" filename (str, optional): output file name (with extension). If not\n",
|
| 341 |
+
" given, the function does not write a file.\n",
|
| 342 |
+
" \"\"\"\n",
|
| 343 |
+
" try:\n",
|
| 344 |
+
" import matplotlib.pyplot as plt\n",
|
| 345 |
+
" from matplotlib.backends.backend_agg import FigureCanvasAgg\n",
|
| 346 |
+
" except ImportError:\n",
|
| 347 |
+
" raise ImportError(\"Function graph_ranks requires matplotlib.\")\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" width = float(width)\n",
|
| 350 |
+
" textspace = float(textspace)\n",
|
| 351 |
+
"\n",
|
| 352 |
+
" def nth(l, n):\n",
|
| 353 |
+
" \"\"\"\n",
|
| 354 |
+
" Returns only nth elemnt in a list.\n",
|
| 355 |
+
" \"\"\"\n",
|
| 356 |
+
" n = lloc(l, n)\n",
|
| 357 |
+
" return [a[n] for a in l]\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" def lloc(l, n):\n",
|
| 360 |
+
" \"\"\"\n",
|
| 361 |
+
" List location in list of list structure.\n",
|
| 362 |
+
" Enable the use of negative locations:\n",
|
| 363 |
+
" -1 is the last element, -2 second last...\n",
|
| 364 |
+
" \"\"\"\n",
|
| 365 |
+
" if n < 0:\n",
|
| 366 |
+
" return len(l[0]) + n\n",
|
| 367 |
+
" else:\n",
|
| 368 |
+
" return n\n",
|
| 369 |
+
"\n",
|
| 370 |
+
" def mxrange(lr):\n",
|
| 371 |
+
" \"\"\"\n",
|
| 372 |
+
" Multiple xranges. Can be used to traverse matrices.\n",
|
| 373 |
+
" This function is very slow due to unknown number of\n",
|
| 374 |
+
" parameters.\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" >>> mxrange([3,5])\n",
|
| 377 |
+
" [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]\n",
|
| 378 |
+
"\n",
|
| 379 |
+
" >>> mxrange([[3,5,1],[9,0,-3]])\n",
|
| 380 |
+
" [(3, 9), (3, 6), (3, 3), (4, 9), (4, 6), (4, 3)]\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" \"\"\"\n",
|
| 383 |
+
" if not len(lr):\n",
|
| 384 |
+
" yield ()\n",
|
| 385 |
+
" else:\n",
|
| 386 |
+
" # it can work with single numbers\n",
|
| 387 |
+
" index = lr[0]\n",
|
| 388 |
+
" if isinstance(index, int):\n",
|
| 389 |
+
" index = [index]\n",
|
| 390 |
+
" for a in range(*index):\n",
|
| 391 |
+
" for b in mxrange(lr[1:]):\n",
|
| 392 |
+
" yield tuple([a] + list(b))\n",
|
| 393 |
+
"\n",
|
| 394 |
+
" def print_figure(fig, *args, **kwargs):\n",
|
| 395 |
+
" canvas = FigureCanvasAgg(fig)\n",
|
| 396 |
+
" canvas.print_figure(*args, **kwargs)\n",
|
| 397 |
+
"\n",
|
| 398 |
+
" sums = avranks\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" tempsort = sorted([(a, i) for i, a in enumerate(sums)], reverse=reverse)\n",
|
| 401 |
+
" ssums = nth(tempsort, 0)\n",
|
| 402 |
+
" sortidx = nth(tempsort, 1)\n",
|
| 403 |
+
" nnames = [names[x] for x in sortidx]\n",
|
| 404 |
+
"\n",
|
| 405 |
+
" if lowv is None:\n",
|
| 406 |
+
" lowv = min(1, int(math.floor(min(ssums))))\n",
|
| 407 |
+
" if highv is None:\n",
|
| 408 |
+
" highv = max(len(avranks), int(math.ceil(max(ssums))))\n",
|
| 409 |
+
"\n",
|
| 410 |
+
" cline = 0.4\n",
|
| 411 |
+
"\n",
|
| 412 |
+
" k = len(sums)\n",
|
| 413 |
+
"\n",
|
| 414 |
+
" lines = None\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" linesblank = 0\n",
|
| 417 |
+
" scalewidth = width - 2 * textspace\n",
|
| 418 |
+
"\n",
|
| 419 |
+
" def rankpos(rank):\n",
|
| 420 |
+
" if not reverse:\n",
|
| 421 |
+
" a = rank - lowv\n",
|
| 422 |
+
" else:\n",
|
| 423 |
+
" a = highv - rank\n",
|
| 424 |
+
" return textspace + scalewidth / (highv - lowv) * a\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" distanceh = 0.25\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" if cd and cdmethod is None:\n",
|
| 429 |
+
" # get pairs of non significant methods\n",
|
| 430 |
+
"\n",
|
| 431 |
+
" def get_lines(sums, hsd):\n",
|
| 432 |
+
" # get all pairs\n",
|
| 433 |
+
" lsums = len(sums)\n",
|
| 434 |
+
" allpairs = [(i, j) for i, j in mxrange([[lsums], [lsums]]) if j > i]\n",
|
| 435 |
+
" # remove not significant\n",
|
| 436 |
+
" notSig = [(i, j) for i, j in allpairs\n",
|
| 437 |
+
" if abs(sums[i] - sums[j]) <= hsd]\n",
|
| 438 |
+
" # keep only longest\n",
|
| 439 |
+
"\n",
|
| 440 |
+
" def no_longer(ij_tuple, notSig):\n",
|
| 441 |
+
" i, j = ij_tuple\n",
|
| 442 |
+
" for i1, j1 in notSig:\n",
|
| 443 |
+
" if (i1 <= i and j1 > j) or (i1 < i and j1 >= j):\n",
|
| 444 |
+
" return False\n",
|
| 445 |
+
" return True\n",
|
| 446 |
+
"\n",
|
| 447 |
+
" longest = [(i, j) for i, j in notSig if no_longer((i, j), notSig)]\n",
|
| 448 |
+
"\n",
|
| 449 |
+
" return longest\n",
|
| 450 |
+
"\n",
|
| 451 |
+
" lines = get_lines(ssums, cd)\n",
|
| 452 |
+
" linesblank = 0.2 + 0.2 + (len(lines) - 1) * 0.1\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" # add scale\n",
|
| 455 |
+
" distanceh = 0.25\n",
|
| 456 |
+
" cline += distanceh\n",
|
| 457 |
+
"\n",
|
| 458 |
+
" # calculate height needed height of an image\n",
|
| 459 |
+
" minnotsignificant = max(2 * 0.2, linesblank)\n",
|
| 460 |
+
" height = cline + ((k + 1) / 2) * 0.2 + minnotsignificant\n",
|
| 461 |
+
"\n",
|
| 462 |
+
" fig = plt.figure(figsize=(width, height))\n",
|
| 463 |
+
" fig.set_facecolor('white')\n",
|
| 464 |
+
" ax = fig.add_axes([0, 0, 1, 1]) # reverse y axis\n",
|
| 465 |
+
" ax.set_axis_off()\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" hf = 1. / height # height factor\n",
|
| 468 |
+
" wf = 1. / width\n",
|
| 469 |
+
"\n",
|
| 470 |
+
" def hfl(l):\n",
|
| 471 |
+
" return [a * hf for a in l]\n",
|
| 472 |
+
"\n",
|
| 473 |
+
" def wfl(l):\n",
|
| 474 |
+
" return [a * wf for a in l]\n",
|
| 475 |
+
"\n",
|
| 476 |
+
"\n",
|
| 477 |
+
" # Upper left corner is (0,0).\n",
|
| 478 |
+
" ax.plot([0, 1], [0, 1], c=\"w\")\n",
|
| 479 |
+
" ax.set_xlim(0, 1)\n",
|
| 480 |
+
" ax.set_ylim(1, 0)\n",
|
| 481 |
+
"\n",
|
| 482 |
+
" def line(l, color='k', **kwargs):\n",
|
| 483 |
+
" \"\"\"\n",
|
| 484 |
+
" Input is a list of pairs of points.\n",
|
| 485 |
+
" \"\"\"\n",
|
| 486 |
+
" ax.plot(wfl(nth(l, 0)), hfl(nth(l, 1)), color=color, **kwargs)\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" def text(x, y, s, *args, **kwargs):\n",
|
| 489 |
+
" ax.text(wf * x, hf * y, s, fontsize = 14, *args, **kwargs)\n",
|
| 490 |
+
"\n",
|
| 491 |
+
" line([(textspace, cline), (width - textspace, cline)], linewidth=0.7)\n",
|
| 492 |
+
"\n",
|
| 493 |
+
" bigtick = 0.1\n",
|
| 494 |
+
" smalltick = 0.05\n",
|
| 495 |
+
"\n",
|
| 496 |
+
" tick = None\n",
|
| 497 |
+
" for a in list(np.arange(lowv, highv, 0.5)) + [highv]:\n",
|
| 498 |
+
" tick = smalltick\n",
|
| 499 |
+
" if a == int(a):\n",
|
| 500 |
+
" tick = bigtick\n",
|
| 501 |
+
" line([(rankpos(a), cline - tick / 2),\n",
|
| 502 |
+
" (rankpos(a), cline)],\n",
|
| 503 |
+
" linewidth=0.7)\n",
|
| 504 |
+
"\n",
|
| 505 |
+
" for a in range(lowv, highv + 1):\n",
|
| 506 |
+
" text(rankpos(a), cline - tick / 2 - 0.05, str(a),\n",
|
| 507 |
+
" ha=\"center\", va=\"bottom\")\n",
|
| 508 |
+
"\n",
|
| 509 |
+
" k = len(ssums)\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" for i in range(math.ceil(k / 2)):\n",
|
| 512 |
+
" chei = cline + minnotsignificant + i * 0.2\n",
|
| 513 |
+
" line([(rankpos(ssums[i]), cline),\n",
|
| 514 |
+
" (rankpos(ssums[i]), chei),\n",
|
| 515 |
+
" (textspace - 0.1, chei)],\n",
|
| 516 |
+
" linewidth=0.7)\n",
|
| 517 |
+
" text(textspace - 0.2, chei, nnames[i], ha=\"right\", va=\"center\")\n",
|
| 518 |
+
"\n",
|
| 519 |
+
" for i in range(math.ceil(k / 2), k):\n",
|
| 520 |
+
" chei = cline + minnotsignificant + (k - i - 1) * 0.2\n",
|
| 521 |
+
" line([(rankpos(ssums[i]), cline),\n",
|
| 522 |
+
" (rankpos(ssums[i]), chei),\n",
|
| 523 |
+
" (textspace + scalewidth + 0.1, chei)],\n",
|
| 524 |
+
" linewidth=0.7)\n",
|
| 525 |
+
" text(textspace + scalewidth + 0.2, chei, nnames[i],\n",
|
| 526 |
+
" ha=\"left\", va=\"center\")\n",
|
| 527 |
+
"\n",
|
| 528 |
+
" if cd and cdmethod is None:\n",
|
| 529 |
+
" # upper scale\n",
|
| 530 |
+
" if not reverse:\n",
|
| 531 |
+
" begin, end = rankpos(lowv), rankpos(lowv + cd)\n",
|
| 532 |
+
" else:\n",
|
| 533 |
+
" begin, end = rankpos(highv), rankpos(highv - cd)\n",
|
| 534 |
+
"\n",
|
| 535 |
+
" line([(begin, distanceh), (end, distanceh)], linewidth=0.7)\n",
|
| 536 |
+
" line([(begin, distanceh + bigtick / 2),\n",
|
| 537 |
+
" (begin, distanceh - bigtick / 2)],\n",
|
| 538 |
+
" linewidth=0.7)\n",
|
| 539 |
+
" line([(end, distanceh + bigtick / 2),\n",
|
| 540 |
+
" (end, distanceh - bigtick / 2)],\n",
|
| 541 |
+
" linewidth=0.7)\n",
|
| 542 |
+
" text((begin + end) / 2, distanceh - 0.05, \"CD\",\n",
|
| 543 |
+
" ha=\"center\", va=\"bottom\")\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" # no-significance lines\n",
|
| 546 |
+
" def draw_lines(lines, side=0.05, height=0.1):\n",
|
| 547 |
+
" start = cline + 0.2\n",
|
| 548 |
+
" for l, r in lines:\n",
|
| 549 |
+
" line([(rankpos(ssums[l]) - side, start),\n",
|
| 550 |
+
" (rankpos(ssums[r]) + side, start)],\n",
|
| 551 |
+
" linewidth=2.5)\n",
|
| 552 |
+
" start += height\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" draw_lines(lines)\n",
|
| 555 |
+
"\n",
|
| 556 |
+
" elif cd:\n",
|
| 557 |
+
" begin = rankpos(avranks[cdmethod] - cd)\n",
|
| 558 |
+
" end = rankpos(avranks[cdmethod] + cd)\n",
|
| 559 |
+
" line([(begin, cline), (end, cline)],\n",
|
| 560 |
+
" linewidth=2.5)\n",
|
| 561 |
+
" line([(begin, cline + bigtick / 2),\n",
|
| 562 |
+
" (begin, cline - bigtick / 2)],\n",
|
| 563 |
+
" linewidth=2.5)\n",
|
| 564 |
+
" line([(end, cline + bigtick / 2),\n",
|
| 565 |
+
" (end, cline - bigtick / 2)],\n",
|
| 566 |
+
" linewidth=2.5)\n",
|
| 567 |
+
"\n",
|
| 568 |
+
" if filename:\n",
|
| 569 |
+
" print_figure(fig, filename, **kwargs)\n",
|
| 570 |
+
"\n",
|
| 571 |
+
"def train_evaluate_model(clf, X_train, y_train, X_test, y_test, clf_name='Classifier'):\n",
|
| 572 |
+
" clf.fit(X_train, y_train)\n",
|
| 573 |
+
" y_pred = clf.predict(X_test)\n",
|
| 574 |
+
" \n",
|
| 575 |
+
" accuracy = accuracy_score(y_test, y_pred)\n",
|
| 576 |
+
" precision = precision_score(y_test, y_pred, average='weighted')\n",
|
| 577 |
+
" recall = recall_score(y_test, y_pred, average='weighted')\n",
|
| 578 |
+
" f1 = f1_score(y_test, y_pred, average='weighted')\n",
|
| 579 |
+
" conf_matrix = confusion_matrix(y_test, y_pred)\n",
|
| 580 |
+
" \n",
|
| 581 |
+
" if hasattr(clf, 'predict_proba'):\n",
|
| 582 |
+
" y_score = clf.predict_proba(X_test)[:, 1]\n",
|
| 583 |
+
" else:\n",
|
| 584 |
+
" y_score = clf.decision_function(X_test)\n",
|
| 585 |
+
" \n",
|
| 586 |
+
" fpr, tpr, _ = roc_curve(y_test, y_score)\n",
|
| 587 |
+
" roc_auc = auc(fpr, tpr)\n",
|
| 588 |
+
" \n",
|
| 589 |
+
" print(f'{clf_name} - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')\n",
|
| 590 |
+
" return accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc"
|
| 591 |
+
]
|
| 592 |
+
},
|
| 593 |
+
{
|
| 594 |
+
"cell_type": "code",
|
| 595 |
+
"execution_count": null,
|
| 596 |
+
"id": "3e602f4d-efb2-4ac9-99bb-c6fbeca791cc",
|
| 597 |
+
"metadata": {
|
| 598 |
+
"jupyter": {
|
| 599 |
+
"source_hidden": true
|
| 600 |
+
}
|
| 601 |
+
},
|
| 602 |
+
"outputs": [],
|
| 603 |
+
"source": [
|
| 604 |
+
"warnings.filterwarnings('ignore')"
|
| 605 |
+
]
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"cell_type": "markdown",
|
| 609 |
+
"id": "773b2524-af11-464f-97a9-f4add85be0d2",
|
| 610 |
+
"metadata": {},
|
| 611 |
+
"source": [
|
| 612 |
+
"### Training (classic/static)\n",
|
| 613 |
+
"In order to run classical/static, make sure to uncomment the one you need. \"Post Training\" is after one of these classical/static is done."
|
| 614 |
+
]
|
| 615 |
+
},
|
| 616 |
+
{
|
| 617 |
+
"cell_type": "markdown",
|
| 618 |
+
"id": "f5b8873c-13d3-43b7-8902-14b73e8d5409",
|
| 619 |
+
"metadata": {},
|
| 620 |
+
"source": [
|
| 621 |
+
"#### Classical Classifiers"
|
| 622 |
+
]
|
| 623 |
+
},
|
| 624 |
+
{
|
| 625 |
+
"cell_type": "code",
|
| 626 |
+
"execution_count": null,
|
| 627 |
+
"id": "264648a0-b0d8-4a63-9167-cea3a4ed1e18",
|
| 628 |
+
"metadata": {},
|
| 629 |
+
"outputs": [],
|
| 630 |
+
"source": [
|
| 631 |
+
"# Optimized Classifiers\n",
|
| 632 |
+
"# classifiers = {\n",
|
| 633 |
+
"# 'DT': DecisionTreeClassifier(\n",
|
| 634 |
+
"# random_state=0, \n",
|
| 635 |
+
"# criterion='gini', \n",
|
| 636 |
+
"# max_depth=6, \n",
|
| 637 |
+
"# min_samples_leaf=10, \n",
|
| 638 |
+
"# min_samples_split=9\n",
|
| 639 |
+
"# ),\n",
|
| 640 |
+
"# 'LR': LogisticRegression(\n",
|
| 641 |
+
"# random_state=0, \n",
|
| 642 |
+
"# C=0.09659168435718246, \n",
|
| 643 |
+
"# max_iter=100, \n",
|
| 644 |
+
"# solver='lbfgs'\n",
|
| 645 |
+
"# ),\n",
|
| 646 |
+
"# 'NB': GaussianNB(\n",
|
| 647 |
+
"# var_smoothing=0.0058873326349240295\n",
|
| 648 |
+
"# ),\n",
|
| 649 |
+
"# 'KN': KNeighborsClassifier(\n",
|
| 650 |
+
"# metric='manhattan', \n",
|
| 651 |
+
"# n_neighbors=8, \n",
|
| 652 |
+
"# weights='uniform'\n",
|
| 653 |
+
"# ),\n",
|
| 654 |
+
"# 'MLP': MLPClassifier(\n",
|
| 655 |
+
"# random_state=0, \n",
|
| 656 |
+
"# max_iter=1000, \n",
|
| 657 |
+
"# alpha=0.0003079393718075164, \n",
|
| 658 |
+
"# hidden_layer_sizes=195, \n",
|
| 659 |
+
"# learning_rate_init=0.0001675266159417717\n",
|
| 660 |
+
"# ),\n",
|
| 661 |
+
"# 'SVC': SVC(probability=True, kernel = 'rbf', C = 0.95, gamma = 'scale')}\n",
|
| 662 |
+
"\n",
|
| 663 |
+
"# Default classifiers\n",
|
| 664 |
+
"# classifiers = {\n",
|
| 665 |
+
"# 'DecisionTree': DecisionTreeClassifier(random_state=0),\n",
|
| 666 |
+
"# 'LogisticRegression': LogisticRegression(max_iter=1000, random_state=0),\n",
|
| 667 |
+
"# 'NaiveBayes': GaussianNB(),\n",
|
| 668 |
+
"# 'KNeighbors': KNeighborsClassifier(),\n",
|
| 669 |
+
"# 'MLP': MLPClassifier(max_iter=1000, random_state=0),\n",
|
| 670 |
+
"# 'SVC': SVC(probability=True, random_state=0)\n",
|
| 671 |
+
"# }\n",
|
| 672 |
+
"\n",
|
| 673 |
+
"# Main\n",
|
| 674 |
+
"# Initialize\n",
|
| 675 |
+
"metric_sums = defaultdict(lambda: {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0})\n",
|
| 676 |
+
"conf_matrices = defaultdict(list)\n",
|
| 677 |
+
"roc_curves = defaultdict(list)\n",
|
| 678 |
+
"roc_aucs = defaultdict(list)\n",
|
| 679 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 680 |
+
"precision_scores = defaultdict(list)\n",
|
| 681 |
+
"recall_scores = defaultdict(list)\n",
|
| 682 |
+
"f1_scores = defaultdict(list)\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"# Loop over 10 different random states\n",
|
| 685 |
+
"for random_state in range(10):\n",
|
| 686 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 687 |
+
"\n",
|
| 688 |
+
" # Splitting the data\n",
|
| 689 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 690 |
+
" y = splitted_dataset['depression_category']\n",
|
| 691 |
+
" \n",
|
| 692 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 693 |
+
" \n",
|
| 694 |
+
" # Identify outliers in the training dataset\n",
|
| 695 |
+
" lof = LocalOutlierFactor()\n",
|
| 696 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 697 |
+
" # Select all rows that are not outliers\n",
|
| 698 |
+
" mask = yhat != -1\n",
|
| 699 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 700 |
+
" \n",
|
| 701 |
+
" original_columns = X.columns.tolist()\n",
|
| 702 |
+
"\n",
|
| 703 |
+
" # SMOTE\n",
|
| 704 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 705 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 706 |
+
"\n",
|
| 707 |
+
" print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 708 |
+
"\n",
|
| 709 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 710 |
+
" \n",
|
| 711 |
+
" sampling_strategy_undersample = {0: 372}\n",
|
| 712 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 713 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 714 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 715 |
+
"\n",
|
| 716 |
+
" # Normalization\n",
|
| 717 |
+
" scaler = MinMaxScaler()\n",
|
| 718 |
+
" \n",
|
| 719 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 720 |
+
" X_test = scaler.transform(X_test)\n",
|
| 721 |
+
" \n",
|
| 722 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 723 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 724 |
+
"\n",
|
| 725 |
+
" # Correlation Feat Analysis\n",
|
| 726 |
+
" corr_df = X_res.copy()\n",
|
| 727 |
+
" corr_df['target'] = y_res\n",
|
| 728 |
+
" \n",
|
| 729 |
+
" corr_mat = corr_df.corr()\n",
|
| 730 |
+
" target_correlation = corr_mat['target'].drop('target')\n",
|
| 731 |
+
" top_features = target_correlation.abs().sort_values(ascending=False).head(200).index.tolist()\n",
|
| 732 |
+
" \n",
|
| 733 |
+
" # Only take top features\n",
|
| 734 |
+
" X_res_fi = X_res[top_features]\n",
|
| 735 |
+
" X_test_fi = X_test[top_features]\n",
|
| 736 |
+
"\n",
|
| 737 |
+
" # Evaluate classifiers\n",
|
| 738 |
+
" for clf_name, clf in classifiers.items():\n",
|
| 739 |
+
" # Ensure the random state for classifiers is consistent\n",
|
| 740 |
+
" if hasattr(clf, 'random_state'):\n",
|
| 741 |
+
" clf.set_params(random_state=random_state)\n",
|
| 742 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name)\n",
|
| 743 |
+
" metric_sums[clf_name]['accuracy'] += accuracy\n",
|
| 744 |
+
" metric_sums[clf_name]['precision'] += precision\n",
|
| 745 |
+
" metric_sums[clf_name]['recall'] += recall\n",
|
| 746 |
+
" metric_sums[clf_name]['f1'] += f1\n",
|
| 747 |
+
" conf_matrices[clf_name].append(conf_matrix)\n",
|
| 748 |
+
" roc_curves[clf_name].append((fpr, tpr))\n",
|
| 749 |
+
" roc_aucs[clf_name].append(roc_auc)\n",
|
| 750 |
+
" accuracy_scores[clf_name].append(accuracy)\n",
|
| 751 |
+
" precision_scores[clf_name].append(precision)\n",
|
| 752 |
+
" recall_scores[clf_name].append(recall)\n",
|
| 753 |
+
" f1_scores[clf_name].append(f1)"
|
| 754 |
+
]
|
| 755 |
+
},
|
| 756 |
+
{
|
| 757 |
+
"cell_type": "markdown",
|
| 758 |
+
"id": "0ec7e67e-a086-48a7-83de-26b54ef03899",
|
| 759 |
+
"metadata": {},
|
| 760 |
+
"source": [
|
| 761 |
+
"#### Static Classifiers"
|
| 762 |
+
]
|
| 763 |
+
},
|
| 764 |
+
{
|
| 765 |
+
"cell_type": "code",
|
| 766 |
+
"execution_count": null,
|
| 767 |
+
"id": "e9e20dd2-bdd8-4626-be89-8bb866212de9",
|
| 768 |
+
"metadata": {
|
| 769 |
+
"scrolled": true
|
| 770 |
+
},
|
| 771 |
+
"outputs": [],
|
| 772 |
+
"source": [
|
| 773 |
+
"# Initialize\n",
|
| 774 |
+
"metric_sums = defaultdict(lambda: {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0})\n",
|
| 775 |
+
"conf_matrices = defaultdict(list)\n",
|
| 776 |
+
"roc_curves = defaultdict(list)\n",
|
| 777 |
+
"roc_aucs = defaultdict(list)\n",
|
| 778 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 779 |
+
"precision_scores = defaultdict(list)\n",
|
| 780 |
+
"recall_scores = defaultdict(list)\n",
|
| 781 |
+
"f1_scores = defaultdict(list)\n",
|
| 782 |
+
"\n",
|
| 783 |
+
"# Optimized Classifiers\n",
|
| 784 |
+
"classifiers = {\n",
|
| 785 |
+
" 'RF': RandomForestClassifier(n_estimators=143, criterion='entropy', max_depth=15, random_state=0),\n",
|
| 786 |
+
" 'XGB': XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=0),\n",
|
| 787 |
+
" 'GB': GradientBoostingClassifier(n_estimators=300, max_depth=3, learning_rate=0.05),\n",
|
| 788 |
+
" # 'AB': AdaBoostClassifier(n_estimators=400, learning_rate=0.1),\n",
|
| 789 |
+
" # 'CB': CatBoostClassifier(depth = 3, iterations = 168, learning_rate = 0.1, verbose = 0),\n",
|
| 790 |
+
" # 'LGBM': LGBMClassifier(learning_rate = 0.1, max_depth = 3, n_estimators = 200) \n",
|
| 791 |
+
"}\n",
|
| 792 |
+
"\n",
|
| 793 |
+
"# Default Classifiers\n",
|
| 794 |
+
"# classifiers = {\n",
|
| 795 |
+
"# 'RandomForest': RandomForestClassifier(n_estimators=100, criterion='entropy', max_depth=7, random_state=0),\n",
|
| 796 |
+
"# 'XGBoost': XGBClassifier(n_estimators=100, max_depth=7, use_label_encoder=False, eval_metric='mlogloss', random_state=0),\n",
|
| 797 |
+
"# 'GradientBoosting': GradientBoostingClassifier(n_estimators=100, random_state=0),\n",
|
| 798 |
+
"# 'AdaBoost': AdaBoostClassifier(n_estimators=100, random_state=0),\n",
|
| 799 |
+
"# 'CatBoost': CatBoostClassifier(n_estimators=100, verbose=0, random_state=0),\n",
|
| 800 |
+
"# 'LightGBM': LGBMClassifier(n_estimators=100, random_state=0)\n",
|
| 801 |
+
"# }\n",
|
| 802 |
+
"\n",
|
| 803 |
+
"# voting_clf = VotingClassifier(estimators=[\n",
|
| 804 |
+
"# ('rf', classifiers['RF']),\n",
|
| 805 |
+
"# ('xgb', classifiers['XGB']),\n",
|
| 806 |
+
"# ('gb', classifiers['GB']),\n",
|
| 807 |
+
"# ('ada', classifiers['AB']),\n",
|
| 808 |
+
"# ('cat', classifiers['CB']),\n",
|
| 809 |
+
"# ('lgbm', classifiers['LGBM'])\n",
|
| 810 |
+
"# ], voting='soft', n_jobs=1)\n",
|
| 811 |
+
"\n",
|
| 812 |
+
"# classifiers['Vot'] = voting_clf\n",
|
| 813 |
+
"\n",
|
| 814 |
+
"# Define the number of features for each classifier\n",
|
| 815 |
+
"num_features = {\n",
|
| 816 |
+
" 'RF': 150,\n",
|
| 817 |
+
" 'XGB': 150,\n",
|
| 818 |
+
" 'GB': 150,\n",
|
| 819 |
+
" # 'AB': 150,\n",
|
| 820 |
+
" # 'CB': 150,\n",
|
| 821 |
+
" # 'LGBM': 150,\n",
|
| 822 |
+
" # 'Vot': 150\n",
|
| 823 |
+
"}\n",
|
| 824 |
+
"\n",
|
| 825 |
+
"for random_state in range(10):\n",
|
| 826 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 827 |
+
"\n",
|
| 828 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 829 |
+
" y = splitted_dataset['depression_category']\n",
|
| 830 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 831 |
+
" \n",
|
| 832 |
+
" lof = LocalOutlierFactor()\n",
|
| 833 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 834 |
+
" mask = yhat != -1\n",
|
| 835 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 836 |
+
" \n",
|
| 837 |
+
" original_columns = X.columns.tolist()\n",
|
| 838 |
+
"\n",
|
| 839 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 840 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 841 |
+
"\n",
|
| 842 |
+
" print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 843 |
+
"\n",
|
| 844 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 845 |
+
" \n",
|
| 846 |
+
" sampling_strategy_undersample = {0: 372}\n",
|
| 847 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 848 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 849 |
+
"\n",
|
| 850 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 851 |
+
"\n",
|
| 852 |
+
" scaler = MinMaxScaler()\n",
|
| 853 |
+
" \n",
|
| 854 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 855 |
+
" X_test = scaler.transform(X_test)\n",
|
| 856 |
+
" \n",
|
| 857 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 858 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 859 |
+
"\n",
|
| 860 |
+
" log_reg = LogisticRegression(C=0.09659168435718246, max_iter=100, solver='lbfgs', random_state=random_state)\n",
|
| 861 |
+
" log_reg.fit(X_res, y_res)\n",
|
| 862 |
+
" selector = SelectFromModel(log_reg, prefit=True)\n",
|
| 863 |
+
" \n",
|
| 864 |
+
" importance = np.abs(log_reg.coef_[0])\n",
|
| 865 |
+
" indices = np.argsort(importance)[::-1]\n",
|
| 866 |
+
" important_features = [original_columns[i] for i in indices]\n",
|
| 867 |
+
" \n",
|
| 868 |
+
" for clf_name, clf in classifiers.items():\n",
|
| 869 |
+
" num_top_features = num_features[clf_name]\n",
|
| 870 |
+
" selected_features = important_features[:num_top_features]\n",
|
| 871 |
+
" \n",
|
| 872 |
+
" X_res_fi = pd.DataFrame(X_res, columns=original_columns)[selected_features]\n",
|
| 873 |
+
" X_test_fi = pd.DataFrame(X_test, columns=original_columns)[selected_features]\n",
|
| 874 |
+
"\n",
|
| 875 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(\n",
|
| 876 |
+
" clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name\n",
|
| 877 |
+
" )\n",
|
| 878 |
+
" metric_sums[clf_name]['accuracy'] += accuracy\n",
|
| 879 |
+
" metric_sums[clf_name]['precision'] += precision\n",
|
| 880 |
+
" metric_sums[clf_name]['recall'] += recall\n",
|
| 881 |
+
" metric_sums[clf_name]['f1'] += f1\n",
|
| 882 |
+
" conf_matrices[clf_name].append(conf_matrix)\n",
|
| 883 |
+
" roc_curves[clf_name].append((fpr, tpr))\n",
|
| 884 |
+
" roc_aucs[clf_name].append(roc_auc)\n",
|
| 885 |
+
" accuracy_scores[clf_name].append(accuracy)\n",
|
| 886 |
+
" precision_scores[clf_name].append(precision)\n",
|
| 887 |
+
" recall_scores[clf_name].append(recall)\n",
|
| 888 |
+
" f1_scores[clf_name].append(f1)"
|
| 889 |
+
]
|
| 890 |
+
},
|
| 891 |
+
{
|
| 892 |
+
"cell_type": "markdown",
|
| 893 |
+
"id": "c0158f58-e6d5-4adf-90e3-f1892294ad24",
|
| 894 |
+
"metadata": {},
|
| 895 |
+
"source": [
|
| 896 |
+
"### Post Training (classic/static)\n",
|
| 897 |
+
"Only run after one of the training methods above are done"
|
| 898 |
+
]
|
| 899 |
+
},
|
| 900 |
+
{
|
| 901 |
+
"cell_type": "code",
|
| 902 |
+
"execution_count": null,
|
| 903 |
+
"id": "c2ab5543-ad6e-4463-8d9a-35fe3e4e8eb4",
|
| 904 |
+
"metadata": {},
|
| 905 |
+
"outputs": [],
|
| 906 |
+
"source": [
|
| 907 |
+
"print('\\nAverage Metrics over 10 Random States:')\n",
|
| 908 |
+
"for clf_name, metrics in metric_sums.items():\n",
|
| 909 |
+
" avg_accuracy = metrics['accuracy'] / 10\n",
|
| 910 |
+
" avg_precision = metrics['precision'] / 10\n",
|
| 911 |
+
" avg_recall = metrics['recall'] / 10\n",
|
| 912 |
+
" avg_f1 = metrics['f1'] / 10\n",
|
| 913 |
+
" std_accuracy = np.std(accuracy_scores[clf_name])\n",
|
| 914 |
+
" std_precision = np.std(precision_scores[clf_name])\n",
|
| 915 |
+
" std_recall = np.std(recall_scores[clf_name])\n",
|
| 916 |
+
" std_f1 = np.std(f1_scores[clf_name])\n",
|
| 917 |
+
" avg_auc = np.mean(roc_aucs[clf_name])\n",
|
| 918 |
+
" print(f'{clf_name} - Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}, Precision: {avg_precision:.4f} ± {std_precision:.4f}, Recall: {avg_recall:.4f} ± {std_recall:.4f}, F1-Score: {avg_f1:.4f} ± {std_f1:.4f}, AUC: {avg_auc:.4f}')"
|
| 919 |
+
]
|
| 920 |
+
},
|
| 921 |
+
{
|
| 922 |
+
"cell_type": "code",
|
| 923 |
+
"execution_count": null,
|
| 924 |
+
"id": "21d5c872-3452-41ca-8f7c-5f4ceb184fba",
|
| 925 |
+
"metadata": {},
|
| 926 |
+
"outputs": [],
|
| 927 |
+
"source": [
|
| 928 |
+
"# Plot ROC Curves for each classifier in one graph\n",
|
| 929 |
+
"plot_combined_roc_curve(roc_curves, classifiers.keys())"
|
| 930 |
+
]
|
| 931 |
+
},
|
| 932 |
+
{
|
| 933 |
+
"cell_type": "code",
|
| 934 |
+
"execution_count": null,
|
| 935 |
+
"id": "26f7a715-efc8-44a6-8c53-e59b6268e551",
|
| 936 |
+
"metadata": {
|
| 937 |
+
"scrolled": true
|
| 938 |
+
},
|
| 939 |
+
"outputs": [],
|
| 940 |
+
"source": [
|
| 941 |
+
"# FN Curve\n",
|
| 942 |
+
"df = pd.DataFrame(accuracy_scores)\n",
|
| 943 |
+
"scores = [df[col].values for col in df.columns]\n",
|
| 944 |
+
"stat, p = friedmanchisquare(*scores)\n",
|
| 945 |
+
"print(f'Friedman Test Statistic: {stat}, p-value: {p}')\n",
|
| 946 |
+
"ranks = df.rank(axis=1, method='average', ascending=False)\n",
|
| 947 |
+
"average_ranks = ranks.mean().values\n",
|
| 948 |
+
"n_datasets = df.shape[0]\n",
|
| 949 |
+
"alpha = 0.05\n",
|
| 950 |
+
"cd = compute_CD(average_ranks, n_datasets, alpha='0.05')\n",
|
| 951 |
+
"print(f'Critical Difference: {cd}')\n",
|
| 952 |
+
"classifiers = [f\"{clf} ({rank:.2f})\" for clf, rank in zip(df.columns, average_ranks)]\n",
|
| 953 |
+
"plt.figure(figsize=(14, 8))\n",
|
| 954 |
+
"graph_ranks(average_ranks, classifiers, cd=cd, width=6, textspace=1)\n",
|
| 955 |
+
"plt.text(0.5, 1.19, f'Friedman-Nemenyi: {stat:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=14)\n",
|
| 956 |
+
"plt.text(0.5, 1.10, f'CD: {cd:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=14)\n",
|
| 957 |
+
"plt.tight_layout()"
|
| 958 |
+
]
|
| 959 |
+
},
|
| 960 |
+
{
|
| 961 |
+
"cell_type": "markdown",
|
| 962 |
+
"id": "88072d69-8bb5-46ab-9c79-6990db7cc8d2",
|
| 963 |
+
"metadata": {},
|
| 964 |
+
"source": [
|
| 965 |
+
"### Hyperparameter optimization (classic/static)"
|
| 966 |
+
]
|
| 967 |
+
},
|
| 968 |
+
{
|
| 969 |
+
"cell_type": "code",
|
| 970 |
+
"execution_count": null,
|
| 971 |
+
"id": "3a4a1546-c24f-4349-bf1b-75ba937700be",
|
| 972 |
+
"metadata": {
|
| 973 |
+
"scrolled": true
|
| 974 |
+
},
|
| 975 |
+
"outputs": [],
|
| 976 |
+
"source": [
|
| 977 |
+
"# Hyperparameter optimization classic\n",
|
| 978 |
+
"search_spaces = {\n",
|
| 979 |
+
" 'DecisionTree': {\n",
|
| 980 |
+
" 'criterion': Categorical(['gini', 'entropy']),\n",
|
| 981 |
+
" 'max_depth': Integer(1, 20),\n",
|
| 982 |
+
" 'min_samples_split': Integer(2, 10),\n",
|
| 983 |
+
" 'min_samples_leaf': Integer(1, 10)\n",
|
| 984 |
+
" },\n",
|
| 985 |
+
" 'LogisticRegression': {\n",
|
| 986 |
+
" 'C': Real(1e-6, 1e+6, prior='log-uniform'),\n",
|
| 987 |
+
" 'solver': Categorical(['lbfgs', 'liblinear']),\n",
|
| 988 |
+
" 'max_iter': Integer(100, 1000)\n",
|
| 989 |
+
" },\n",
|
| 990 |
+
" 'NaiveBayes': {\n",
|
| 991 |
+
" 'var_smoothing': Real(1e-9, 1e-2, prior='log-uniform')\n",
|
| 992 |
+
" },\n",
|
| 993 |
+
" 'KNeighbors': {\n",
|
| 994 |
+
" 'n_neighbors': Integer(1, 30),\n",
|
| 995 |
+
" 'weights': Categorical(['uniform', 'distance']),\n",
|
| 996 |
+
" 'metric': Categorical(['euclidean', 'manhattan', 'minkowski'])\n",
|
| 997 |
+
" },\n",
|
| 998 |
+
" 'MLP': {\n",
|
| 999 |
+
" 'hidden_layer_sizes': Integer(50, 200),\n",
|
| 1000 |
+
" 'alpha': Real(1e-6, 1e-2, prior='log-uniform'),\n",
|
| 1001 |
+
" 'learning_rate_init': Real(1e-4, 1e-2, prior='log-uniform')\n",
|
| 1002 |
+
" },\n",
|
| 1003 |
+
" 'SVC': {\n",
|
| 1004 |
+
" 'C': [0.1, 1, 10, 100, 1000],\n",
|
| 1005 |
+
" 'gamma': [1, 0.1, 0.01, 0.001, 0.0001],\n",
|
| 1006 |
+
" 'kernel': ['rbf']\n",
|
| 1007 |
+
" }\n",
|
| 1008 |
+
"}\n",
|
| 1009 |
+
"\n",
|
| 1010 |
+
"classifiers = {\n",
|
| 1011 |
+
" 'DecisionTree': DecisionTreeClassifier(random_state=0),\n",
|
| 1012 |
+
" 'LogisticRegression': LogisticRegression(max_iter=1000, random_state=0),\n",
|
| 1013 |
+
" 'NaiveBayes': GaussianNB(),\n",
|
| 1014 |
+
" 'KNeighbors': KNeighborsClassifier(),\n",
|
| 1015 |
+
" 'MLP': MLPClassifier(max_iter=1000, random_state=0),\n",
|
| 1016 |
+
" 'SVC': SVC(probability=True, random_state=0)\n",
|
| 1017 |
+
"}\n",
|
| 1018 |
+
"\n",
|
| 1019 |
+
"top_features_count = {\n",
|
| 1020 |
+
" 'DecisionTree': 200,\n",
|
| 1021 |
+
" 'LogisticRegression': 200,\n",
|
| 1022 |
+
" 'NaiveBayes': 200,\n",
|
| 1023 |
+
" 'KNeighbors': 200,\n",
|
| 1024 |
+
" 'MLP': 200,\n",
|
| 1025 |
+
" 'SVC': 200\n",
|
| 1026 |
+
"}\n",
|
| 1027 |
+
"\n",
|
| 1028 |
+
"random_state = 0\n",
|
| 1029 |
+
"print(f\"Processing for Random State: {random_state}\")\n",
|
| 1030 |
+
"\n",
|
| 1031 |
+
"X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1032 |
+
"y = splitted_dataset['depression_category']\n",
|
| 1033 |
+
"\n",
|
| 1034 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1035 |
+
"\n",
|
| 1036 |
+
"lof = LocalOutlierFactor()\n",
|
| 1037 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 1038 |
+
"mask = yhat != -1\n",
|
| 1039 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1040 |
+
"\n",
|
| 1041 |
+
"original_columns = X.columns.tolist()\n",
|
| 1042 |
+
"\n",
|
| 1043 |
+
"smote = SMOTE(random_state=random_state)\n",
|
| 1044 |
+
"X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1045 |
+
"\n",
|
| 1046 |
+
"print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 1047 |
+
"\n",
|
| 1048 |
+
"print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 1049 |
+
"\n",
|
| 1050 |
+
"sampling_strategy_undersample = {0: 372}\n",
|
| 1051 |
+
"rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1052 |
+
"X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1053 |
+
"print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1054 |
+
"\n",
|
| 1055 |
+
"scaler = MinMaxScaler()\n",
|
| 1056 |
+
"\n",
|
| 1057 |
+
"X_res = scaler.fit_transform(X_res)\n",
|
| 1058 |
+
"X_test = scaler.transform(X_test)\n",
|
| 1059 |
+
"\n",
|
| 1060 |
+
"X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1061 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1062 |
+
"\n",
|
| 1063 |
+
"corr_df = X_res.copy()\n",
|
| 1064 |
+
"corr_df['target'] = y_res\n",
|
| 1065 |
+
"\n",
|
| 1066 |
+
"corr_mat = corr_df.corr()\n",
|
| 1067 |
+
"target_correlation = corr_mat['target'].drop('target')\n",
|
| 1068 |
+
"\n",
|
| 1069 |
+
"for clf_name, clf in classifiers.items():\n",
|
| 1070 |
+
" print(f\"Optimizing {clf_name}\")\n",
|
| 1071 |
+
" \n",
|
| 1072 |
+
" top_features = target_correlation.abs().sort_values(ascending=False).head(top_features_count[clf_name]).index.tolist()\n",
|
| 1073 |
+
" \n",
|
| 1074 |
+
" X_res_fi = X_res[top_features]\n",
|
| 1075 |
+
" X_test_fi = X_test[top_features]\n",
|
| 1076 |
+
" \n",
|
| 1077 |
+
" opt = BayesSearchCV(clf, search_spaces[clf_name], n_iter=30, cv=3, random_state=random_state, n_jobs=-1, verbose = 30)\n",
|
| 1078 |
+
" opt.fit(X_res_fi, y_res)\n",
|
| 1079 |
+
" \n",
|
| 1080 |
+
" best_clf = opt.best_estimator_\n",
|
| 1081 |
+
" best_params = opt.best_params_\n",
|
| 1082 |
+
"\n",
|
| 1083 |
+
" print(f\"Best parameters for {clf_name}: {best_params}\")\n",
|
| 1084 |
+
" \n",
|
| 1085 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(best_clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name)\n",
|
| 1086 |
+
" print(f\"Best results for {clf_name}:\")\n",
|
| 1087 |
+
" print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, AUC: {roc_auc:.4f}')\n",
|
| 1088 |
+
" print(conf_matrix)\n",
|
| 1089 |
+
" print() "
|
| 1090 |
+
]
|
| 1091 |
+
},
|
| 1092 |
+
{
|
| 1093 |
+
"cell_type": "code",
|
| 1094 |
+
"execution_count": null,
|
| 1095 |
+
"id": "2f473cc7-577d-4dcd-a8b1-be7f33200fbc",
|
| 1096 |
+
"metadata": {
|
| 1097 |
+
"scrolled": true
|
| 1098 |
+
},
|
| 1099 |
+
"outputs": [],
|
| 1100 |
+
"source": [
|
| 1101 |
+
"# Hyperparameter optimization static\n",
|
| 1102 |
+
"metric_sums = defaultdict(lambda: {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0})\n",
|
| 1103 |
+
"conf_matrices = defaultdict(list)\n",
|
| 1104 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 1105 |
+
"precision_scores = defaultdict(list)\n",
|
| 1106 |
+
"recall_scores = defaultdict(list)\n",
|
| 1107 |
+
"f1_scores = defaultdict(list)\n",
|
| 1108 |
+
"\n",
|
| 1109 |
+
"classifiers = {\n",
|
| 1110 |
+
" 'RandomForest': RandomForestClassifier(),\n",
|
| 1111 |
+
" 'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='mlogloss'),\n",
|
| 1112 |
+
" 'AdaBoost': AdaBoostClassifier(),\n",
|
| 1113 |
+
" 'GradientBoosting': GradientBoostingClassifier(),\n",
|
| 1114 |
+
" 'CatBoost': CatBoostClassifier(verbose=0),\n",
|
| 1115 |
+
" 'LightGBM': LGBMClassifier()\n",
|
| 1116 |
+
"}\n",
|
| 1117 |
+
"\n",
|
| 1118 |
+
"num_features = {\n",
|
| 1119 |
+
" 'RandomForest': 150,\n",
|
| 1120 |
+
" 'XGBoost': 150,\n",
|
| 1121 |
+
" 'GradientBoosting': 150,\n",
|
| 1122 |
+
" 'AdaBoost': 150,\n",
|
| 1123 |
+
" 'CatBoost': 150,\n",
|
| 1124 |
+
" 'LightGBM': 150,\n",
|
| 1125 |
+
"}\n",
|
| 1126 |
+
"\n",
|
| 1127 |
+
"search_spaces = {\n",
|
| 1128 |
+
" 'RandomForest': {\n",
|
| 1129 |
+
" 'n_estimators': [100, 200, 300],\n",
|
| 1130 |
+
" 'criterion': ['gini', 'entropy'],\n",
|
| 1131 |
+
" 'max_depth': [None, 7, 15],\n",
|
| 1132 |
+
" 'bootstrap': [True, False]\n",
|
| 1133 |
+
" },\n",
|
| 1134 |
+
" 'XGBoost': {\n",
|
| 1135 |
+
" 'n_estimators': [100, 200, 300],\n",
|
| 1136 |
+
" 'max_depth': [5, 10],\n",
|
| 1137 |
+
" 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1138 |
+
" 'gamma': [0, 0.2, 0.4],\n",
|
| 1139 |
+
" },\n",
|
| 1140 |
+
" 'GradientBoosting': {\n",
|
| 1141 |
+
" 'n_estimators': [100, 200, 300],\n",
|
| 1142 |
+
" 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1143 |
+
" 'max_depth': [5, 10],\n",
|
| 1144 |
+
" 'subsample': [0.7, 0.9, 1.0],\n",
|
| 1145 |
+
" },\n",
|
| 1146 |
+
" 'AdaBoost': {\n",
|
| 1147 |
+
" 'n_estimators': [100, 200, 300],\n",
|
| 1148 |
+
" 'learning_rate': [0.1, 0.5, 1.0],\n",
|
| 1149 |
+
" 'algorithm': ['SAMME', 'SAMME.R']\n",
|
| 1150 |
+
" },\n",
|
| 1151 |
+
" 'CatBoost': {\n",
|
| 1152 |
+
" 'iterations': [100, 200, 300],\n",
|
| 1153 |
+
" 'depth': [5, 7, 9],\n",
|
| 1154 |
+
" 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1155 |
+
" },\n",
|
| 1156 |
+
" 'LightGBM': {\n",
|
| 1157 |
+
" 'n_estimators': [100, 200, 300],\n",
|
| 1158 |
+
" 'num_leaves': [31, 63, 127],\n",
|
| 1159 |
+
" 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1160 |
+
" 'subsample': [0.7, 0.9, 1.0],\n",
|
| 1161 |
+
" }\n",
|
| 1162 |
+
"}\n",
|
| 1163 |
+
"\n",
|
| 1164 |
+
"def hyperparameter_optimization(clf, search_space, X, y):\n",
|
| 1165 |
+
" combined_results = []\n",
|
| 1166 |
+
" for random_state in range(3):\n",
|
| 1167 |
+
" cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=random_state)\n",
|
| 1168 |
+
" opt = BayesSearchCV(clf, search_space, n_iter=30, cv=cv, random_state=random_state, n_jobs=-1, verbose=0)\n",
|
| 1169 |
+
" opt.fit(X, y)\n",
|
| 1170 |
+
" combined_results.append(opt.best_params_)\n",
|
| 1171 |
+
" best_params = pd.DataFrame(combined_results).mode().iloc[0].to_dict()\n",
|
| 1172 |
+
" return best_params\n",
|
| 1173 |
+
"\n",
|
| 1174 |
+
"for random_state in range(9,10):\n",
|
| 1175 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 1176 |
+
"\n",
|
| 1177 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1178 |
+
" y = splitted_dataset['depression_category']\n",
|
| 1179 |
+
" \n",
|
| 1180 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1181 |
+
" \n",
|
| 1182 |
+
" lof = LocalOutlierFactor()\n",
|
| 1183 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 1184 |
+
" mask = yhat != -1\n",
|
| 1185 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1186 |
+
" \n",
|
| 1187 |
+
" original_columns = X.columns.tolist()\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 1190 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1191 |
+
"\n",
|
| 1192 |
+
" print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 1193 |
+
"\n",
|
| 1194 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 1195 |
+
" \n",
|
| 1196 |
+
" sampling_strategy_undersample = {0: 372}\n",
|
| 1197 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1198 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1199 |
+
"\n",
|
| 1200 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1201 |
+
"\n",
|
| 1202 |
+
" scaler = MinMaxScaler()\n",
|
| 1203 |
+
" \n",
|
| 1204 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 1205 |
+
" X_test = scaler.transform(X_test)\n",
|
| 1206 |
+
" \n",
|
| 1207 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1208 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1209 |
+
"\n",
|
| 1210 |
+
" log_reg = LogisticRegression(C=0.09659168435718246, max_iter=100, solver='lbfgs', random_state=random_state)\n",
|
| 1211 |
+
" log_reg.fit(X_res, y_res)\n",
|
| 1212 |
+
" selector = SelectFromModel(log_reg, prefit=True)\n",
|
| 1213 |
+
" \n",
|
| 1214 |
+
" importance = np.abs(log_reg.coef_[0])\n",
|
| 1215 |
+
" indices = np.argsort(importance)[::-1]\n",
|
| 1216 |
+
" important_features = [original_columns[i] for i in indices[:300]]\n",
|
| 1217 |
+
"\n",
|
| 1218 |
+
" for clf_name, clf in classifiers.items():\n",
|
| 1219 |
+
" print(f\"Optimizing {clf_name}\")\n",
|
| 1220 |
+
" num_top_features = num_features[clf_name]\n",
|
| 1221 |
+
" selected_features = important_features[:num_top_features]\n",
|
| 1222 |
+
" \n",
|
| 1223 |
+
" X_res_fi = pd.DataFrame(X_res, columns=original_columns)[selected_features]\n",
|
| 1224 |
+
" \n",
|
| 1225 |
+
" best_params = hyperparameter_optimization(clf, search_spaces[clf_name], X_res_fi, y_res)\n",
|
| 1226 |
+
" if 'n_estimators' in best_params:\n",
|
| 1227 |
+
" best_params['n_estimators'] = int(best_params['n_estimators'])\n",
|
| 1228 |
+
" if 'max_depth' in best_params:\n",
|
| 1229 |
+
" best_params['max_depth'] = int(best_params['max_depth'])\n",
|
| 1230 |
+
" if 'iterations' in best_params:\n",
|
| 1231 |
+
" best_params['iterations'] = int(best_params['iterations'])\n",
|
| 1232 |
+
" clf.set_params(**best_params)\n",
|
| 1233 |
+
" print(f\"Best parameters for {clf_name}: {best_params}\")\n",
|
| 1234 |
+
"\n",
|
| 1235 |
+
" X_test_fi = pd.DataFrame(X_test, columns=original_columns)[selected_features]\n",
|
| 1236 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name)\n",
|
| 1237 |
+
" metric_sums[clf_name]['accuracy'] += accuracy\n",
|
| 1238 |
+
" metric_sums[clf_name]['precision'] += precision\n",
|
| 1239 |
+
" metric_sums[clf_name]['recall'] += recall\n",
|
| 1240 |
+
" metric_sums[clf_name]['f1'] += f1\n",
|
| 1241 |
+
" conf_matrices[clf_name].append(conf_matrix)\n",
|
| 1242 |
+
" accuracy_scores[clf_name].append(accuracy)\n",
|
| 1243 |
+
" precision_scores[clf_name].append(precision)\n",
|
| 1244 |
+
" recall_scores[clf_name].append(recall)\n",
|
| 1245 |
+
" f1_scores[clf_name].append(f1)"
|
| 1246 |
+
]
|
| 1247 |
+
},
|
| 1248 |
+
{
|
| 1249 |
+
"cell_type": "markdown",
|
| 1250 |
+
"id": "65df93d4-12d3-4d68-b402-633354849dff",
|
| 1251 |
+
"metadata": {},
|
| 1252 |
+
"source": [
|
| 1253 |
+
"### DES Training (all)"
|
| 1254 |
+
]
|
| 1255 |
+
},
|
| 1256 |
+
{
|
| 1257 |
+
"cell_type": "code",
|
| 1258 |
+
"execution_count": null,
|
| 1259 |
+
"id": "709bb1f8-249e-4409-a537-c6bbe9a399f3",
|
| 1260 |
+
"metadata": {
|
| 1261 |
+
"scrolled": true
|
| 1262 |
+
},
|
| 1263 |
+
"outputs": [],
|
| 1264 |
+
"source": [
|
| 1265 |
+
"metric_sums_des = {\n",
|
| 1266 |
+
" 'KNORAE': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1267 |
+
" 'KNORAU': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1268 |
+
" 'KNOP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1269 |
+
" 'DESMI': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1270 |
+
" 'METADES': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1271 |
+
" 'DESKNN': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1272 |
+
" 'DESP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1273 |
+
" 'FIRE-KNORA-U': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1274 |
+
" 'FIRE-KNORA-E': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1275 |
+
" 'FIRE-METADES': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1276 |
+
" 'FIRE-DESKNN': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1277 |
+
" 'FIRE-DESP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1278 |
+
" 'FIRE-KNOP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1279 |
+
"}\n",
|
| 1280 |
+
"\n",
|
| 1281 |
+
"conf_matrices_des = {\n",
|
| 1282 |
+
" 'KNORAE': [],\n",
|
| 1283 |
+
" 'KNORAU': [],\n",
|
| 1284 |
+
" 'KNOP': [],\n",
|
| 1285 |
+
" 'DESMI': [],\n",
|
| 1286 |
+
" 'METADES': [],\n",
|
| 1287 |
+
" 'DESKNN': [],\n",
|
| 1288 |
+
" 'DESP': [],\n",
|
| 1289 |
+
" 'FIRE-KNORA-U': [],\n",
|
| 1290 |
+
" 'FIRE-KNORA-E': [],\n",
|
| 1291 |
+
" 'FIRE-METADES': [],\n",
|
| 1292 |
+
" 'FIRE-DESKNN': [],\n",
|
| 1293 |
+
" 'FIRE-DESP': [],\n",
|
| 1294 |
+
" 'FIRE-KNOP': [],\n",
|
| 1295 |
+
"}\n",
|
| 1296 |
+
"\n",
|
| 1297 |
+
"roc_curves = defaultdict(list)\n",
|
| 1298 |
+
"roc_aucs = defaultdict(list)\n",
|
| 1299 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 1300 |
+
"precision_scores = defaultdict(list)\n",
|
| 1301 |
+
"recall_scores = defaultdict(list)\n",
|
| 1302 |
+
"f1_scores = defaultdict(list)\n",
|
| 1303 |
+
"feature_importance_runs = []\n",
|
| 1304 |
+
"\n",
|
| 1305 |
+
"# Uncomment wanted combinations\n",
|
| 1306 |
+
"base_classifiers = {\n",
|
| 1307 |
+
" # 'DecisionTree': DecisionTreeClassifier(\n",
|
| 1308 |
+
" # random_state=0, \n",
|
| 1309 |
+
" # criterion='gini', \n",
|
| 1310 |
+
" # max_depth=6, \n",
|
| 1311 |
+
" # min_samples_leaf=10, \n",
|
| 1312 |
+
" # min_samples_split=9\n",
|
| 1313 |
+
" # ),\n",
|
| 1314 |
+
" # 'LogisticRegression': LogisticRegression(\n",
|
| 1315 |
+
" # random_state=0, \n",
|
| 1316 |
+
" # C=0.09659168435718246, \n",
|
| 1317 |
+
" # max_iter=100, \n",
|
| 1318 |
+
" # solver='lbfgs'\n",
|
| 1319 |
+
" # ),\n",
|
| 1320 |
+
" # 'NaiveBayes': GaussianNB(\n",
|
| 1321 |
+
" # var_smoothing=0.0058873326349240295\n",
|
| 1322 |
+
" # ),\n",
|
| 1323 |
+
" # 'KNeighbors': KNeighborsClassifier(\n",
|
| 1324 |
+
" # metric='manhattan', \n",
|
| 1325 |
+
" # n_neighbors=15, \n",
|
| 1326 |
+
" # weights='uniform'\n",
|
| 1327 |
+
" # ),\n",
|
| 1328 |
+
" # 'MLP': MLPClassifier(\n",
|
| 1329 |
+
" # random_state=0, \n",
|
| 1330 |
+
" # max_iter=1000, \n",
|
| 1331 |
+
" # alpha=0.0003079393718075164, \n",
|
| 1332 |
+
" # hidden_layer_sizes=195, \n",
|
| 1333 |
+
" # learning_rate_init=0.0001675266159417717\n",
|
| 1334 |
+
" # ),\n",
|
| 1335 |
+
" # 'SVC': SVC(probability=True, kernel = 'rbf', C = 1.5, gamma = 'auto'),\n",
|
| 1336 |
+
" # 'RF': RandomForestClassifier(n_estimators=143, criterion='entropy', max_depth=15, random_state=0),\n",
|
| 1337 |
+
" 'XGB': XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=0),\n",
|
| 1338 |
+
" 'GB': GradientBoostingClassifier(n_estimators=300, max_depth=3, learning_rate=0.05),\n",
|
| 1339 |
+
" 'AB': AdaBoostClassifier(n_estimators=400, learning_rate=0.1),\n",
|
| 1340 |
+
" # 'CB': CatBoostClassifier(depth = 3, iterations = 168, learning_rate = 0.1, verbose = 0),\n",
|
| 1341 |
+
" # 'LGBM': LGBMClassifier(learning_rate = 0.1, max_depth = 3, n_estimators = 200) \n",
|
| 1342 |
+
"}\n",
|
| 1343 |
+
"\n",
|
| 1344 |
+
"random_state = 0\n",
|
| 1345 |
+
"\n",
|
| 1346 |
+
"for random_state in range(10):\n",
|
| 1347 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 1348 |
+
"\n",
|
| 1349 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1350 |
+
" y = splitted_dataset['depression_category']\n",
|
| 1351 |
+
" \n",
|
| 1352 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1353 |
+
" \n",
|
| 1354 |
+
" lof = LocalOutlierFactor()\n",
|
| 1355 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 1356 |
+
" mask = yhat != -1\n",
|
| 1357 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1358 |
+
" \n",
|
| 1359 |
+
" original_columns = X.columns.tolist()\n",
|
| 1360 |
+
"\n",
|
| 1361 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 1362 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1363 |
+
"\n",
|
| 1364 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\")\n",
|
| 1365 |
+
" sampling_strategy_undersample = {0: 372}\n",
|
| 1366 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1367 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test) \n",
|
| 1368 |
+
"\n",
|
| 1369 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1370 |
+
"\n",
|
| 1371 |
+
" scaler = MinMaxScaler()\n",
|
| 1372 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 1373 |
+
" X_test = scaler.transform(X_test)\n",
|
| 1374 |
+
" \n",
|
| 1375 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1376 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1377 |
+
"\n",
|
| 1378 |
+
" xgb_fs = XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=random_state)\n",
|
| 1379 |
+
" xgb_fs.fit(X_res, y_res)\n",
|
| 1380 |
+
"\n",
|
| 1381 |
+
" feature_importances = xgb_fs.feature_importances_\n",
|
| 1382 |
+
" indices = np.argsort(feature_importances)[::-1]\n",
|
| 1383 |
+
" top_50_features = [original_columns[i] for i in indices[:50]]\n",
|
| 1384 |
+
" current_run_features = {original_columns[i]: feature_importances[i] for i in indices[:50]}\n",
|
| 1385 |
+
" \n",
|
| 1386 |
+
" feature_importance_runs.append(current_run_features)\n",
|
| 1387 |
+
"\n",
|
| 1388 |
+
" X_res_fi = X_res[top_50_features]\n",
|
| 1389 |
+
" X_test_fi = X_test[top_50_features]\n",
|
| 1390 |
+
" \n",
|
| 1391 |
+
" model_pool = list(base_classifiers.values())\n",
|
| 1392 |
+
" \n",
|
| 1393 |
+
" for clf in model_pool:\n",
|
| 1394 |
+
" clf.fit(X_res_fi, y_res)\n",
|
| 1395 |
+
" \n",
|
| 1396 |
+
" des_models = {\n",
|
| 1397 |
+
" 'KNORAE': KNORAE(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1398 |
+
" 'KNORAU': KNORAU(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1399 |
+
" 'DESMI': DESMI(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1400 |
+
" 'METADES': METADES(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1401 |
+
" 'DESKNN': DESKNN(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1402 |
+
" 'DESP': DESP(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1403 |
+
" 'KNOP': KNOP(pool_classifiers=model_pool, random_state=random_state, k=9),\n",
|
| 1404 |
+
" 'FIRE-KNORA-U': KNORAU(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1405 |
+
" 'FIRE-KNORA-E': KNORAE(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1406 |
+
" 'FIRE-METADES': METADES(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1407 |
+
" 'FIRE-DESKNN': DESKNN(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1408 |
+
" 'FIRE-DESP': DESP(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1409 |
+
" 'FIRE-KNOP': KNOP(pool_classifiers=model_pool, DFP=True, k=40, random_state = random_state)\n",
|
| 1410 |
+
" }\n",
|
| 1411 |
+
"\n",
|
| 1412 |
+
" for des_name, des_model in des_models.items():\n",
|
| 1413 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(\n",
|
| 1414 |
+
" des_model, X_res_fi, y_res, X_test_fi, y_test, clf_name=des_name\n",
|
| 1415 |
+
" )\n",
|
| 1416 |
+
" metric_sums_des[des_name]['accuracy'] += accuracy\n",
|
| 1417 |
+
" metric_sums_des[des_name]['precision'] += precision\n",
|
| 1418 |
+
" metric_sums_des[des_name]['recall'] += recall\n",
|
| 1419 |
+
" metric_sums_des[des_name]['f1'] += f1\n",
|
| 1420 |
+
" conf_matrices_des[des_name].append(conf_matrix)\n",
|
| 1421 |
+
" roc_curves[des_name].append((fpr, tpr))\n",
|
| 1422 |
+
" roc_aucs[des_name].append(roc_auc)\n",
|
| 1423 |
+
" accuracy_scores[des_name].append(accuracy)\n",
|
| 1424 |
+
" precision_scores[des_name].append(precision)\n",
|
| 1425 |
+
" recall_scores[des_name].append(recall)\n",
|
| 1426 |
+
" f1_scores[des_name].append(f1)\n",
|
| 1427 |
+
"\n",
|
| 1428 |
+
" print(f'Confusion Matrix for {des_name} at Random State {random_state}:\\n{conf_matrix}\\n')"
|
| 1429 |
+
]
|
| 1430 |
+
},
|
| 1431 |
+
{
|
| 1432 |
+
"cell_type": "code",
|
| 1433 |
+
"execution_count": null,
|
| 1434 |
+
"id": "34b0ead4-1d8c-4d0a-b0c7-71cd71f6ab7d",
|
| 1435 |
+
"metadata": {},
|
| 1436 |
+
"outputs": [],
|
| 1437 |
+
"source": [
|
| 1438 |
+
"def plot_combined_roc_curve(roc_curves, classifier_names):\n",
|
| 1439 |
+
" plt.figure(figsize=(12, 8))\n",
|
| 1440 |
+
" mean_fpr = np.linspace(0, 1, 100)\n",
|
| 1441 |
+
" colors = plt.cm.get_cmap('tab20', len(classifier_names))\n",
|
| 1442 |
+
" \n",
|
| 1443 |
+
" for i, clf_name in enumerate(classifier_names):\n",
|
| 1444 |
+
" tprs = []\n",
|
| 1445 |
+
" for fpr, tpr in roc_curves[clf_name]:\n",
|
| 1446 |
+
" tprs.append(np.interp(mean_fpr, fpr, tpr))\n",
|
| 1447 |
+
" mean_tpr = np.mean(tprs, axis=0)\n",
|
| 1448 |
+
" mean_tpr[-1] = 1.0\n",
|
| 1449 |
+
" mean_auc = auc(mean_fpr, mean_tpr)\n",
|
| 1450 |
+
" plt.plot(mean_fpr, mean_tpr, color=colors(i), lw=2, linestyle='-', marker='o', markersize=4, \n",
|
| 1451 |
+
" label=f'{clf_name} (AUC = {mean_auc:.3f})')\n",
|
| 1452 |
+
"\n",
|
| 1453 |
+
" plt.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')\n",
|
| 1454 |
+
" plt.xlim([0.0, 1.0])\n",
|
| 1455 |
+
" plt.ylim([0.0, 1.05])\n",
|
| 1456 |
+
" plt.xlabel('False Positive Rate', fontsize=26)\n",
|
| 1457 |
+
" plt.ylabel('True Positive Rate', fontsize=26)\n",
|
| 1458 |
+
" plt.xticks(fontsize=30) # Increase x-axis numbers font size\n",
|
| 1459 |
+
" plt.yticks(fontsize=30) # Increase y-axis numbers font size\n",
|
| 1460 |
+
" plt.legend(loc=\"center left\", bbox_to_anchor=(1.05, 0.5), fontsize=26, frameon=True, framealpha=0.9) # Place legend beside the plot\n",
|
| 1461 |
+
" plt.grid(True)\n",
|
| 1462 |
+
"\n",
|
| 1463 |
+
" filename='bonk.svg'\n",
|
| 1464 |
+
"\n",
|
| 1465 |
+
" plt.savefig(filename, format='svg', bbox_inches = 'tight')\n",
|
| 1466 |
+
" plt.show()\n",
|
| 1467 |
+
"\n",
|
| 1468 |
+
" display(FileLink(filename))"
|
| 1469 |
+
]
|
| 1470 |
+
},
|
| 1471 |
+
{
|
| 1472 |
+
"cell_type": "code",
|
| 1473 |
+
"execution_count": null,
|
| 1474 |
+
"id": "4cfd8d73-26fd-4dbc-a6ca-92a9643e1d27",
|
| 1475 |
+
"metadata": {},
|
| 1476 |
+
"outputs": [],
|
| 1477 |
+
"source": [
|
| 1478 |
+
"print('\\nAverage Metrics over 10 Random States:')\n",
|
| 1479 |
+
"for des_name, metrics in metric_sums_des.items():\n",
|
| 1480 |
+
" avg_accuracy = metrics['accuracy'] / 10\n",
|
| 1481 |
+
" avg_precision = metrics['precision'] / 10\n",
|
| 1482 |
+
" avg_recall = metrics['recall'] / 10\n",
|
| 1483 |
+
" avg_f1 = metrics['f1'] / 10\n",
|
| 1484 |
+
" std_accuracy = np.std(accuracy_scores[des_name])\n",
|
| 1485 |
+
" std_precision = np.std(precision_scores[des_name])\n",
|
| 1486 |
+
" std_recall = np.std(recall_scores[des_name])\n",
|
| 1487 |
+
" std_f1 = np.std(f1_scores[des_name])\n",
|
| 1488 |
+
" avg_auc = np.mean(roc_aucs[des_name])\n",
|
| 1489 |
+
" print(f'{des_name} - Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}, Precision: {avg_precision:.4f} ± {std_precision:.4f}, Recall: {avg_recall:.4f} ± {std_recall:.4f}, F1-Score: {avg_f1:.4f} ± {std_f1:.4f}, AUC: {avg_auc:.4f}')\n",
|
| 1490 |
+
"\n",
|
| 1491 |
+
"plot_combined_roc_curve(roc_curves, list(des_models.keys()))"
|
| 1492 |
+
]
|
| 1493 |
+
},
|
| 1494 |
+
{
|
| 1495 |
+
"cell_type": "code",
|
| 1496 |
+
"execution_count": null,
|
| 1497 |
+
"id": "32780b70-dae5-4f7d-9ae4-128dc782fad4",
|
| 1498 |
+
"metadata": {},
|
| 1499 |
+
"outputs": [],
|
| 1500 |
+
"source": [
|
| 1501 |
+
"df = pd.DataFrame(accuracy_scores)\n",
|
| 1502 |
+
"scores = [df[col].values for col in df.columns]\n",
|
| 1503 |
+
"\n",
|
| 1504 |
+
"stat, p = friedmanchisquare(*scores)\n",
|
| 1505 |
+
"print(f'Friedman Test Statistic: {stat}, p-value: {p}')\n",
|
| 1506 |
+
"\n",
|
| 1507 |
+
"ranks = df.rank(axis=1, method='average', ascending=False)\n",
|
| 1508 |
+
"average_ranks = ranks.mean().values\n",
|
| 1509 |
+
"\n",
|
| 1510 |
+
"n_datasets = df.shape[0]\n",
|
| 1511 |
+
"alpha = 0.05\n",
|
| 1512 |
+
"\n",
|
| 1513 |
+
"cd = compute_CD(average_ranks, n_datasets, alpha='0.05')\n",
|
| 1514 |
+
"print(f'Critical Difference: {cd}')\n",
|
| 1515 |
+
"\n",
|
| 1516 |
+
"classifiers = [f\"{clf} ({rank:.2f})\" for clf, rank in zip(df.columns, average_ranks)]\n",
|
| 1517 |
+
"\n",
|
| 1518 |
+
"plt.figure(figsize=(14, 10))\n",
|
| 1519 |
+
"\n",
|
| 1520 |
+
"graph_ranks(average_ranks, classifiers, cd=cd, width=6, textspace=1)\n",
|
| 1521 |
+
"plt.xlabel('Classifiers')\n",
|
| 1522 |
+
"\n",
|
| 1523 |
+
"plt.text(0.5, 1.19, f'Friedman-Nemenyi: {stat:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=16)\n",
|
| 1524 |
+
"plt.text(0.5, 1.10, f'CD: {cd:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=16)\n",
|
| 1525 |
+
"\n",
|
| 1526 |
+
"plt.tight_layout()"
|
| 1527 |
+
]
|
| 1528 |
+
},
|
| 1529 |
+
{
|
| 1530 |
+
"cell_type": "markdown",
|
| 1531 |
+
"id": "f25a2b29-129f-4314-b2f9-8aff9615d5d7",
|
| 1532 |
+
"metadata": {},
|
| 1533 |
+
"source": [
|
| 1534 |
+
"### Shap (will mostly be exported files)"
|
| 1535 |
+
]
|
| 1536 |
+
},
|
| 1537 |
+
{
|
| 1538 |
+
"cell_type": "code",
|
| 1539 |
+
"execution_count": null,
|
| 1540 |
+
"id": "ed3c43d0-68d2-4aac-ac44-25a28dc6dc75",
|
| 1541 |
+
"metadata": {},
|
| 1542 |
+
"outputs": [],
|
| 1543 |
+
"source": [
|
| 1544 |
+
"# Example with XGB\n",
|
| 1545 |
+
"\n",
|
| 1546 |
+
"random_state = 2\n",
|
| 1547 |
+
"print(f\"Processing for Random State: {random_state}\")\n",
|
| 1548 |
+
"\n",
|
| 1549 |
+
"X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1550 |
+
"y = splitted_dataset['depression_category']\n",
|
| 1551 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1552 |
+
"\n",
|
| 1553 |
+
"lof = LocalOutlierFactor()\n",
|
| 1554 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 1555 |
+
"mask = yhat != -1\n",
|
| 1556 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1557 |
+
"\n",
|
| 1558 |
+
"original_columns = X.columns.tolist()\n",
|
| 1559 |
+
"\n",
|
| 1560 |
+
"smote = SMOTE(random_state=random_state)\n",
|
| 1561 |
+
"X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1562 |
+
"\n",
|
| 1563 |
+
"print(f\"Number of training labels after SMOTE: {y_res.value_counts()}\")\n",
|
| 1564 |
+
"print(f\"Number of test labels before resampling: {y_test.value_counts()}\")\n",
|
| 1565 |
+
"\n",
|
| 1566 |
+
"sampling_strategy_undersample = {0: 372}\n",
|
| 1567 |
+
"rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1568 |
+
"X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1569 |
+
"\n",
|
| 1570 |
+
"print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1571 |
+
"\n",
|
| 1572 |
+
"# Normalization\n",
|
| 1573 |
+
"# scaler = MinMaxScaler()\n",
|
| 1574 |
+
"# X_res = scaler.fit_transform(X_res)\n",
|
| 1575 |
+
"# X_test = scaler.transform(X_test)\n",
|
| 1576 |
+
"\n",
|
| 1577 |
+
"X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1578 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1579 |
+
"\n",
|
| 1580 |
+
"# Train XGBoost model on all features\n",
|
| 1581 |
+
"model = xgb.XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=random_state)\n",
|
| 1582 |
+
"model.fit(X_res, y_res)\n",
|
| 1583 |
+
"\n",
|
| 1584 |
+
"y_pred = model.predict(X_test)\n",
|
| 1585 |
+
"\n",
|
| 1586 |
+
"accuracy = accuracy_score(y_test, y_pred)\n",
|
| 1587 |
+
"print(f'Accuracy: {accuracy:.4f}')\n",
|
| 1588 |
+
"\n",
|
| 1589 |
+
"explainer = shap.Explainer(model, X_res)\n",
|
| 1590 |
+
"shap_values = explainer(X_res)"
|
| 1591 |
+
]
|
| 1592 |
+
},
|
| 1593 |
+
{
|
| 1594 |
+
"cell_type": "code",
|
| 1595 |
+
"execution_count": null,
|
| 1596 |
+
"id": "2cb9c929-4e86-42df-bda2-48730004c154",
|
| 1597 |
+
"metadata": {},
|
| 1598 |
+
"outputs": [],
|
| 1599 |
+
"source": [
|
| 1600 |
+
"def plot_shap_waterfall(instance_index, filename):\n",
|
| 1601 |
+
" shap_value = shap_values[instance_index]\n",
|
| 1602 |
+
" plt.figure(figsize=(14, 8))\n",
|
| 1603 |
+
" \n",
|
| 1604 |
+
" shap.plots.waterfall(shap_value, show=False)\n",
|
| 1605 |
+
" \n",
|
| 1606 |
+
" ax = plt.gca()\n",
|
| 1607 |
+
" \n",
|
| 1608 |
+
" ax.tick_params(axis='both', which='major', labelsize=16)\n",
|
| 1609 |
+
" ax.set_xlabel(ax.get_xlabel(), fontsize=20)\n",
|
| 1610 |
+
" ax.set_ylabel(ax.get_ylabel(), fontsize=22)\n",
|
| 1611 |
+
" \n",
|
| 1612 |
+
" plt.tight_layout()\n",
|
| 1613 |
+
"\n",
|
| 1614 |
+
" plt.savefig(filename, format='svg')\n",
|
| 1615 |
+
" \n",
|
| 1616 |
+
" plt.close()\n",
|
| 1617 |
+
"\n",
|
| 1618 |
+
"plot_shap_waterfall(0, \"waterfall_plot_instance_0.svg\")\n",
|
| 1619 |
+
"plot_shap_waterfall(562, \"waterfall_plot_instance_562.svg\")"
|
| 1620 |
+
]
|
| 1621 |
+
},
|
| 1622 |
+
{
|
| 1623 |
+
"cell_type": "code",
|
| 1624 |
+
"execution_count": null,
|
| 1625 |
+
"id": "e111213e-caea-48a0-adb7-c6b2362eed4f",
|
| 1626 |
+
"metadata": {},
|
| 1627 |
+
"outputs": [],
|
| 1628 |
+
"source": [
|
| 1629 |
+
"import matplotlib.pyplot as plt\n",
|
| 1630 |
+
"\n",
|
| 1631 |
+
"plt.figure(figsize=(14, 8))\n",
|
| 1632 |
+
"shap.summary_plot(\n",
|
| 1633 |
+
" shap_values,\n",
|
| 1634 |
+
" X_res,\n",
|
| 1635 |
+
" plot_type=\"bar\",\n",
|
| 1636 |
+
" feature_names=original_columns,\n",
|
| 1637 |
+
" show=False\n",
|
| 1638 |
+
")\n",
|
| 1639 |
+
"\n",
|
| 1640 |
+
"ax = plt.gca()\n",
|
| 1641 |
+
"ax.tick_params(axis='both', which='major', labelsize=16)\n",
|
| 1642 |
+
"ax.set_xlabel(ax.get_xlabel(), fontsize=16)\n",
|
| 1643 |
+
"ax.set_ylabel(ax.get_ylabel(), fontsize=22)\n",
|
| 1644 |
+
"\n",
|
| 1645 |
+
"plt.savefig(\"shap_summary_plot.svg\", format='svg')\n",
|
| 1646 |
+
"plt.close()"
|
| 1647 |
+
]
|
| 1648 |
+
},
|
| 1649 |
+
{
|
| 1650 |
+
"cell_type": "code",
|
| 1651 |
+
"execution_count": null,
|
| 1652 |
+
"id": "4ae84637-a018-49a1-ad4e-522aa937bfad",
|
| 1653 |
+
"metadata": {},
|
| 1654 |
+
"outputs": [],
|
| 1655 |
+
"source": [
|
| 1656 |
+
"shap.initjs()\n",
|
| 1657 |
+
"\n",
|
| 1658 |
+
"fig, ax = plt.subplots(figsize=(14, 8))\n",
|
| 1659 |
+
"\n",
|
| 1660 |
+
"shap.summary_plot(\n",
|
| 1661 |
+
" shap_values,\n",
|
| 1662 |
+
" X_res,\n",
|
| 1663 |
+
" plot_type=\"dot\",\n",
|
| 1664 |
+
" feature_names=original_columns,\n",
|
| 1665 |
+
" show=False\n",
|
| 1666 |
+
")\n",
|
| 1667 |
+
"\n",
|
| 1668 |
+
"ax.tick_params(axis='both', which='major', labelsize=16)\n",
|
| 1669 |
+
"ax.set_xlabel(ax.get_xlabel(), fontsize=16)\n",
|
| 1670 |
+
"ax.set_ylabel(ax.get_ylabel(), fontsize=22)\n",
|
| 1671 |
+
"\n",
|
| 1672 |
+
"fig.savefig(\"shap_summary_dot_plot.svg\", format='svg', bbox_inches='tight')\n",
|
| 1673 |
+
"\n",
|
| 1674 |
+
"plt.close(fig)"
|
| 1675 |
+
]
|
| 1676 |
+
},
|
| 1677 |
+
{
|
| 1678 |
+
"cell_type": "code",
|
| 1679 |
+
"execution_count": null,
|
| 1680 |
+
"id": "05f0de1c-3850-4d77-9184-d593451022c1",
|
| 1681 |
+
"metadata": {},
|
| 1682 |
+
"outputs": [],
|
| 1683 |
+
"source": [
|
| 1684 |
+
"from sklearn.tree import DecisionTreeClassifier, export_text\n",
|
| 1685 |
+
"from sklearn import tree\n",
|
| 1686 |
+
"\n",
|
| 1687 |
+
"random_state = 5\n",
|
| 1688 |
+
"\n",
|
| 1689 |
+
"X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1690 |
+
"y = splitted_dataset['depression_category']\n",
|
| 1691 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1692 |
+
"\n",
|
| 1693 |
+
"lof = LocalOutlierFactor()\n",
|
| 1694 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 1695 |
+
"mask = yhat != -1\n",
|
| 1696 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1697 |
+
"\n",
|
| 1698 |
+
"original_columns = X.columns.tolist()\n",
|
| 1699 |
+
"\n",
|
| 1700 |
+
"smote = SMOTE(random_state=random_state)\n",
|
| 1701 |
+
"X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1702 |
+
"\n",
|
| 1703 |
+
"print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 1704 |
+
"print(f\"Number of test labels before resampling: {y_test.value_counts()}\")\n",
|
| 1705 |
+
"sampling_strategy_undersample = {0: 372}\n",
|
| 1706 |
+
"rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1707 |
+
"X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1708 |
+
"\n",
|
| 1709 |
+
"print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1710 |
+
"\n",
|
| 1711 |
+
"# Normalization\n",
|
| 1712 |
+
"# scaler = MinMaxScaler()\n",
|
| 1713 |
+
"# X_res = scaler.fit_transform(X_res)\n",
|
| 1714 |
+
"# X_test = scaler.transform(X_test)\n",
|
| 1715 |
+
"\n",
|
| 1716 |
+
"X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1717 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1718 |
+
"\n",
|
| 1719 |
+
"decision_tree_model = DecisionTreeClassifier(\n",
|
| 1720 |
+
" random_state=0, \n",
|
| 1721 |
+
" criterion='gini', \n",
|
| 1722 |
+
" max_depth=6, \n",
|
| 1723 |
+
" min_samples_leaf=10, \n",
|
| 1724 |
+
" min_samples_split=9\n",
|
| 1725 |
+
")\n",
|
| 1726 |
+
"decision_tree_model.fit(X_res, y_res)\n",
|
| 1727 |
+
"\n",
|
| 1728 |
+
"plt.figure(figsize=(20, 14))\n",
|
| 1729 |
+
"tree.plot_tree(\n",
|
| 1730 |
+
" decision_tree_model, \n",
|
| 1731 |
+
" feature_names=original_columns, \n",
|
| 1732 |
+
" class_names=['depression', 'normal'],\n",
|
| 1733 |
+
" filled=True, \n",
|
| 1734 |
+
" rounded=True, \n",
|
| 1735 |
+
" fontsize=10,\n",
|
| 1736 |
+
" max_depth = 3\n",
|
| 1737 |
+
")\n",
|
| 1738 |
+
"\n",
|
| 1739 |
+
"plt.savefig(\"decision_tree_plot.svg\", format='svg')\n",
|
| 1740 |
+
"plt.close()\n",
|
| 1741 |
+
"\n",
|
| 1742 |
+
"print(\"Decision Tree plot saved as 'decision_tree_plot.svg'\")"
|
| 1743 |
+
]
|
| 1744 |
+
},
|
| 1745 |
+
{
|
| 1746 |
+
"cell_type": "code",
|
| 1747 |
+
"execution_count": null,
|
| 1748 |
+
"id": "b46a92a6-f370-4df4-ac29-8f382fc1820b",
|
| 1749 |
+
"metadata": {
|
| 1750 |
+
"scrolled": true
|
| 1751 |
+
},
|
| 1752 |
+
"outputs": [],
|
| 1753 |
+
"source": [
|
| 1754 |
+
"tree_rules = export_text(decision_tree_model, feature_names=original_columns, max_depth=50)\n",
|
| 1755 |
+
"print(\"Decision rules for the tree (up to depth 3):\")\n",
|
| 1756 |
+
"print(tree_rules)\n",
|
| 1757 |
+
"\n",
|
| 1758 |
+
"node_indicator = decision_tree_model.decision_path(X_test)\n",
|
| 1759 |
+
"\n",
|
| 1760 |
+
"sample_id = 0\n",
|
| 1761 |
+
"node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]]\n",
|
| 1762 |
+
"\n",
|
| 1763 |
+
"print(f\"\\nDecision path for sample {sample_id}:\")\n",
|
| 1764 |
+
"for node_id in node_index:\n",
|
| 1765 |
+
" if X_test.iloc[sample_id, decision_tree_model.tree_.feature[node_id]] <= decision_tree_model.tree_.threshold[node_id]:\n",
|
| 1766 |
+
" threshold_sign = \"<=\"\n",
|
| 1767 |
+
" else:\n",
|
| 1768 |
+
" threshold_sign = \">\"\n",
|
| 1769 |
+
" print(f\"Node {node_id}: (X_test[{sample_id}, {decision_tree_model.tree_.feature[node_id]}] = {X_test.iloc[sample_id, decision_tree_model.tree_.feature[node_id]]}) \"\n",
|
| 1770 |
+
" f\"{threshold_sign} {decision_tree_model.tree_.threshold[node_id]}\")\n",
|
| 1771 |
+
"\n",
|
| 1772 |
+
"# Get prediction for a specific test sample\n",
|
| 1773 |
+
"predicted_class = decision_tree_model.predict([X_test.iloc[sample_id]])\n",
|
| 1774 |
+
"print(f\"\\nPredicted class for test sample {sample_id}: {predicted_class}\")"
|
| 1775 |
+
]
|
| 1776 |
+
}
|
| 1777 |
+
],
|
| 1778 |
+
"metadata": {
|
| 1779 |
+
"kernelspec": {
|
| 1780 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1781 |
+
"language": "python",
|
| 1782 |
+
"name": "python3"
|
| 1783 |
+
},
|
| 1784 |
+
"language_info": {
|
| 1785 |
+
"codemirror_mode": {
|
| 1786 |
+
"name": "ipython",
|
| 1787 |
+
"version": 3
|
| 1788 |
+
},
|
| 1789 |
+
"file_extension": ".py",
|
| 1790 |
+
"mimetype": "text/x-python",
|
| 1791 |
+
"name": "python",
|
| 1792 |
+
"nbconvert_exporter": "python",
|
| 1793 |
+
"pygments_lexer": "ipython3",
|
| 1794 |
+
"version": "3.12.3"
|
| 1795 |
+
}
|
| 1796 |
+
},
|
| 1797 |
+
"nbformat": 4,
|
| 1798 |
+
"nbformat_minor": 5
|
| 1799 |
+
}
|
regressionLayer.ipynb
ADDED
|
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "c4849e1e-1928-4e8b-a12f-37c786b50aca",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Regression Layer"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "dddcb2a8-73d3-4476-b88a-bc1494a2c830",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import pandas as pd\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"import matplotlib.pyplot as plt\n",
|
| 21 |
+
"import seaborn as sns\n",
|
| 22 |
+
"from sklearn.model_selection import cross_val_score, KFold\n",
|
| 23 |
+
"from sklearn.impute import KNNImputer\n",
|
| 24 |
+
"from sklearn.pipeline import make_pipeline\n",
|
| 25 |
+
"from xgboost import XGBClassifier\n",
|
| 26 |
+
"from sklearn.impute import SimpleImputer\n",
|
| 27 |
+
"from sklearn.experimental import enable_iterative_imputer\n",
|
| 28 |
+
"from imblearn.pipeline import make_pipeline as make_pipeline_imb\n",
|
| 29 |
+
"from imblearn.over_sampling import SMOTE,SMOTENC\n",
|
| 30 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 31 |
+
"from collections import Counter\n",
|
| 32 |
+
"from sklearn.metrics import classification_report, accuracy_score, confusion_matrix\n",
|
| 33 |
+
"from sklearn.svm import SVC\n",
|
| 34 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
| 35 |
+
"from sklearn.neighbors import LocalOutlierFactor\n",
|
| 36 |
+
"from sklearn.utils import resample\n",
|
| 37 |
+
"import warnings\n",
|
| 38 |
+
"from imblearn.over_sampling import RandomOverSampler\n",
|
| 39 |
+
"from imblearn.under_sampling import RandomUnderSampler\n",
|
| 40 |
+
"from sklearn.preprocessing import MinMaxScaler\n",
|
| 41 |
+
"from imblearn.pipeline import Pipeline\n",
|
| 42 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
| 43 |
+
"from sklearn.preprocessing import LabelEncoder, PowerTransformer\n",
|
| 44 |
+
"from collections import defaultdict\n",
|
| 45 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc\n",
|
| 46 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
| 47 |
+
"from sklearn.neural_network import MLPClassifier\n",
|
| 48 |
+
"import Orange\n",
|
| 49 |
+
"from scipy.stats import friedmanchisquare, rankdata\n",
|
| 50 |
+
"import shap\n",
|
| 51 |
+
"import scikit_posthocs as sp\n",
|
| 52 |
+
"from sklearn.feature_selection import SelectFromModel\n",
|
| 53 |
+
"from IPython.display import FileLink, display\n",
|
| 54 |
+
"import math\n",
|
| 55 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 56 |
+
"from skopt.space import Integer, Real\n",
|
| 57 |
+
"from sklearn.model_selection import StratifiedKFold\n",
|
| 58 |
+
"from skopt import BayesSearchCV\n",
|
| 59 |
+
"import xgboost as xgb\n",
|
| 60 |
+
"from imblearn.over_sampling import SMOTE\n",
|
| 61 |
+
"from sklearn.tree import DecisionTreeClassifier, export_text\n",
|
| 62 |
+
"from sklearn import tree\n",
|
| 63 |
+
"from skopt.space import Real, Integer, Categorical\n",
|
| 64 |
+
"from skopt.callbacks import VerboseCallback\n",
|
| 65 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 66 |
+
"from sklearn.preprocessing import MinMaxScaler\n",
|
| 67 |
+
"from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor, ExtraTreesRegressor, AdaBoostRegressor, VotingRegressor\n",
|
| 68 |
+
"from catboost import CatBoostRegressor\n",
|
| 69 |
+
"from xgboost import XGBRegressor\n",
|
| 70 |
+
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
|
| 71 |
+
"from sklearn.feature_selection import SelectFromModel\n",
|
| 72 |
+
"from sklearn.neighbors import LocalOutlierFactor\n",
|
| 73 |
+
"from lightgbm import LGBMRegressor\n",
|
| 74 |
+
"from IPython.display import display, FileLink"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "markdown",
|
| 79 |
+
"id": "39d6683c-fd3f-4daa-afdc-2e7f83c3fce3",
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"source": [
|
| 82 |
+
"### Preparation before training"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "code",
|
| 87 |
+
"execution_count": null,
|
| 88 |
+
"id": "ec17e47a-8d92-498d-8f28-d8259d6ebc4e",
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [],
|
| 91 |
+
"source": [
|
| 92 |
+
"# Call Dataset\n",
|
| 93 |
+
"pd.set_option('display.max_rows', 10)\n",
|
| 94 |
+
"initial_df = pd.read_csv('3labelv4Regression.csv')\n",
|
| 95 |
+
"initial_df.info()"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"id": "d3751352-73bc-42c6-b6a5-84a3badf14ff",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"outputs": [],
|
| 104 |
+
"source": [
|
| 105 |
+
"# All categorical features except for label\n",
|
| 106 |
+
"cols = initial_df.columns\n",
|
| 107 |
+
"num_cols = initial_df._get_numeric_data().columns\n",
|
| 108 |
+
"categorical_features = list(set(cols) - set(num_cols))\n",
|
| 109 |
+
"categorical_features.remove('depression_category')\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"# Label Encode all categorical, but keep missing values\n",
|
| 112 |
+
"le_initial_df = initial_df.copy()\n",
|
| 113 |
+
"dropped_labels = le_initial_df['depression_category']\n",
|
| 114 |
+
"le_initial_df = le_initial_df.drop('depression_category', axis = 1)\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"for col in le_initial_df.columns:\n",
|
| 117 |
+
" if le_initial_df[col].dtype == 'object':\n",
|
| 118 |
+
" le_initial_df[col] = le_initial_df[col].fillna('missing')\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" label_encoder = LabelEncoder()\n",
|
| 121 |
+
" le_initial_df[col] = label_encoder.fit_transform(le_initial_df[col])\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" missing_value_index = np.where(label_encoder.classes_ == 'missing')[0]\n",
|
| 124 |
+
" \n",
|
| 125 |
+
" le_initial_df[col] = le_initial_df[col].replace(missing_value_index, np.nan)\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"le_initial_df = pd.concat([le_initial_df, dropped_labels], axis = 1)"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"id": "9fa21a95-21f1-4274-86ba-d7977813066b",
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"outputs": [],
|
| 136 |
+
"source": [
|
| 137 |
+
"le_initial_df"
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "code",
|
| 142 |
+
"execution_count": null,
|
| 143 |
+
"id": "0ba54a81-d4de-4884-99d3-29867eb7ea40",
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"outputs": [],
|
| 146 |
+
"source": [
|
| 147 |
+
"# Seperate and Combine\n",
|
| 148 |
+
"le_df_normal = le_initial_df[le_initial_df['depression_category'] == 'normal']\n",
|
| 149 |
+
"le_df_mild = le_initial_df[le_initial_df['depression_category'] == 'mild']\n",
|
| 150 |
+
"le_df_moderatesevere = le_initial_df[le_initial_df['depression_category'] == 'moderatesevere']\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"le_df_depression = pd.concat([le_df_normal, le_df_mild, le_df_moderatesevere], ignore_index = False)\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"le_df_depression['depression_category'] = 'depression'\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"# Check depression category counts\n",
|
| 157 |
+
"dataframes = [le_df_normal, le_df_mild, le_df_moderatesevere]\n",
|
| 158 |
+
"le_initial_df = pd.concat(dataframes, ignore_index=True)\n",
|
| 159 |
+
"label_counts = le_initial_df['depression_category'].value_counts()\n",
|
| 160 |
+
"label_counts"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": null,
|
| 166 |
+
"id": "c517213f-f6bb-4298-99ee-3b29f6f7d0cb",
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"# Some outlier.\n",
|
| 171 |
+
"# threshold = int(0.8 * le_df_normal.shape[1])\n",
|
| 172 |
+
"# le_df_normal = le_df_normal.dropna(thresh = threshold)\n",
|
| 173 |
+
"# threshold = int(0.8 * le_df_depression.shape[1])\n",
|
| 174 |
+
"# le_df_depression = le_df_depression.dropna(thresh = threshold)\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"# Check depression category counts\n",
|
| 177 |
+
"dataframes = [le_df_normal, le_df_mild, le_df_moderatesevere]\n",
|
| 178 |
+
"le_initial_df = pd.concat(dataframes, ignore_index=True)\n",
|
| 179 |
+
"label_counts = le_initial_df['depression_category'].value_counts()"
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"cell_type": "code",
|
| 184 |
+
"execution_count": null,
|
| 185 |
+
"id": "8c5bdf52-c645-4a11-920a-579e28db1a50",
|
| 186 |
+
"metadata": {},
|
| 187 |
+
"outputs": [],
|
| 188 |
+
"source": [
|
| 189 |
+
"# Imputation\n",
|
| 190 |
+
"different_le_dfs = [le_df_normal, le_df_mild, le_df_moderatesevere]\n",
|
| 191 |
+
"imputed_le_dfs = []\n",
|
| 192 |
+
"from sklearn.impute import IterativeImputer\n",
|
| 193 |
+
"for le_df in different_le_dfs:\n",
|
| 194 |
+
" y = le_df['depression_category']\n",
|
| 195 |
+
" X = le_df.drop('depression_category', axis = 1)\n",
|
| 196 |
+
" \n",
|
| 197 |
+
" imputer = SimpleImputer(strategy='median')\n",
|
| 198 |
+
" imputed_data = imputer.fit_transform(X)\n",
|
| 199 |
+
" imputed_df = pd.DataFrame(imputed_data, columns = X.columns)\n",
|
| 200 |
+
"\n",
|
| 201 |
+
" imputed_df['depression_category'] = y.reset_index(drop = True)\n",
|
| 202 |
+
" imputed_le_dfs.append(imputed_df)\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"concatenated_le_dfs = pd.concat(imputed_le_dfs, ignore_index = True)\n",
|
| 205 |
+
"concatenated_le_dfs"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": null,
|
| 211 |
+
"id": "0e871a5f-beab-4f90-b8c4-e922ef86d0f0",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"outputs": [],
|
| 214 |
+
"source": [
|
| 215 |
+
"# Full label encode depression category\n",
|
| 216 |
+
"fully_LE_concatenated_le_dfs = concatenated_le_dfs.copy()\n",
|
| 217 |
+
"fully_LE_concatenated_le_dfs['depression_category'] = label_encoder.fit_transform(fully_LE_concatenated_le_dfs['depression_category'])\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"# The dataset after category connect, imputation, and label encoding\n",
|
| 220 |
+
"splitted_dataset = fully_LE_concatenated_le_dfs.copy()\n",
|
| 221 |
+
"splitted_dataset = splitted_dataset.drop('depression_category', axis = 1)\n",
|
| 222 |
+
"splitted_dataset"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"cell_type": "code",
|
| 227 |
+
"execution_count": null,
|
| 228 |
+
"id": "3e602f4d-efb2-4ac9-99bb-c6fbeca791cc",
|
| 229 |
+
"metadata": {},
|
| 230 |
+
"outputs": [],
|
| 231 |
+
"source": [
|
| 232 |
+
"warnings.filterwarnings('ignore')"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "markdown",
|
| 237 |
+
"id": "68563dc9-38c9-4ea5-8c08-9de904f58f43",
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"source": [
|
| 240 |
+
"### Regression Training"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "code",
|
| 245 |
+
"execution_count": null,
|
| 246 |
+
"id": "5e339a41-e484-4a11-84f0-770cbaacb09d",
|
| 247 |
+
"metadata": {
|
| 248 |
+
"scrolled": true
|
| 249 |
+
},
|
| 250 |
+
"outputs": [],
|
| 251 |
+
"source": [
|
| 252 |
+
"# Optimized parameters\n",
|
| 253 |
+
"regressors = {\n",
|
| 254 |
+
" 'CBR': CatBoostRegressor(verbose=0, iterations=2000, learning_rate=0.01, depth=5),\n",
|
| 255 |
+
" 'XGBR': XGBRegressor(learning_rate=0.04403027347366962, max_depth=3, n_estimators=238),\n",
|
| 256 |
+
" 'LGBMR': LGBMRegressor(learning_rate=0.02904035023286438, num_leaves=20, n_estimators=170),\n",
|
| 257 |
+
" 'GBR': GradientBoostingRegressor(learning_rate=0.03, max_depth=2, n_estimators=700),\n",
|
| 258 |
+
" 'RFR': RandomForestRegressor(max_depth=12, max_features=1.0, n_estimators=300),\n",
|
| 259 |
+
" 'ETR': ExtraTreesRegressor(max_depth=12, max_features=1.0, n_estimators=64),\n",
|
| 260 |
+
" 'ABR': AdaBoostRegressor(learning_rate=0.29915504677867777, n_estimators=92)\n",
|
| 261 |
+
"}\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"# Default parameters\n",
|
| 264 |
+
"# regressors = {\n",
|
| 265 |
+
"# 'CBR': CatBoostRegressor(verbose=0),\n",
|
| 266 |
+
"# 'XGBR': XGBRegressor(),\n",
|
| 267 |
+
"# 'LGBMR': LGBMRegressor(),\n",
|
| 268 |
+
"# 'GBR': GradientBoostingRegressor(),\n",
|
| 269 |
+
"# 'RFR': RandomForestRegressor(),\n",
|
| 270 |
+
"# 'ETR': ExtraTreesRegressor(),\n",
|
| 271 |
+
"# 'ABR': AdaBoostRegressor()\n",
|
| 272 |
+
"# }\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"voting_regressor = VotingRegressor(estimators=[\n",
|
| 275 |
+
" ('cbr', regressors['CBR']),\n",
|
| 276 |
+
" ('xgbr', regressors['XGBR']),\n",
|
| 277 |
+
" ('gbr', regressors['GBR']),\n",
|
| 278 |
+
" ('abr', regressors['ABR'])\n",
|
| 279 |
+
"])\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"regressors['Voting'] = voting_regressor\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"metric_sums = {name: {'rmse': 0, 'mae': 0, 'r2': 0} for name in regressors.keys()}\n",
|
| 284 |
+
"metric_stds = {name: {'rmse': 0, 'mae': 0, 'r2': 0} for name in regressors.keys()}\n",
|
| 285 |
+
"rmse_scores = {name: [] for name in regressors.keys()}\n",
|
| 286 |
+
"mae_scores = {name: [] for name in regressors.keys()}\n",
|
| 287 |
+
"r2_scores = {name: [] for name in regressors.keys()}\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"for random_state in range(10):\n",
|
| 290 |
+
" print(f'Processing for Random State: {random_state}')\n",
|
| 291 |
+
"\n",
|
| 292 |
+
" X = splitted_dataset.drop('total_sum', axis=1)\n",
|
| 293 |
+
" y = splitted_dataset['total_sum']\n",
|
| 294 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=random_state)\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" lof = LocalOutlierFactor()\n",
|
| 297 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 298 |
+
"\n",
|
| 299 |
+
" mask = yhat != -1\n",
|
| 300 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" original_columns = X.columns.tolist()\n",
|
| 303 |
+
"\n",
|
| 304 |
+
" print(f\"Number of training labels after outlier removal: {len(y_train)}\")\n",
|
| 305 |
+
" print(f\"Number of test labels: {len(y_test)}\")\n",
|
| 306 |
+
"\n",
|
| 307 |
+
" scaler = MinMaxScaler()\n",
|
| 308 |
+
" X_train = scaler.fit_transform(X_train)\n",
|
| 309 |
+
" X_test = scaler.transform(X_test)\n",
|
| 310 |
+
" \n",
|
| 311 |
+
" X_train = pd.DataFrame(X_train, columns=original_columns)\n",
|
| 312 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" # Feature selection using XGBRegressor\n",
|
| 315 |
+
" xgb = XGBRegressor(random_state=random_state)\n",
|
| 316 |
+
" xgb.fit(X_train, y_train)\n",
|
| 317 |
+
" selector = SelectFromModel(xgb, prefit=True)\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" importance = np.abs(xgb.feature_importances_)\n",
|
| 320 |
+
" indices = np.argsort(importance)[::-1]\n",
|
| 321 |
+
" important_features = [original_columns[i] for i in indices[:50]]\n",
|
| 322 |
+
"\n",
|
| 323 |
+
" for reg_name, reg in regressors.items():\n",
|
| 324 |
+
" selected_features = important_features\n",
|
| 325 |
+
" \n",
|
| 326 |
+
" X_train_fi = pd.DataFrame(X_train, columns=original_columns)[selected_features]\n",
|
| 327 |
+
" X_test_fi = pd.DataFrame(X_test, columns=original_columns)[selected_features]\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" reg.fit(X_train_fi, y_train)\n",
|
| 330 |
+
" y_pred = reg.predict(X_test_fi)\n",
|
| 331 |
+
"\n",
|
| 332 |
+
" y_pred = np.round(y_pred)\n",
|
| 333 |
+
" \n",
|
| 334 |
+
" rmse = np.sqrt(mean_squared_error(y_test, y_pred))\n",
|
| 335 |
+
" mae = mean_absolute_error(y_test, y_pred)\n",
|
| 336 |
+
" r2 = r2_score(y_test, y_pred)\n",
|
| 337 |
+
"\n",
|
| 338 |
+
" metric_sums[reg_name]['rmse'] += rmse\n",
|
| 339 |
+
" metric_sums[reg_name]['mae'] += mae\n",
|
| 340 |
+
" metric_sums[reg_name]['r2'] += r2\n",
|
| 341 |
+
" rmse_scores[reg_name].append(rmse)\n",
|
| 342 |
+
" mae_scores[reg_name].append(mae)\n",
|
| 343 |
+
" r2_scores[reg_name].append(r2)"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"cell_type": "code",
|
| 348 |
+
"execution_count": null,
|
| 349 |
+
"id": "0929bf2c-fd41-4e0e-8070-4a6317f3efab",
|
| 350 |
+
"metadata": {},
|
| 351 |
+
"outputs": [],
|
| 352 |
+
"source": [
|
| 353 |
+
"# Calculate and print the average metrics and their standard deviations\n",
|
| 354 |
+
"for reg_name in regressors.keys():\n",
|
| 355 |
+
" avg_rmse = metric_sums[reg_name]['rmse'] / 10\n",
|
| 356 |
+
" avg_mae = metric_sums[reg_name]['mae'] / 10\n",
|
| 357 |
+
" avg_r2 = metric_sums[reg_name]['r2'] / 10\n",
|
| 358 |
+
" std_rmse = np.std(rmse_scores[reg_name])\n",
|
| 359 |
+
" std_mae = np.std(mae_scores[reg_name])\n",
|
| 360 |
+
" std_r2 = np.std(r2_scores[reg_name])\n",
|
| 361 |
+
" \n",
|
| 362 |
+
" print(f\"Regressor: {reg_name}\")\n",
|
| 363 |
+
" print(f\"Average RMSE: {avg_rmse} ± {std_rmse}\")\n",
|
| 364 |
+
" print(f\"Average MAE: {avg_mae} ± {std_mae}\")\n",
|
| 365 |
+
" print(f\"Average R2: {avg_r2} ± {std_r2}\")\n",
|
| 366 |
+
" print(\"------\")"
|
| 367 |
+
]
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"cell_type": "markdown",
|
| 371 |
+
"id": "07f62248-70b4-479b-b782-3aec714111f9",
|
| 372 |
+
"metadata": {},
|
| 373 |
+
"source": [
|
| 374 |
+
"### Shap n FN"
|
| 375 |
+
]
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"cell_type": "code",
|
| 379 |
+
"execution_count": null,
|
| 380 |
+
"id": "f6ce4639-53d9-4dcd-a49c-80555326f197",
|
| 381 |
+
"metadata": {},
|
| 382 |
+
"outputs": [],
|
| 383 |
+
"source": [
|
| 384 |
+
"# Preparation code to make CD diagram from older version of Orange\n",
|
| 385 |
+
"def compute_CD(avranks, n, alpha=\"0.05\", test=\"nemenyi\"):\n",
|
| 386 |
+
" \"\"\"\n",
|
| 387 |
+
" Returns critical difference for Nemenyi or Bonferroni-Dunn test\n",
|
| 388 |
+
" according to given alpha (either alpha=\"0.05\" or alpha=\"0.1\") for average\n",
|
| 389 |
+
" ranks and number of tested datasets N. Test can be either \"nemenyi\" for\n",
|
| 390 |
+
" for Nemenyi two tailed test or \"bonferroni-dunn\" for Bonferroni-Dunn test.\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" This function is deprecated and will be removed in Orange 3.34.\n",
|
| 393 |
+
" \"\"\"\n",
|
| 394 |
+
" k = len(avranks)\n",
|
| 395 |
+
" d = {(\"nemenyi\", \"0.05\"): [0, 0, 1.959964, 2.343701, 2.569032, 2.727774,\n",
|
| 396 |
+
" 2.849705, 2.94832, 3.030879, 3.101730, 3.163684,\n",
|
| 397 |
+
" 3.218654, 3.268004, 3.312739, 3.353618, 3.39123,\n",
|
| 398 |
+
" 3.426041, 3.458425, 3.488685, 3.517073,\n",
|
| 399 |
+
" 3.543799],\n",
|
| 400 |
+
" (\"nemenyi\", \"0.1\"): [0, 0, 1.644854, 2.052293, 2.291341, 2.459516,\n",
|
| 401 |
+
" 2.588521, 2.692732, 2.779884, 2.854606, 2.919889,\n",
|
| 402 |
+
" 2.977768, 3.029694, 3.076733, 3.119693, 3.159199,\n",
|
| 403 |
+
" 3.195743, 3.229723, 3.261461, 3.291224, 3.319233],\n",
|
| 404 |
+
" (\"bonferroni-dunn\", \"0.05\"): [0, 0, 1.960, 2.241, 2.394, 2.498, 2.576,\n",
|
| 405 |
+
" 2.638, 2.690, 2.724, 2.773],\n",
|
| 406 |
+
" (\"bonferroni-dunn\", \"0.1\"): [0, 0, 1.645, 1.960, 2.128, 2.241, 2.326,\n",
|
| 407 |
+
" 2.394, 2.450, 2.498, 2.539]}\n",
|
| 408 |
+
" q = d[(test, alpha)]\n",
|
| 409 |
+
" cd = q[k] * (k * (k + 1) / (6.0 * n)) ** 0.5\n",
|
| 410 |
+
" return cd\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"def graph_ranks(avranks, names, cd=None, cdmethod=None, lowv=None, highv=None,\n",
|
| 414 |
+
" width=6, textspace=1, reverse=False, filename=None, **kwargs):\n",
|
| 415 |
+
" \"\"\"\n",
|
| 416 |
+
" Draws a CD graph, which is used to display the differences in methods'\n",
|
| 417 |
+
" performance. See Janez Demsar, Statistical Comparisons of Classifiers over\n",
|
| 418 |
+
" Multiple Data Sets, 7(Jan):1--30, 2006.\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" Needs matplotlib to work.\n",
|
| 421 |
+
"\n",
|
| 422 |
+
" The image is ploted on `plt` imported using\n",
|
| 423 |
+
" `import matplotlib.pyplot as plt`.\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" This function is deprecated and will be removed in Orange 3.34.\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" Args:\n",
|
| 428 |
+
" avranks (list of float): average ranks of methods.\n",
|
| 429 |
+
" names (list of str): names of methods.\n",
|
| 430 |
+
" cd (float): Critical difference used for statistically significance of\n",
|
| 431 |
+
" difference between methods.\n",
|
| 432 |
+
" cdmethod (int, optional): the method that is compared with other methods\n",
|
| 433 |
+
" If omitted, show pairwise comparison of methods\n",
|
| 434 |
+
" lowv (int, optional): the lowest shown rank\n",
|
| 435 |
+
" highv (int, optional): the highest shown rank\n",
|
| 436 |
+
" width (int, optional): default width in inches (default: 6)\n",
|
| 437 |
+
" textspace (int, optional): space on figure sides (in inches) for the\n",
|
| 438 |
+
" method names (default: 1)\n",
|
| 439 |
+
" reverse (bool, optional): if set to `True`, the lowest rank is on the\n",
|
| 440 |
+
" right (default: `False`)\n",
|
| 441 |
+
" filename (str, optional): output file name (with extension). If not\n",
|
| 442 |
+
" given, the function does not write a file.\n",
|
| 443 |
+
" \"\"\"\n",
|
| 444 |
+
" try:\n",
|
| 445 |
+
" import matplotlib.pyplot as plt\n",
|
| 446 |
+
" from matplotlib.backends.backend_agg import FigureCanvasAgg\n",
|
| 447 |
+
" except ImportError:\n",
|
| 448 |
+
" raise ImportError(\"Function graph_ranks requires matplotlib.\")\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" width = float(width)\n",
|
| 451 |
+
" textspace = float(textspace)\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" def nth(l, n):\n",
|
| 454 |
+
" \"\"\"\n",
|
| 455 |
+
" Returns only nth elemnt in a list.\n",
|
| 456 |
+
" \"\"\"\n",
|
| 457 |
+
" n = lloc(l, n)\n",
|
| 458 |
+
" return [a[n] for a in l]\n",
|
| 459 |
+
"\n",
|
| 460 |
+
" def lloc(l, n):\n",
|
| 461 |
+
" \"\"\"\n",
|
| 462 |
+
" List location in list of list structure.\n",
|
| 463 |
+
" Enable the use of negative locations:\n",
|
| 464 |
+
" -1 is the last element, -2 second last...\n",
|
| 465 |
+
" \"\"\"\n",
|
| 466 |
+
" if n < 0:\n",
|
| 467 |
+
" return len(l[0]) + n\n",
|
| 468 |
+
" else:\n",
|
| 469 |
+
" return n\n",
|
| 470 |
+
"\n",
|
| 471 |
+
" def mxrange(lr):\n",
|
| 472 |
+
" \"\"\"\n",
|
| 473 |
+
" Multiple xranges. Can be used to traverse matrices.\n",
|
| 474 |
+
" This function is very slow due to unknown number of\n",
|
| 475 |
+
" parameters.\n",
|
| 476 |
+
"\n",
|
| 477 |
+
" >>> mxrange([3,5])\n",
|
| 478 |
+
" [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]\n",
|
| 479 |
+
"\n",
|
| 480 |
+
" >>> mxrange([[3,5,1],[9,0,-3]])\n",
|
| 481 |
+
" [(3, 9), (3, 6), (3, 3), (4, 9), (4, 6), (4, 3)]\n",
|
| 482 |
+
"\n",
|
| 483 |
+
" \"\"\"\n",
|
| 484 |
+
" if not len(lr):\n",
|
| 485 |
+
" yield ()\n",
|
| 486 |
+
" else:\n",
|
| 487 |
+
" # it can work with single numbers\n",
|
| 488 |
+
" index = lr[0]\n",
|
| 489 |
+
" if isinstance(index, int):\n",
|
| 490 |
+
" index = [index]\n",
|
| 491 |
+
" for a in range(*index):\n",
|
| 492 |
+
" for b in mxrange(lr[1:]):\n",
|
| 493 |
+
" yield tuple([a] + list(b))\n",
|
| 494 |
+
"\n",
|
| 495 |
+
" def print_figure(fig, *args, **kwargs):\n",
|
| 496 |
+
" canvas = FigureCanvasAgg(fig)\n",
|
| 497 |
+
" canvas.print_figure(*args, **kwargs)\n",
|
| 498 |
+
"\n",
|
| 499 |
+
" sums = avranks\n",
|
| 500 |
+
"\n",
|
| 501 |
+
" tempsort = sorted([(a, i) for i, a in enumerate(sums)], reverse=reverse)\n",
|
| 502 |
+
" ssums = nth(tempsort, 0)\n",
|
| 503 |
+
" sortidx = nth(tempsort, 1)\n",
|
| 504 |
+
" nnames = [names[x] for x in sortidx]\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" if lowv is None:\n",
|
| 507 |
+
" lowv = min(1, int(math.floor(min(ssums))))\n",
|
| 508 |
+
" if highv is None:\n",
|
| 509 |
+
" highv = max(len(avranks), int(math.ceil(max(ssums))))\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" cline = 0.4\n",
|
| 512 |
+
"\n",
|
| 513 |
+
" k = len(sums)\n",
|
| 514 |
+
"\n",
|
| 515 |
+
" lines = None\n",
|
| 516 |
+
"\n",
|
| 517 |
+
" linesblank = 0\n",
|
| 518 |
+
" scalewidth = width - 2 * textspace\n",
|
| 519 |
+
"\n",
|
| 520 |
+
" def rankpos(rank):\n",
|
| 521 |
+
" if not reverse:\n",
|
| 522 |
+
" a = rank - lowv\n",
|
| 523 |
+
" else:\n",
|
| 524 |
+
" a = highv - rank\n",
|
| 525 |
+
" return textspace + scalewidth / (highv - lowv) * a\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" distanceh = 0.25\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" if cd and cdmethod is None:\n",
|
| 530 |
+
" # get pairs of non significant methods\n",
|
| 531 |
+
"\n",
|
| 532 |
+
" def get_lines(sums, hsd):\n",
|
| 533 |
+
" # get all pairs\n",
|
| 534 |
+
" lsums = len(sums)\n",
|
| 535 |
+
" allpairs = [(i, j) for i, j in mxrange([[lsums], [lsums]]) if j > i]\n",
|
| 536 |
+
" # remove not significant\n",
|
| 537 |
+
" notSig = [(i, j) for i, j in allpairs\n",
|
| 538 |
+
" if abs(sums[i] - sums[j]) <= hsd]\n",
|
| 539 |
+
" # keep only longest\n",
|
| 540 |
+
"\n",
|
| 541 |
+
" def no_longer(ij_tuple, notSig):\n",
|
| 542 |
+
" i, j = ij_tuple\n",
|
| 543 |
+
" for i1, j1 in notSig:\n",
|
| 544 |
+
" if (i1 <= i and j1 > j) or (i1 < i and j1 >= j):\n",
|
| 545 |
+
" return False\n",
|
| 546 |
+
" return True\n",
|
| 547 |
+
"\n",
|
| 548 |
+
" longest = [(i, j) for i, j in notSig if no_longer((i, j), notSig)]\n",
|
| 549 |
+
"\n",
|
| 550 |
+
" return longest\n",
|
| 551 |
+
"\n",
|
| 552 |
+
" lines = get_lines(ssums, cd)\n",
|
| 553 |
+
" linesblank = 0.2 + 0.2 + (len(lines) - 1) * 0.1\n",
|
| 554 |
+
"\n",
|
| 555 |
+
" # add scale\n",
|
| 556 |
+
" distanceh = 0.25\n",
|
| 557 |
+
" cline += distanceh\n",
|
| 558 |
+
"\n",
|
| 559 |
+
" # calculate height needed height of an image\n",
|
| 560 |
+
" minnotsignificant = max(2 * 0.2, linesblank)\n",
|
| 561 |
+
" height = cline + ((k + 1) / 2) * 0.2 + minnotsignificant\n",
|
| 562 |
+
"\n",
|
| 563 |
+
" fig = plt.figure(figsize=(width, height))\n",
|
| 564 |
+
" fig.set_facecolor('white')\n",
|
| 565 |
+
" ax = fig.add_axes([0, 0, 1, 1]) # reverse y axis\n",
|
| 566 |
+
" ax.set_axis_off()\n",
|
| 567 |
+
"\n",
|
| 568 |
+
" hf = 1. / height # height factor\n",
|
| 569 |
+
" wf = 1. / width\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" def hfl(l):\n",
|
| 572 |
+
" return [a * hf for a in l]\n",
|
| 573 |
+
"\n",
|
| 574 |
+
" def wfl(l):\n",
|
| 575 |
+
" return [a * wf for a in l]\n",
|
| 576 |
+
"\n",
|
| 577 |
+
"\n",
|
| 578 |
+
" # Upper left corner is (0,0).\n",
|
| 579 |
+
" ax.plot([0, 1], [0, 1], c=\"w\")\n",
|
| 580 |
+
" ax.set_xlim(0, 1)\n",
|
| 581 |
+
" ax.set_ylim(1, 0)\n",
|
| 582 |
+
"\n",
|
| 583 |
+
" def line(l, color='k', **kwargs):\n",
|
| 584 |
+
" \"\"\"\n",
|
| 585 |
+
" Input is a list of pairs of points.\n",
|
| 586 |
+
" \"\"\"\n",
|
| 587 |
+
" ax.plot(wfl(nth(l, 0)), hfl(nth(l, 1)), color=color, **kwargs)\n",
|
| 588 |
+
"\n",
|
| 589 |
+
" def text(x, y, s, *args, **kwargs):\n",
|
| 590 |
+
" ax.text(wf * x, hf * y, s, fontsize = 14, *args, **kwargs)\n",
|
| 591 |
+
"\n",
|
| 592 |
+
" line([(textspace, cline), (width - textspace, cline)], linewidth=0.7)\n",
|
| 593 |
+
"\n",
|
| 594 |
+
" bigtick = 0.1\n",
|
| 595 |
+
" smalltick = 0.05\n",
|
| 596 |
+
"\n",
|
| 597 |
+
" tick = None\n",
|
| 598 |
+
" for a in list(np.arange(lowv, highv, 0.5)) + [highv]:\n",
|
| 599 |
+
" tick = smalltick\n",
|
| 600 |
+
" if a == int(a):\n",
|
| 601 |
+
" tick = bigtick\n",
|
| 602 |
+
" line([(rankpos(a), cline - tick / 2),\n",
|
| 603 |
+
" (rankpos(a), cline)],\n",
|
| 604 |
+
" linewidth=0.7)\n",
|
| 605 |
+
"\n",
|
| 606 |
+
" for a in range(lowv, highv + 1):\n",
|
| 607 |
+
" text(rankpos(a), cline - tick / 2 - 0.05, str(a),\n",
|
| 608 |
+
" ha=\"center\", va=\"bottom\")\n",
|
| 609 |
+
"\n",
|
| 610 |
+
" k = len(ssums)\n",
|
| 611 |
+
"\n",
|
| 612 |
+
" for i in range(math.ceil(k / 2)):\n",
|
| 613 |
+
" chei = cline + minnotsignificant + i * 0.2\n",
|
| 614 |
+
" line([(rankpos(ssums[i]), cline),\n",
|
| 615 |
+
" (rankpos(ssums[i]), chei),\n",
|
| 616 |
+
" (textspace - 0.1, chei)],\n",
|
| 617 |
+
" linewidth=0.7)\n",
|
| 618 |
+
" text(textspace - 0.2, chei, nnames[i], ha=\"right\", va=\"center\")\n",
|
| 619 |
+
"\n",
|
| 620 |
+
" for i in range(math.ceil(k / 2), k):\n",
|
| 621 |
+
" chei = cline + minnotsignificant + (k - i - 1) * 0.2\n",
|
| 622 |
+
" line([(rankpos(ssums[i]), cline),\n",
|
| 623 |
+
" (rankpos(ssums[i]), chei),\n",
|
| 624 |
+
" (textspace + scalewidth + 0.1, chei)],\n",
|
| 625 |
+
" linewidth=0.7)\n",
|
| 626 |
+
" text(textspace + scalewidth + 0.2, chei, nnames[i],\n",
|
| 627 |
+
" ha=\"left\", va=\"center\")\n",
|
| 628 |
+
"\n",
|
| 629 |
+
" if cd and cdmethod is None:\n",
|
| 630 |
+
" # upper scale\n",
|
| 631 |
+
" if not reverse:\n",
|
| 632 |
+
" begin, end = rankpos(lowv), rankpos(lowv + cd)\n",
|
| 633 |
+
" else:\n",
|
| 634 |
+
" begin, end = rankpos(highv), rankpos(highv - cd)\n",
|
| 635 |
+
"\n",
|
| 636 |
+
" line([(begin, distanceh), (end, distanceh)], linewidth=0.7)\n",
|
| 637 |
+
" line([(begin, distanceh + bigtick / 2),\n",
|
| 638 |
+
" (begin, distanceh - bigtick / 2)],\n",
|
| 639 |
+
" linewidth=0.7)\n",
|
| 640 |
+
" line([(end, distanceh + bigtick / 2),\n",
|
| 641 |
+
" (end, distanceh - bigtick / 2)],\n",
|
| 642 |
+
" linewidth=0.7)\n",
|
| 643 |
+
" text((begin + end) / 2, distanceh - 0.05, \"CD\",\n",
|
| 644 |
+
" ha=\"center\", va=\"bottom\")\n",
|
| 645 |
+
"\n",
|
| 646 |
+
" # no-significance lines\n",
|
| 647 |
+
" def draw_lines(lines, side=0.05, height=0.1):\n",
|
| 648 |
+
" start = cline + 0.2\n",
|
| 649 |
+
" for l, r in lines:\n",
|
| 650 |
+
" line([(rankpos(ssums[l]) - side, start),\n",
|
| 651 |
+
" (rankpos(ssums[r]) + side, start)],\n",
|
| 652 |
+
" linewidth=2.5)\n",
|
| 653 |
+
" start += height\n",
|
| 654 |
+
"\n",
|
| 655 |
+
" draw_lines(lines)\n",
|
| 656 |
+
"\n",
|
| 657 |
+
" elif cd:\n",
|
| 658 |
+
" begin = rankpos(avranks[cdmethod] - cd)\n",
|
| 659 |
+
" end = rankpos(avranks[cdmethod] + cd)\n",
|
| 660 |
+
" line([(begin, cline), (end, cline)],\n",
|
| 661 |
+
" linewidth=2.5)\n",
|
| 662 |
+
" line([(begin, cline + bigtick / 2),\n",
|
| 663 |
+
" (begin, cline - bigtick / 2)],\n",
|
| 664 |
+
" linewidth=2.5)\n",
|
| 665 |
+
" line([(end, cline + bigtick / 2),\n",
|
| 666 |
+
" (end, cline - bigtick / 2)],\n",
|
| 667 |
+
" linewidth=2.5)\n",
|
| 668 |
+
"\n",
|
| 669 |
+
" if filename:\n",
|
| 670 |
+
" print_figure(fig, filename, **kwargs)"
|
| 671 |
+
]
|
| 672 |
+
},
|
| 673 |
+
{
|
| 674 |
+
"cell_type": "code",
|
| 675 |
+
"execution_count": null,
|
| 676 |
+
"id": "eeeb1424-63f5-4338-bdc9-5e36e8e4ac09",
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"outputs": [],
|
| 679 |
+
"source": [
|
| 680 |
+
"# FN\n",
|
| 681 |
+
"df = pd.DataFrame(rmse_scores)\n",
|
| 682 |
+
"df\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"scores = [df[col].values for col in df.columns]\n",
|
| 685 |
+
"\n",
|
| 686 |
+
"stat, p = friedmanchisquare(*scores)\n",
|
| 687 |
+
"print(f'Friedman Test Statistic: {stat}, p-value: {p}')\n",
|
| 688 |
+
"\n",
|
| 689 |
+
"ranks = df.rank(axis=1, method='average')\n",
|
| 690 |
+
"average_ranks = ranks.mean().values\n",
|
| 691 |
+
"\n",
|
| 692 |
+
"n_datasets = df.shape[0]\n",
|
| 693 |
+
"alpha = 0.05\n",
|
| 694 |
+
"\n",
|
| 695 |
+
"from scikit_posthocs import posthoc_nemenyi_friedman\n",
|
| 696 |
+
"cd = np.sqrt((len(df.columns) * (len(df.columns) + 1)) / (6 * n_datasets)) * np.sqrt(2 / alpha)\n",
|
| 697 |
+
"print(f'Critical Difference: {cd}')\n",
|
| 698 |
+
"\n",
|
| 699 |
+
"classifiers = [f\"{clf} ({rank:.2f})\" for clf, rank in zip(df.columns, average_ranks)]\n",
|
| 700 |
+
"\n",
|
| 701 |
+
"plt.figure(figsize=(16, 10))\n",
|
| 702 |
+
"\n",
|
| 703 |
+
"graph_ranks(average_ranks, classifiers, cd=cd, width=6, textspace=1)\n",
|
| 704 |
+
"plt.xlabel('Classifiers')\n",
|
| 705 |
+
"\n",
|
| 706 |
+
"plt.text(0.5, 1.19, f'Friedman-Nemenyi: {stat:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=16)\n",
|
| 707 |
+
"plt.text(0.5, 1.10, f'CD: {cd:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=16)\n",
|
| 708 |
+
"\n",
|
| 709 |
+
"plt.tight_layout()"
|
| 710 |
+
]
|
| 711 |
+
},
|
| 712 |
+
{
|
| 713 |
+
"cell_type": "code",
|
| 714 |
+
"execution_count": null,
|
| 715 |
+
"id": "d88fdb63-9223-48f5-b4c3-78c5ab76938b",
|
| 716 |
+
"metadata": {},
|
| 717 |
+
"outputs": [],
|
| 718 |
+
"source": [
|
| 719 |
+
"# SHAP\n",
|
| 720 |
+
"import shap\n",
|
| 721 |
+
"import matplotlib.pyplot as plt\n",
|
| 722 |
+
"import pandas as pd\n",
|
| 723 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 724 |
+
"from sklearn.preprocessing import MinMaxScaler\n",
|
| 725 |
+
"from catboost import CatBoostRegressor\n",
|
| 726 |
+
"from sklearn.feature_selection import SelectFromModel\n",
|
| 727 |
+
"from sklearn.neighbors import LocalOutlierFactor\n",
|
| 728 |
+
"from xgboost import XGBRegressor\n",
|
| 729 |
+
"\n",
|
| 730 |
+
"cbr = CatBoostRegressor(verbose=0, iterations=2000, learning_rate=0.01, depth=5)\n",
|
| 731 |
+
"\n",
|
| 732 |
+
"X = splitted_dataset.drop('total_sum', axis=1)\n",
|
| 733 |
+
"y = splitted_dataset['total_sum']\n",
|
| 734 |
+
"random_state = 0\n",
|
| 735 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=random_state)\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"lof = LocalOutlierFactor()\n",
|
| 738 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"mask = yhat != -1\n",
|
| 741 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 742 |
+
"\n",
|
| 743 |
+
"original_columns = X.columns.tolist()\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"X_train = pd.DataFrame(X_train, columns=original_columns)\n",
|
| 746 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 747 |
+
"\n",
|
| 748 |
+
"xgb = XGBRegressor(random_state=random_state)\n",
|
| 749 |
+
"xgb.fit(X_train, y_train)\n",
|
| 750 |
+
"selector = SelectFromModel(xgb, prefit=True)\n",
|
| 751 |
+
"\n",
|
| 752 |
+
"importance = np.abs(xgb.feature_importances_)\n",
|
| 753 |
+
"indices = np.argsort(importance)[::-1]\n",
|
| 754 |
+
"important_features = [original_columns[i] for i in indices[:50]]\n",
|
| 755 |
+
"\n",
|
| 756 |
+
"X_train_fi = X_train[important_features]\n",
|
| 757 |
+
"X_test_fi = X_test[important_features]\n",
|
| 758 |
+
"\n",
|
| 759 |
+
"cbr.fit(X_train_fi, y_train)\n",
|
| 760 |
+
"\n",
|
| 761 |
+
"# Compute SHAP values using shap.Explainer\n",
|
| 762 |
+
"explainer = shap.Explainer(cbr, X_train_fi)\n",
|
| 763 |
+
"shap_values = explainer(X_train_fi)\n",
|
| 764 |
+
"plt.figure(figsize=(12, 8))\n",
|
| 765 |
+
"shap.summary_plot(shap_values, X_train_fi, plot_type=\"bar\", feature_names=important_features, show=False)\n",
|
| 766 |
+
"plt.savefig(\"shap_summary_plot.svg\", format='svg') # Save the plot as SVG\n",
|
| 767 |
+
"plt.close()\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"display(FileLink(\"shap_summary_plot.svg\"))\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"sorted_indices = np.argsort(y_train.values)\n",
|
| 772 |
+
"low_value_index = sorted_indices[0]\n",
|
| 773 |
+
"high_value_index = sorted_indices[-1]\n",
|
| 774 |
+
"\n",
|
| 775 |
+
"print(f\"Array position with low target value: {low_value_index}, Target Value: {y_train.values[low_value_index]}\")\n",
|
| 776 |
+
"print(f\"Array position with high target value: {high_value_index}, Target Value: {y_train.values[high_value_index]}\")"
|
| 777 |
+
]
|
| 778 |
+
},
|
| 779 |
+
{
|
| 780 |
+
"cell_type": "code",
|
| 781 |
+
"execution_count": null,
|
| 782 |
+
"id": "f6cce3ab-2480-4e87-a4ad-36054e12553a",
|
| 783 |
+
"metadata": {},
|
| 784 |
+
"outputs": [],
|
| 785 |
+
"source": [
|
| 786 |
+
"# Function to plot SHAP waterfall plot for a specific instance and save as SVG\n",
|
| 787 |
+
"def plot_shap_waterfall(instance_index, filename):\n",
|
| 788 |
+
" shap_value = shap_values[instance_index]\n",
|
| 789 |
+
" \n",
|
| 790 |
+
" plt.figure(figsize=(14, 8))\n",
|
| 791 |
+
" \n",
|
| 792 |
+
" shap.plots.waterfall(shap_value, show=False)\n",
|
| 793 |
+
" \n",
|
| 794 |
+
" plt.tight_layout()\n",
|
| 795 |
+
"\n",
|
| 796 |
+
" plt.savefig(filename, format='svg')\n",
|
| 797 |
+
" \n",
|
| 798 |
+
" plt.close()\n",
|
| 799 |
+
"\n",
|
| 800 |
+
"plot_shap_waterfall(482, \"waterfall_plot_instance_0.svg\")\n",
|
| 801 |
+
"plot_shap_waterfall(70, \"waterfall_plot_instance_1.svg\")\n",
|
| 802 |
+
"\n",
|
| 803 |
+
"display(FileLink(\"waterfall_plot_instance_0.svg\"))\n",
|
| 804 |
+
"display(FileLink(\"waterfall_plot_instance_1.svg\"))"
|
| 805 |
+
]
|
| 806 |
+
},
|
| 807 |
+
{
|
| 808 |
+
"cell_type": "markdown",
|
| 809 |
+
"id": "176f6725-ac13-444e-8d6e-fe3cc7a63b95",
|
| 810 |
+
"metadata": {},
|
| 811 |
+
"source": [
|
| 812 |
+
"### Hyperparameter Optimization"
|
| 813 |
+
]
|
| 814 |
+
},
|
| 815 |
+
{
|
| 816 |
+
"cell_type": "code",
|
| 817 |
+
"execution_count": null,
|
| 818 |
+
"id": "213e4d7c-5719-47bc-831d-46809d65fb49",
|
| 819 |
+
"metadata": {},
|
| 820 |
+
"outputs": [],
|
| 821 |
+
"source": [
|
| 822 |
+
"# Models\n",
|
| 823 |
+
"regressors = {\n",
|
| 824 |
+
" 'CBR': CatBoostRegressor(verbose=0),\n",
|
| 825 |
+
" #'XGBR': XGBRegressor(),\n",
|
| 826 |
+
" #'LGBMR': LGBMRegressor(),\n",
|
| 827 |
+
" # 'GBR': GradientBoostingRegressor(),\n",
|
| 828 |
+
" #'RFR': RandomForestRegressor(),\n",
|
| 829 |
+
" # 'ETR': ExtraTreesRegressor(),\n",
|
| 830 |
+
" # 'ABR': AdaBoostRegressor()\n",
|
| 831 |
+
"}\n",
|
| 832 |
+
"\n",
|
| 833 |
+
"# Define parameter grids for the models\n",
|
| 834 |
+
"param_grids = {\n",
|
| 835 |
+
" 'CBR': {\n",
|
| 836 |
+
" 'iterations': Integer(100, 500),\n",
|
| 837 |
+
" 'learning_rate': Real(0.01, 0.1),\n",
|
| 838 |
+
" 'depth': Integer(3, 10),\n",
|
| 839 |
+
" },\n",
|
| 840 |
+
" # 'GBR': {\n",
|
| 841 |
+
" # 'n_estimators': Integer(50, 300),\n",
|
| 842 |
+
" # 'learning_rate': Real(0.01, 0.1),\n",
|
| 843 |
+
" # 'max_depth': Integer(3, 10)\n",
|
| 844 |
+
" # },\n",
|
| 845 |
+
" # 'RFR': {\n",
|
| 846 |
+
" # 'n_estimators': Integer(50, 300),\n",
|
| 847 |
+
" # 'max_depth': Integer(3, 20)\n",
|
| 848 |
+
" # },\n",
|
| 849 |
+
" # 'XGBR': {\n",
|
| 850 |
+
" # 'n_estimators': Integer(50, 300),\n",
|
| 851 |
+
" # 'learning_rate': Real(0.01, 0.1),\n",
|
| 852 |
+
" # 'max_depth': Integer(3, 10),\n",
|
| 853 |
+
" # },\n",
|
| 854 |
+
" # 'LGBMR': {\n",
|
| 855 |
+
" # 'n_estimators': Integer(50, 300),\n",
|
| 856 |
+
" # 'learning_rate': Real(0.01, 0.1),\n",
|
| 857 |
+
" # 'num_leaves': Integer(20, 50),\n",
|
| 858 |
+
" # },\n",
|
| 859 |
+
" # 'ETR': {\n",
|
| 860 |
+
" # 'n_estimators': Integer(50, 300),\n",
|
| 861 |
+
" # 'max_depth': Integer(3, 20)\n",
|
| 862 |
+
" # },\n",
|
| 863 |
+
"# 'ABR': {\n",
|
| 864 |
+
"# 'n_estimators': Integer(50, 300),\n",
|
| 865 |
+
"# 'learning_rate': Real(0.01, 1.0)\n",
|
| 866 |
+
"# }\n",
|
| 867 |
+
"}\n",
|
| 868 |
+
"\n",
|
| 869 |
+
"# Function to perform hyperparameter tuning\n",
|
| 870 |
+
"def hyperparameter_tuning(model, param_grid, X_train, y_train):\n",
|
| 871 |
+
" bayes_search = BayesSearchCV(model, search_spaces=param_grid, n_iter=50, cv=5, scoring='neg_mean_squared_error', n_jobs=-1, random_state=0)\n",
|
| 872 |
+
" bayes_search.fit(X_train, y_train)\n",
|
| 873 |
+
" return bayes_search.best_estimator_\n",
|
| 874 |
+
"\n",
|
| 875 |
+
"metric_sums = {name: {'rmse': 0, 'mae': 0, 'r2': 0} for name in regressors.keys()}\n",
|
| 876 |
+
"metric_stds = {name: {'rmse': 0, 'mae': 0, 'r2': 0} for name in regressors.keys()}\n",
|
| 877 |
+
"rmse_scores = {name: [] for name in regressors.keys()}\n",
|
| 878 |
+
"mae_scores = {name: [] for name in regressors.keys()}\n",
|
| 879 |
+
"r2_scores = {name: [] for name in regressors.keys()}\n",
|
| 880 |
+
"\n",
|
| 881 |
+
"random_state = 5\n",
|
| 882 |
+
"print(f'Processing for Random State: {random_state}')\n",
|
| 883 |
+
"\n",
|
| 884 |
+
"X = splitted_dataset.drop('total_sum', axis=1)\n",
|
| 885 |
+
"y = splitted_dataset['total_sum']\n",
|
| 886 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=random_state)\n",
|
| 887 |
+
"\n",
|
| 888 |
+
"lof = LocalOutlierFactor()\n",
|
| 889 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 890 |
+
"\n",
|
| 891 |
+
"mask = yhat != -1\n",
|
| 892 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 893 |
+
"\n",
|
| 894 |
+
"original_columns = X.columns.tolist()\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"print(f\"Number of training labels after outlier removal: {len(y_train)}\")\n",
|
| 897 |
+
"print(f\"Number of test labels: {len(y_test)}\")\n",
|
| 898 |
+
"\n",
|
| 899 |
+
"scaler = MinMaxScaler()\n",
|
| 900 |
+
"X_train = scaler.fit_transform(X_train)\n",
|
| 901 |
+
"X_test = scaler.transform(X_test)\n",
|
| 902 |
+
"\n",
|
| 903 |
+
"X_train = pd.DataFrame(X_train, columns=original_columns)\n",
|
| 904 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 905 |
+
"\n",
|
| 906 |
+
"xgb = XGBRegressor(random_state=random_state)\n",
|
| 907 |
+
"xgb.fit(X_train, y_train)\n",
|
| 908 |
+
"selector = SelectFromModel(xgb, prefit=True)\n",
|
| 909 |
+
"\n",
|
| 910 |
+
"importance = np.abs(xgb.feature_importances_)\n",
|
| 911 |
+
"indices = np.argsort(importance)[::-1]\n",
|
| 912 |
+
"important_features = [original_columns[i] for i in indices[:50]]\n",
|
| 913 |
+
"\n",
|
| 914 |
+
"X_train = X_train[important_features]\n",
|
| 915 |
+
"X_test = X_test[important_features]\n",
|
| 916 |
+
"\n",
|
| 917 |
+
"best_models = {}\n",
|
| 918 |
+
"for model_name, param_grid in param_grids.items():\n",
|
| 919 |
+
" if model_name == 'GBR':\n",
|
| 920 |
+
" model = GradientBoostingRegressor()\n",
|
| 921 |
+
" elif model_name == 'RFR':\n",
|
| 922 |
+
" model = RandomForestRegressor()\n",
|
| 923 |
+
" elif model_name == 'XGBR':\n",
|
| 924 |
+
" model = XGBRegressor()\n",
|
| 925 |
+
" elif model_name == 'LGBMR':\n",
|
| 926 |
+
" model = LGBMRegressor()\n",
|
| 927 |
+
" elif model_name == 'ETR':\n",
|
| 928 |
+
" model = ExtraTreesRegressor()\n",
|
| 929 |
+
" elif model_name == 'ABR':\n",
|
| 930 |
+
" model = AdaBoostRegressor()\n",
|
| 931 |
+
" elif model_name == 'CBR':\n",
|
| 932 |
+
" model = CatBoostRegressor(verbose=0)\n",
|
| 933 |
+
" \n",
|
| 934 |
+
" print(f\"Optimizing {model_name}...\")\n",
|
| 935 |
+
" best_model = hyperparameter_tuning(model, param_grid, X_train, y_train)\n",
|
| 936 |
+
" best_models[model_name] = best_model\n",
|
| 937 |
+
" print(f\"Best parameters for {model_name}: {best_model.get_params()}\")"
|
| 938 |
+
]
|
| 939 |
+
}
|
| 940 |
+
],
|
| 941 |
+
"metadata": {
|
| 942 |
+
"kernelspec": {
|
| 943 |
+
"display_name": "Python 3 (ipykernel)",
|
| 944 |
+
"language": "python",
|
| 945 |
+
"name": "python3"
|
| 946 |
+
},
|
| 947 |
+
"language_info": {
|
| 948 |
+
"codemirror_mode": {
|
| 949 |
+
"name": "ipython",
|
| 950 |
+
"version": 3
|
| 951 |
+
},
|
| 952 |
+
"file_extension": ".py",
|
| 953 |
+
"mimetype": "text/x-python",
|
| 954 |
+
"name": "python",
|
| 955 |
+
"nbconvert_exporter": "python",
|
| 956 |
+
"pygments_lexer": "ipython3",
|
| 957 |
+
"version": "3.12.3"
|
| 958 |
+
}
|
| 959 |
+
},
|
| 960 |
+
"nbformat": 4,
|
| 961 |
+
"nbformat_minor": 5
|
| 962 |
+
}
|
severityPredictionLayer.ipynb
ADDED
|
@@ -0,0 +1,1858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "83e18217-bbe8-4a9f-bab9-9be999d44761",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Severity Prediction Layer"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "dddcb2a8-73d3-4476-b88a-bc1494a2c830",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import pandas as pd\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"import matplotlib.pyplot as plt\n",
|
| 21 |
+
"import seaborn as sns\n",
|
| 22 |
+
"from sklearn.model_selection import cross_val_score, KFold\n",
|
| 23 |
+
"from sklearn.impute import KNNImputer\n",
|
| 24 |
+
"from sklearn.pipeline import make_pipeline\n",
|
| 25 |
+
"from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, HistGradientBoostingClassifier\n",
|
| 26 |
+
"from xgboost import XGBClassifier\n",
|
| 27 |
+
"from sklearn.impute import SimpleImputer\n",
|
| 28 |
+
"from sklearn.experimental import enable_iterative_imputer\n",
|
| 29 |
+
"from imblearn.pipeline import make_pipeline as make_pipeline_imb\n",
|
| 30 |
+
"from imblearn.over_sampling import SMOTE,SMOTENC\n",
|
| 31 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 32 |
+
"from collections import Counter\n",
|
| 33 |
+
"from sklearn.metrics import classification_report, accuracy_score, confusion_matrix\n",
|
| 34 |
+
"from sklearn.ensemble import GradientBoostingClassifier\n",
|
| 35 |
+
"from sklearn.ensemble import VotingClassifier\n",
|
| 36 |
+
"from sklearn.svm import SVC\n",
|
| 37 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
| 38 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
| 39 |
+
"from sklearn.ensemble import BaggingClassifier\n",
|
| 40 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
| 41 |
+
"from sklearn.ensemble import ExtraTreesClassifier\n",
|
| 42 |
+
"from deslib.dcs import APosteriori\n",
|
| 43 |
+
"from deslib.des import KNORAE, KNORAU, KNOP, DESMI\n",
|
| 44 |
+
"from sklearn.neighbors import LocalOutlierFactor\n",
|
| 45 |
+
"from sklearn.utils import resample\n",
|
| 46 |
+
"import warnings\n",
|
| 47 |
+
"from imblearn.over_sampling import RandomOverSampler\n",
|
| 48 |
+
"from imblearn.under_sampling import RandomUnderSampler\n",
|
| 49 |
+
"from sklearn.preprocessing import MinMaxScaler\n",
|
| 50 |
+
"from imblearn.pipeline import Pipeline\n",
|
| 51 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
| 52 |
+
"from sklearn.preprocessing import LabelEncoder, PowerTransformer\n",
|
| 53 |
+
"from collections import defaultdict\n",
|
| 54 |
+
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier, VotingClassifier\n",
|
| 55 |
+
"from catboost import CatBoostClassifier\n",
|
| 56 |
+
"from lightgbm import LGBMClassifier\n",
|
| 57 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc\n",
|
| 58 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
| 59 |
+
"from sklearn.neural_network import MLPClassifier\n",
|
| 60 |
+
"import Orange\n",
|
| 61 |
+
"from scipy.stats import friedmanchisquare, rankdata\n",
|
| 62 |
+
"import shap\n",
|
| 63 |
+
"import scikit_posthocs as sp\n",
|
| 64 |
+
"from sklearn.feature_selection import SelectFromModel\n",
|
| 65 |
+
"from IPython.display import FileLink, display\n",
|
| 66 |
+
"import math\n",
|
| 67 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 68 |
+
"from skopt.space import Integer, Real\n",
|
| 69 |
+
"from sklearn.model_selection import StratifiedKFold\n",
|
| 70 |
+
"from skopt import BayesSearchCV\n",
|
| 71 |
+
"import xgboost as xgb\n",
|
| 72 |
+
"from imblearn.over_sampling import SMOTE\n",
|
| 73 |
+
"from sklearn.tree import DecisionTreeClassifier, export_text\n",
|
| 74 |
+
"from sklearn import tree\n",
|
| 75 |
+
"from skopt.space import Real, Integer, Categorical\n",
|
| 76 |
+
"from skopt.callbacks import VerboseCallback\n",
|
| 77 |
+
"from deslib.des.knora_e import KNORAE\n",
|
| 78 |
+
"from deslib.des.knora_u import KNORAU\n",
|
| 79 |
+
"from deslib.des.knop import KNOP\n",
|
| 80 |
+
"from deslib.des.meta_des import METADES\n",
|
| 81 |
+
"from deslib.des.des_knn import DESKNN\n",
|
| 82 |
+
"from deslib.des.des_p import DESP"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"id": "39d6683c-fd3f-4daa-afdc-2e7f83c3fce3",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"### Preparation before training"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"id": "ec17e47a-8d92-498d-8f28-d8259d6ebc4e",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"# Call Dataset\n",
|
| 101 |
+
"pd.set_option('display.max_rows', 10)\n",
|
| 102 |
+
"initial_df = pd.read_csv('3labelv4Classification.csv')\n",
|
| 103 |
+
"initial_df.info()"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "code",
|
| 108 |
+
"execution_count": null,
|
| 109 |
+
"id": "d3751352-73bc-42c6-b6a5-84a3badf14ff",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"# All categorical features except for label\n",
|
| 114 |
+
"cols = initial_df.columns\n",
|
| 115 |
+
"num_cols = initial_df._get_numeric_data().columns\n",
|
| 116 |
+
"categorical_features = list(set(cols) - set(num_cols))\n",
|
| 117 |
+
"categorical_features.remove('depression_category')\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"# Label Encode all categorical, but keep missing values\n",
|
| 120 |
+
"le_initial_df = initial_df.copy()\n",
|
| 121 |
+
"dropped_labels = le_initial_df['depression_category']\n",
|
| 122 |
+
"le_initial_df = le_initial_df.drop('depression_category', axis = 1)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"for col in le_initial_df.columns:\n",
|
| 125 |
+
" if le_initial_df[col].dtype == 'object':\n",
|
| 126 |
+
" le_initial_df[col] = le_initial_df[col].fillna('missing')\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" label_encoder = LabelEncoder()\n",
|
| 129 |
+
" le_initial_df[col] = label_encoder.fit_transform(le_initial_df[col])\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" missing_value_index = np.where(label_encoder.classes_ == 'missing')[0]\n",
|
| 132 |
+
" \n",
|
| 133 |
+
" le_initial_df[col] = le_initial_df[col].replace(missing_value_index, np.nan)\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"le_initial_df = pd.concat([le_initial_df, dropped_labels], axis = 1)"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": null,
|
| 141 |
+
"id": "9fa21a95-21f1-4274-86ba-d7977813066b",
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"outputs": [],
|
| 144 |
+
"source": [
|
| 145 |
+
"le_initial_df"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"id": "0ba54a81-d4de-4884-99d3-29867eb7ea40",
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"# Seperate and Combine\n",
|
| 156 |
+
"le_df_mild = le_initial_df[le_initial_df['depression_category'] == 'mild']\n",
|
| 157 |
+
"le_df_moderatesevere = le_initial_df[le_initial_df['depression_category'] == 'moderatesevere']\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"# Check depression category counts\n",
|
| 160 |
+
"dataframes = [le_df_mild, le_df_moderatesevere]\n",
|
| 161 |
+
"le_initial_df = pd.concat(dataframes, ignore_index=True)\n",
|
| 162 |
+
"label_counts = le_initial_df['depression_category'].value_counts()\n",
|
| 163 |
+
"label_counts"
|
| 164 |
+
]
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"cell_type": "code",
|
| 168 |
+
"execution_count": null,
|
| 169 |
+
"id": "8c5bdf52-c645-4a11-920a-579e28db1a50",
|
| 170 |
+
"metadata": {},
|
| 171 |
+
"outputs": [],
|
| 172 |
+
"source": [
|
| 173 |
+
"# Imputation\n",
|
| 174 |
+
"different_le_dfs = [le_df_mild, le_df_moderatesevere]\n",
|
| 175 |
+
"imputed_le_dfs = []\n",
|
| 176 |
+
"from sklearn.impute import IterativeImputer\n",
|
| 177 |
+
"for le_df in different_le_dfs:\n",
|
| 178 |
+
" y = le_df['depression_category']\n",
|
| 179 |
+
" X = le_df.drop('depression_category', axis = 1)\n",
|
| 180 |
+
" \n",
|
| 181 |
+
" imputer = SimpleImputer(strategy='median')\n",
|
| 182 |
+
" imputed_data = imputer.fit_transform(X)\n",
|
| 183 |
+
" imputed_df = pd.DataFrame(imputed_data, columns = X.columns)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
" imputed_df['depression_category'] = y.reset_index(drop = True)\n",
|
| 186 |
+
" imputed_le_dfs.append(imputed_df)\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"concatenated_le_dfs = pd.concat(imputed_le_dfs, ignore_index = True)\n",
|
| 189 |
+
"concatenated_le_dfs"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "code",
|
| 194 |
+
"execution_count": null,
|
| 195 |
+
"id": "0e871a5f-beab-4f90-b8c4-e922ef86d0f0",
|
| 196 |
+
"metadata": {},
|
| 197 |
+
"outputs": [],
|
| 198 |
+
"source": [
|
| 199 |
+
"# Full label encode depression category\n",
|
| 200 |
+
"fully_LE_concatenated_le_dfs = concatenated_le_dfs.copy()\n",
|
| 201 |
+
"fully_LE_concatenated_le_dfs['depression_category'] = label_encoder.fit_transform(fully_LE_concatenated_le_dfs['depression_category'])\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"# The dataset after category connect, imputation, and label encoding\n",
|
| 204 |
+
"splitted_dataset = fully_LE_concatenated_le_dfs.copy()\n",
|
| 205 |
+
"splitted_dataset"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "markdown",
|
| 210 |
+
"id": "198fdc3c-ccbc-4ac5-8621-c0c5dc771cbb",
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"source": [
|
| 213 |
+
"### Setup for training"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": null,
|
| 219 |
+
"id": "101c5549-7bc3-4573-867b-f09776e254db",
|
| 220 |
+
"metadata": {
|
| 221 |
+
"jupyter": {
|
| 222 |
+
"source_hidden": true
|
| 223 |
+
}
|
| 224 |
+
},
|
| 225 |
+
"outputs": [],
|
| 226 |
+
"source": [
|
| 227 |
+
"def plot_combined_roc_curve(roc_curves, classifier_names):\n",
|
| 228 |
+
" plt.figure(figsize=(12, 8))\n",
|
| 229 |
+
" mean_fpr = np.linspace(0, 1, 100)\n",
|
| 230 |
+
" colors = plt.cm.get_cmap('tab20', len(classifier_names))\n",
|
| 231 |
+
" \n",
|
| 232 |
+
" for i, clf_name in enumerate(classifier_names):\n",
|
| 233 |
+
" tprs = []\n",
|
| 234 |
+
" for fpr, tpr in roc_curves[clf_name]:\n",
|
| 235 |
+
" tprs.append(np.interp(mean_fpr, fpr, tpr))\n",
|
| 236 |
+
" mean_tpr = np.mean(tprs, axis=0)\n",
|
| 237 |
+
" mean_tpr[-1] = 1.0\n",
|
| 238 |
+
" mean_auc = auc(mean_fpr, mean_tpr)\n",
|
| 239 |
+
" plt.plot(mean_fpr, mean_tpr, color=colors(i), lw=2, linestyle='-', marker='o', markersize=4, \n",
|
| 240 |
+
" label=f'{clf_name} (AUC = {mean_auc:.3f})')\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" plt.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')\n",
|
| 243 |
+
" plt.xlim([0.0, 1.0])\n",
|
| 244 |
+
" plt.ylim([0.0, 1.05])\n",
|
| 245 |
+
" plt.xlabel('False Positive Rate', fontsize=26)\n",
|
| 246 |
+
" plt.ylabel('True Positive Rate', fontsize=26)\n",
|
| 247 |
+
" plt.xticks(fontsize=30)\n",
|
| 248 |
+
" plt.yticks(fontsize=30)\n",
|
| 249 |
+
" plt.legend(loc=\"lower right\", fontsize=22, frameon=True, framealpha=0.9)\n",
|
| 250 |
+
" plt.grid(True)\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" filename='bonk.svg'\n",
|
| 253 |
+
"\n",
|
| 254 |
+
" plt.savefig(filename, format='svg')\n",
|
| 255 |
+
" plt.show()\n",
|
| 256 |
+
"\n",
|
| 257 |
+
" display(FileLink(filename))\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"# Preparation code to make CD diagram from older version of Orange\n",
|
| 260 |
+
"def compute_CD(avranks, n, alpha=\"0.05\", test=\"nemenyi\"):\n",
|
| 261 |
+
" \"\"\"\n",
|
| 262 |
+
" Returns critical difference for Nemenyi or Bonferroni-Dunn test\n",
|
| 263 |
+
" according to given alpha (either alpha=\"0.05\" or alpha=\"0.1\") for average\n",
|
| 264 |
+
" ranks and number of tested datasets N. Test can be either \"nemenyi\" for\n",
|
| 265 |
+
" for Nemenyi two tailed test or \"bonferroni-dunn\" for Bonferroni-Dunn test.\n",
|
| 266 |
+
"\n",
|
| 267 |
+
" This function is deprecated and will be removed in Orange 3.34.\n",
|
| 268 |
+
" \"\"\"\n",
|
| 269 |
+
" k = len(avranks)\n",
|
| 270 |
+
" d = {(\"nemenyi\", \"0.05\"): [0, 0, 1.959964, 2.343701, 2.569032, 2.727774,\n",
|
| 271 |
+
" 2.849705, 2.94832, 3.030879, 3.101730, 3.163684,\n",
|
| 272 |
+
" 3.218654, 3.268004, 3.312739, 3.353618, 3.39123,\n",
|
| 273 |
+
" 3.426041, 3.458425, 3.488685, 3.517073,\n",
|
| 274 |
+
" 3.543799],\n",
|
| 275 |
+
" (\"nemenyi\", \"0.1\"): [0, 0, 1.644854, 2.052293, 2.291341, 2.459516,\n",
|
| 276 |
+
" 2.588521, 2.692732, 2.779884, 2.854606, 2.919889,\n",
|
| 277 |
+
" 2.977768, 3.029694, 3.076733, 3.119693, 3.159199,\n",
|
| 278 |
+
" 3.195743, 3.229723, 3.261461, 3.291224, 3.319233],\n",
|
| 279 |
+
" (\"bonferroni-dunn\", \"0.05\"): [0, 0, 1.960, 2.241, 2.394, 2.498, 2.576,\n",
|
| 280 |
+
" 2.638, 2.690, 2.724, 2.773],\n",
|
| 281 |
+
" (\"bonferroni-dunn\", \"0.1\"): [0, 0, 1.645, 1.960, 2.128, 2.241, 2.326,\n",
|
| 282 |
+
" 2.394, 2.450, 2.498, 2.539]}\n",
|
| 283 |
+
" q = d[(test, alpha)]\n",
|
| 284 |
+
" cd = q[k] * (k * (k + 1) / (6.0 * n)) ** 0.5\n",
|
| 285 |
+
" return cd\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"def graph_ranks(avranks, names, cd=None, cdmethod=None, lowv=None, highv=None,\n",
|
| 289 |
+
" width=6, textspace=1, reverse=False, filename=None, **kwargs):\n",
|
| 290 |
+
" \"\"\"\n",
|
| 291 |
+
" Draws a CD graph, which is used to display the differences in methods'\n",
|
| 292 |
+
" performance. See Janez Demsar, Statistical Comparisons of Classifiers over\n",
|
| 293 |
+
" Multiple Data Sets, 7(Jan):1--30, 2006.\n",
|
| 294 |
+
"\n",
|
| 295 |
+
" Needs matplotlib to work.\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" The image is ploted on `plt` imported using\n",
|
| 298 |
+
" `import matplotlib.pyplot as plt`.\n",
|
| 299 |
+
"\n",
|
| 300 |
+
" This function is deprecated and will be removed in Orange 3.34.\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" Args:\n",
|
| 303 |
+
" avranks (list of float): average ranks of methods.\n",
|
| 304 |
+
" names (list of str): names of methods.\n",
|
| 305 |
+
" cd (float): Critical difference used for statistically significance of\n",
|
| 306 |
+
" difference between methods.\n",
|
| 307 |
+
" cdmethod (int, optional): the method that is compared with other methods\n",
|
| 308 |
+
" If omitted, show pairwise comparison of methods\n",
|
| 309 |
+
" lowv (int, optional): the lowest shown rank\n",
|
| 310 |
+
" highv (int, optional): the highest shown rank\n",
|
| 311 |
+
" width (int, optional): default width in inches (default: 6)\n",
|
| 312 |
+
" textspace (int, optional): space on figure sides (in inches) for the\n",
|
| 313 |
+
" method names (default: 1)\n",
|
| 314 |
+
" reverse (bool, optional): if set to `True`, the lowest rank is on the\n",
|
| 315 |
+
" right (default: `False`)\n",
|
| 316 |
+
" filename (str, optional): output file name (with extension). If not\n",
|
| 317 |
+
" given, the function does not write a file.\n",
|
| 318 |
+
" \"\"\"\n",
|
| 319 |
+
" try:\n",
|
| 320 |
+
" import matplotlib.pyplot as plt\n",
|
| 321 |
+
" from matplotlib.backends.backend_agg import FigureCanvasAgg\n",
|
| 322 |
+
" except ImportError:\n",
|
| 323 |
+
" raise ImportError(\"Function graph_ranks requires matplotlib.\")\n",
|
| 324 |
+
"\n",
|
| 325 |
+
" width = float(width)\n",
|
| 326 |
+
" textspace = float(textspace)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
" def nth(l, n):\n",
|
| 329 |
+
" \"\"\"\n",
|
| 330 |
+
" Returns only nth elemnt in a list.\n",
|
| 331 |
+
" \"\"\"\n",
|
| 332 |
+
" n = lloc(l, n)\n",
|
| 333 |
+
" return [a[n] for a in l]\n",
|
| 334 |
+
"\n",
|
| 335 |
+
" def lloc(l, n):\n",
|
| 336 |
+
" \"\"\"\n",
|
| 337 |
+
" List location in list of list structure.\n",
|
| 338 |
+
" Enable the use of negative locations:\n",
|
| 339 |
+
" -1 is the last element, -2 second last...\n",
|
| 340 |
+
" \"\"\"\n",
|
| 341 |
+
" if n < 0:\n",
|
| 342 |
+
" return len(l[0]) + n\n",
|
| 343 |
+
" else:\n",
|
| 344 |
+
" return n\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" def mxrange(lr):\n",
|
| 347 |
+
" \"\"\"\n",
|
| 348 |
+
" Multiple xranges. Can be used to traverse matrices.\n",
|
| 349 |
+
" This function is very slow due to unknown number of\n",
|
| 350 |
+
" parameters.\n",
|
| 351 |
+
"\n",
|
| 352 |
+
" >>> mxrange([3,5])\n",
|
| 353 |
+
" [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]\n",
|
| 354 |
+
"\n",
|
| 355 |
+
" >>> mxrange([[3,5,1],[9,0,-3]])\n",
|
| 356 |
+
" [(3, 9), (3, 6), (3, 3), (4, 9), (4, 6), (4, 3)]\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" \"\"\"\n",
|
| 359 |
+
" if not len(lr):\n",
|
| 360 |
+
" yield ()\n",
|
| 361 |
+
" else:\n",
|
| 362 |
+
" # it can work with single numbers\n",
|
| 363 |
+
" index = lr[0]\n",
|
| 364 |
+
" if isinstance(index, int):\n",
|
| 365 |
+
" index = [index]\n",
|
| 366 |
+
" for a in range(*index):\n",
|
| 367 |
+
" for b in mxrange(lr[1:]):\n",
|
| 368 |
+
" yield tuple([a] + list(b))\n",
|
| 369 |
+
"\n",
|
| 370 |
+
" def print_figure(fig, *args, **kwargs):\n",
|
| 371 |
+
" canvas = FigureCanvasAgg(fig)\n",
|
| 372 |
+
" canvas.print_figure(*args, **kwargs)\n",
|
| 373 |
+
"\n",
|
| 374 |
+
" sums = avranks\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" tempsort = sorted([(a, i) for i, a in enumerate(sums)], reverse=reverse)\n",
|
| 377 |
+
" ssums = nth(tempsort, 0)\n",
|
| 378 |
+
" sortidx = nth(tempsort, 1)\n",
|
| 379 |
+
" nnames = [names[x] for x in sortidx]\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" if lowv is None:\n",
|
| 382 |
+
" lowv = min(1, int(math.floor(min(ssums))))\n",
|
| 383 |
+
" if highv is None:\n",
|
| 384 |
+
" highv = max(len(avranks), int(math.ceil(max(ssums))))\n",
|
| 385 |
+
"\n",
|
| 386 |
+
" cline = 0.4\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" k = len(sums)\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" lines = None\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" linesblank = 0\n",
|
| 393 |
+
" scalewidth = width - 2 * textspace\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" def rankpos(rank):\n",
|
| 396 |
+
" if not reverse:\n",
|
| 397 |
+
" a = rank - lowv\n",
|
| 398 |
+
" else:\n",
|
| 399 |
+
" a = highv - rank\n",
|
| 400 |
+
" return textspace + scalewidth / (highv - lowv) * a\n",
|
| 401 |
+
"\n",
|
| 402 |
+
" distanceh = 0.25\n",
|
| 403 |
+
"\n",
|
| 404 |
+
" if cd and cdmethod is None:\n",
|
| 405 |
+
" # get pairs of non significant methods\n",
|
| 406 |
+
"\n",
|
| 407 |
+
" def get_lines(sums, hsd):\n",
|
| 408 |
+
" # get all pairs\n",
|
| 409 |
+
" lsums = len(sums)\n",
|
| 410 |
+
" allpairs = [(i, j) for i, j in mxrange([[lsums], [lsums]]) if j > i]\n",
|
| 411 |
+
" # remove not significant\n",
|
| 412 |
+
" notSig = [(i, j) for i, j in allpairs\n",
|
| 413 |
+
" if abs(sums[i] - sums[j]) <= hsd]\n",
|
| 414 |
+
" # keep only longest\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" def no_longer(ij_tuple, notSig):\n",
|
| 417 |
+
" i, j = ij_tuple\n",
|
| 418 |
+
" for i1, j1 in notSig:\n",
|
| 419 |
+
" if (i1 <= i and j1 > j) or (i1 < i and j1 >= j):\n",
|
| 420 |
+
" return False\n",
|
| 421 |
+
" return True\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" longest = [(i, j) for i, j in notSig if no_longer((i, j), notSig)]\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" return longest\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" lines = get_lines(ssums, cd)\n",
|
| 428 |
+
" linesblank = 0.2 + 0.2 + (len(lines) - 1) * 0.1\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" # add scale\n",
|
| 431 |
+
" distanceh = 0.25\n",
|
| 432 |
+
" cline += distanceh\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" # calculate height needed height of an image\n",
|
| 435 |
+
" minnotsignificant = max(2 * 0.2, linesblank)\n",
|
| 436 |
+
" height = cline + ((k + 1) / 2) * 0.2 + minnotsignificant\n",
|
| 437 |
+
"\n",
|
| 438 |
+
" fig = plt.figure(figsize=(width, height))\n",
|
| 439 |
+
" fig.set_facecolor('white')\n",
|
| 440 |
+
" ax = fig.add_axes([0, 0, 1, 1]) # reverse y axis\n",
|
| 441 |
+
" ax.set_axis_off()\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" hf = 1. / height # height factor\n",
|
| 444 |
+
" wf = 1. / width\n",
|
| 445 |
+
"\n",
|
| 446 |
+
" def hfl(l):\n",
|
| 447 |
+
" return [a * hf for a in l]\n",
|
| 448 |
+
"\n",
|
| 449 |
+
" def wfl(l):\n",
|
| 450 |
+
" return [a * wf for a in l]\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" # Upper left corner is (0,0).\n",
|
| 454 |
+
" ax.plot([0, 1], [0, 1], c=\"w\")\n",
|
| 455 |
+
" ax.set_xlim(0, 1)\n",
|
| 456 |
+
" ax.set_ylim(1, 0)\n",
|
| 457 |
+
"\n",
|
| 458 |
+
" def line(l, color='k', **kwargs):\n",
|
| 459 |
+
" \"\"\"\n",
|
| 460 |
+
" Input is a list of pairs of points.\n",
|
| 461 |
+
" \"\"\"\n",
|
| 462 |
+
" ax.plot(wfl(nth(l, 0)), hfl(nth(l, 1)), color=color, **kwargs)\n",
|
| 463 |
+
"\n",
|
| 464 |
+
" def text(x, y, s, *args, **kwargs):\n",
|
| 465 |
+
" ax.text(wf * x, hf * y, s, fontsize = 14, *args, **kwargs)\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" line([(textspace, cline), (width - textspace, cline)], linewidth=0.7)\n",
|
| 468 |
+
"\n",
|
| 469 |
+
" bigtick = 0.1\n",
|
| 470 |
+
" smalltick = 0.05\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" tick = None\n",
|
| 473 |
+
" for a in list(np.arange(lowv, highv, 0.5)) + [highv]:\n",
|
| 474 |
+
" tick = smalltick\n",
|
| 475 |
+
" if a == int(a):\n",
|
| 476 |
+
" tick = bigtick\n",
|
| 477 |
+
" line([(rankpos(a), cline - tick / 2),\n",
|
| 478 |
+
" (rankpos(a), cline)],\n",
|
| 479 |
+
" linewidth=0.7)\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" for a in range(lowv, highv + 1):\n",
|
| 482 |
+
" text(rankpos(a), cline - tick / 2 - 0.05, str(a),\n",
|
| 483 |
+
" ha=\"center\", va=\"bottom\")\n",
|
| 484 |
+
"\n",
|
| 485 |
+
" k = len(ssums)\n",
|
| 486 |
+
"\n",
|
| 487 |
+
" for i in range(math.ceil(k / 2)):\n",
|
| 488 |
+
" chei = cline + minnotsignificant + i * 0.2\n",
|
| 489 |
+
" line([(rankpos(ssums[i]), cline),\n",
|
| 490 |
+
" (rankpos(ssums[i]), chei),\n",
|
| 491 |
+
" (textspace - 0.1, chei)],\n",
|
| 492 |
+
" linewidth=0.7)\n",
|
| 493 |
+
" text(textspace - 0.2, chei, nnames[i], ha=\"right\", va=\"center\")\n",
|
| 494 |
+
"\n",
|
| 495 |
+
" for i in range(math.ceil(k / 2), k):\n",
|
| 496 |
+
" chei = cline + minnotsignificant + (k - i - 1) * 0.2\n",
|
| 497 |
+
" line([(rankpos(ssums[i]), cline),\n",
|
| 498 |
+
" (rankpos(ssums[i]), chei),\n",
|
| 499 |
+
" (textspace + scalewidth + 0.1, chei)],\n",
|
| 500 |
+
" linewidth=0.7)\n",
|
| 501 |
+
" text(textspace + scalewidth + 0.2, chei, nnames[i],\n",
|
| 502 |
+
" ha=\"left\", va=\"center\")\n",
|
| 503 |
+
"\n",
|
| 504 |
+
" if cd and cdmethod is None:\n",
|
| 505 |
+
" # upper scale\n",
|
| 506 |
+
" if not reverse:\n",
|
| 507 |
+
" begin, end = rankpos(lowv), rankpos(lowv + cd)\n",
|
| 508 |
+
" else:\n",
|
| 509 |
+
" begin, end = rankpos(highv), rankpos(highv - cd)\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" line([(begin, distanceh), (end, distanceh)], linewidth=0.7)\n",
|
| 512 |
+
" line([(begin, distanceh + bigtick / 2),\n",
|
| 513 |
+
" (begin, distanceh - bigtick / 2)],\n",
|
| 514 |
+
" linewidth=0.7)\n",
|
| 515 |
+
" line([(end, distanceh + bigtick / 2),\n",
|
| 516 |
+
" (end, distanceh - bigtick / 2)],\n",
|
| 517 |
+
" linewidth=0.7)\n",
|
| 518 |
+
" text((begin + end) / 2, distanceh - 0.05, \"CD\",\n",
|
| 519 |
+
" ha=\"center\", va=\"bottom\")\n",
|
| 520 |
+
"\n",
|
| 521 |
+
" # no-significance lines\n",
|
| 522 |
+
" def draw_lines(lines, side=0.05, height=0.1):\n",
|
| 523 |
+
" start = cline + 0.2\n",
|
| 524 |
+
" for l, r in lines:\n",
|
| 525 |
+
" line([(rankpos(ssums[l]) - side, start),\n",
|
| 526 |
+
" (rankpos(ssums[r]) + side, start)],\n",
|
| 527 |
+
" linewidth=2.5)\n",
|
| 528 |
+
" start += height\n",
|
| 529 |
+
"\n",
|
| 530 |
+
" draw_lines(lines)\n",
|
| 531 |
+
"\n",
|
| 532 |
+
" elif cd:\n",
|
| 533 |
+
" begin = rankpos(avranks[cdmethod] - cd)\n",
|
| 534 |
+
" end = rankpos(avranks[cdmethod] + cd)\n",
|
| 535 |
+
" line([(begin, cline), (end, cline)],\n",
|
| 536 |
+
" linewidth=2.5)\n",
|
| 537 |
+
" line([(begin, cline + bigtick / 2),\n",
|
| 538 |
+
" (begin, cline - bigtick / 2)],\n",
|
| 539 |
+
" linewidth=2.5)\n",
|
| 540 |
+
" line([(end, cline + bigtick / 2),\n",
|
| 541 |
+
" (end, cline - bigtick / 2)],\n",
|
| 542 |
+
" linewidth=2.5)\n",
|
| 543 |
+
"\n",
|
| 544 |
+
" if filename:\n",
|
| 545 |
+
" print_figure(fig, filename, **kwargs)\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"def train_evaluate_model(clf, X_train, y_train, X_test, y_test, clf_name='Classifier'):\n",
|
| 548 |
+
" clf.fit(X_train, y_train)\n",
|
| 549 |
+
" y_pred = clf.predict(X_test)\n",
|
| 550 |
+
" \n",
|
| 551 |
+
" accuracy = accuracy_score(y_test, y_pred)\n",
|
| 552 |
+
" precision = precision_score(y_test, y_pred, average='weighted')\n",
|
| 553 |
+
" recall = recall_score(y_test, y_pred, average='weighted')\n",
|
| 554 |
+
" f1 = f1_score(y_test, y_pred, average='weighted')\n",
|
| 555 |
+
" conf_matrix = confusion_matrix(y_test, y_pred)\n",
|
| 556 |
+
" \n",
|
| 557 |
+
" if hasattr(clf, 'predict_proba'):\n",
|
| 558 |
+
" y_score = clf.predict_proba(X_test)[:, 1]\n",
|
| 559 |
+
" else:\n",
|
| 560 |
+
" y_score = clf.decision_function(X_test)\n",
|
| 561 |
+
" \n",
|
| 562 |
+
" fpr, tpr, _ = roc_curve(y_test, y_score)\n",
|
| 563 |
+
" roc_auc = auc(fpr, tpr)\n",
|
| 564 |
+
" \n",
|
| 565 |
+
" print(f'{clf_name} - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')\n",
|
| 566 |
+
" return accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc"
|
| 567 |
+
]
|
| 568 |
+
},
|
| 569 |
+
{
|
| 570 |
+
"cell_type": "code",
|
| 571 |
+
"execution_count": null,
|
| 572 |
+
"id": "3e602f4d-efb2-4ac9-99bb-c6fbeca791cc",
|
| 573 |
+
"metadata": {
|
| 574 |
+
"jupyter": {
|
| 575 |
+
"source_hidden": true
|
| 576 |
+
}
|
| 577 |
+
},
|
| 578 |
+
"outputs": [],
|
| 579 |
+
"source": [
|
| 580 |
+
"warnings.filterwarnings('ignore')"
|
| 581 |
+
]
|
| 582 |
+
},
|
| 583 |
+
{
|
| 584 |
+
"cell_type": "markdown",
|
| 585 |
+
"id": "773b2524-af11-464f-97a9-f4add85be0d2",
|
| 586 |
+
"metadata": {},
|
| 587 |
+
"source": [
|
| 588 |
+
"### Training (classic/static)\n",
|
| 589 |
+
"In order to run classical/static, make sure to uncomment the one you need. \"Post Training\" is after one of these classical/static is done."
|
| 590 |
+
]
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"cell_type": "markdown",
|
| 594 |
+
"id": "f5b8873c-13d3-43b7-8902-14b73e8d5409",
|
| 595 |
+
"metadata": {},
|
| 596 |
+
"source": [
|
| 597 |
+
"#### Classical Classifiers"
|
| 598 |
+
]
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"cell_type": "code",
|
| 602 |
+
"execution_count": null,
|
| 603 |
+
"id": "264648a0-b0d8-4a63-9167-cea3a4ed1e18",
|
| 604 |
+
"metadata": {
|
| 605 |
+
"scrolled": true
|
| 606 |
+
},
|
| 607 |
+
"outputs": [],
|
| 608 |
+
"source": [
|
| 609 |
+
"# Optimized Classifiers Det\n",
|
| 610 |
+
"# classifiers = {\n",
|
| 611 |
+
"# 'DT': DecisionTreeClassifier(\n",
|
| 612 |
+
"# random_state=0, \n",
|
| 613 |
+
"# criterion='gini', \n",
|
| 614 |
+
"# max_depth=6, \n",
|
| 615 |
+
"# min_samples_leaf=10, \n",
|
| 616 |
+
"# min_samples_split=9\n",
|
| 617 |
+
"# ),\n",
|
| 618 |
+
"# 'LR': LogisticRegression(\n",
|
| 619 |
+
"# random_state=0, \n",
|
| 620 |
+
"# C=0.09659168435718246, \n",
|
| 621 |
+
"# max_iter=100, \n",
|
| 622 |
+
"# solver='lbfgs'\n",
|
| 623 |
+
"# ),\n",
|
| 624 |
+
"# 'NB': GaussianNB(\n",
|
| 625 |
+
"# var_smoothing=0.0058873326349240295\n",
|
| 626 |
+
"# ),\n",
|
| 627 |
+
"# 'KN': KNeighborsClassifier(\n",
|
| 628 |
+
"# metric='manhattan', \n",
|
| 629 |
+
"# n_neighbors=8, \n",
|
| 630 |
+
"# weights='uniform'\n",
|
| 631 |
+
"# ),\n",
|
| 632 |
+
"# 'MLP': MLPClassifier(\n",
|
| 633 |
+
"# random_state=0, \n",
|
| 634 |
+
"# max_iter=1000, \n",
|
| 635 |
+
"# alpha=0.0003079393718075164, \n",
|
| 636 |
+
"# hidden_layer_sizes=195, \n",
|
| 637 |
+
"# learning_rate_init=0.0001675266159417717\n",
|
| 638 |
+
"# ),\n",
|
| 639 |
+
"# 'SVC': SVC(probability=True, kernel = 'rbf', C = 0.95, gamma = 'scale')}\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"# Optimized Classifiers Sevpred\n",
|
| 642 |
+
"# classifiers = {\n",
|
| 643 |
+
"# 'DT': DecisionTreeClassifier(\n",
|
| 644 |
+
"# random_state=0, \n",
|
| 645 |
+
"# criterion='entropy', \n",
|
| 646 |
+
"# max_depth=20, \n",
|
| 647 |
+
"# min_samples_leaf=8, \n",
|
| 648 |
+
"# min_samples_split=6\n",
|
| 649 |
+
"# ),\n",
|
| 650 |
+
"# 'LR': LogisticRegression(\n",
|
| 651 |
+
"# random_state=0, \n",
|
| 652 |
+
"# C=2.2095350994035026, \n",
|
| 653 |
+
"# max_iter=152, \n",
|
| 654 |
+
"# solver='lbfgs'\n",
|
| 655 |
+
"# ),\n",
|
| 656 |
+
"# 'NB': GaussianNB(\n",
|
| 657 |
+
"# var_smoothing=0.00995456588724228\n",
|
| 658 |
+
"# ),\n",
|
| 659 |
+
"# 'KN': KNeighborsClassifier(\n",
|
| 660 |
+
"# metric='manhattan', \n",
|
| 661 |
+
"# n_neighbors=1, \n",
|
| 662 |
+
"# weights='uniform'\n",
|
| 663 |
+
"# ),\n",
|
| 664 |
+
"# 'MLP': MLPClassifier(\n",
|
| 665 |
+
"# random_state=0, \n",
|
| 666 |
+
"# max_iter=1000, \n",
|
| 667 |
+
"# alpha=8.512480164062713e-06, \n",
|
| 668 |
+
"# hidden_layer_sizes=87, \n",
|
| 669 |
+
"# learning_rate_init=0.002859975932024275\n",
|
| 670 |
+
"# ),\n",
|
| 671 |
+
"# 'SVC': SVC(\n",
|
| 672 |
+
"# probability=True, \n",
|
| 673 |
+
"# kernel='rbf', \n",
|
| 674 |
+
"# C=100, \n",
|
| 675 |
+
"# gamma=0.1\n",
|
| 676 |
+
"# )\n",
|
| 677 |
+
"# }\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"# Default classifiers\n",
|
| 680 |
+
"# classifiers = {\n",
|
| 681 |
+
"# 'DecisionTree': DecisionTreeClassifier(random_state=0),\n",
|
| 682 |
+
"# 'LogisticRegression': LogisticRegression(max_iter=1000, random_state=0),\n",
|
| 683 |
+
"# 'NaiveBayes': GaussianNB(),\n",
|
| 684 |
+
"# 'KNeighbors': KNeighborsClassifier(),\n",
|
| 685 |
+
"# 'MLP': MLPClassifier(max_iter=1000, random_state=0),\n",
|
| 686 |
+
"# 'SVC': SVC(probability=True, random_state=0)\n",
|
| 687 |
+
"# }\n",
|
| 688 |
+
"\n",
|
| 689 |
+
"# Main\n",
|
| 690 |
+
"# Initialize\n",
|
| 691 |
+
"metric_sums = defaultdict(lambda: {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0})\n",
|
| 692 |
+
"conf_matrices = defaultdict(list)\n",
|
| 693 |
+
"roc_curves = defaultdict(list)\n",
|
| 694 |
+
"roc_aucs = defaultdict(list)\n",
|
| 695 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 696 |
+
"precision_scores = defaultdict(list)\n",
|
| 697 |
+
"recall_scores = defaultdict(list)\n",
|
| 698 |
+
"f1_scores = defaultdict(list)\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"# Loop over 10 different random states\n",
|
| 701 |
+
"for random_state in range(10):\n",
|
| 702 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" # Splitting the data\n",
|
| 705 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 706 |
+
" y = splitted_dataset['depression_category']\n",
|
| 707 |
+
" \n",
|
| 708 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 709 |
+
" \n",
|
| 710 |
+
" # Identify outliers in the training dataset\n",
|
| 711 |
+
" lof = LocalOutlierFactor()\n",
|
| 712 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 713 |
+
" # Select all rows that are not outliers\n",
|
| 714 |
+
" mask = yhat != -1\n",
|
| 715 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 716 |
+
" \n",
|
| 717 |
+
" original_columns = X.columns.tolist()\n",
|
| 718 |
+
"\n",
|
| 719 |
+
" # SMOTE\n",
|
| 720 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 721 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 722 |
+
"\n",
|
| 723 |
+
" print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 724 |
+
"\n",
|
| 725 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 726 |
+
" \n",
|
| 727 |
+
" sampling_strategy_undersample = {0: 372}\n",
|
| 728 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 729 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 730 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 731 |
+
"\n",
|
| 732 |
+
" # Normalization\n",
|
| 733 |
+
" scaler = MinMaxScaler()\n",
|
| 734 |
+
" \n",
|
| 735 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 736 |
+
" X_test = scaler.transform(X_test)\n",
|
| 737 |
+
" \n",
|
| 738 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 739 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 740 |
+
"\n",
|
| 741 |
+
" # Correlation Feat Analysis\n",
|
| 742 |
+
" corr_df = X_res.copy()\n",
|
| 743 |
+
" corr_df['target'] = y_res\n",
|
| 744 |
+
" \n",
|
| 745 |
+
" corr_mat = corr_df.corr()\n",
|
| 746 |
+
" target_correlation = corr_mat['target'].drop('target')\n",
|
| 747 |
+
" top_features = target_correlation.abs().sort_values(ascending=False).head(200).index.tolist()\n",
|
| 748 |
+
" \n",
|
| 749 |
+
" # Only take top features\n",
|
| 750 |
+
" X_res_fi = X_res[top_features]\n",
|
| 751 |
+
" X_test_fi = X_test[top_features]\n",
|
| 752 |
+
"\n",
|
| 753 |
+
" # Evaluate classifiers\n",
|
| 754 |
+
" for clf_name, clf in classifiers.items():\n",
|
| 755 |
+
" # Ensure the random state for classifiers is consistent\n",
|
| 756 |
+
" if hasattr(clf, 'random_state'):\n",
|
| 757 |
+
" clf.set_params(random_state=random_state)\n",
|
| 758 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name)\n",
|
| 759 |
+
" metric_sums[clf_name]['accuracy'] += accuracy\n",
|
| 760 |
+
" metric_sums[clf_name]['precision'] += precision\n",
|
| 761 |
+
" metric_sums[clf_name]['recall'] += recall\n",
|
| 762 |
+
" metric_sums[clf_name]['f1'] += f1\n",
|
| 763 |
+
" conf_matrices[clf_name].append(conf_matrix)\n",
|
| 764 |
+
" roc_curves[clf_name].append((fpr, tpr))\n",
|
| 765 |
+
" roc_aucs[clf_name].append(roc_auc)\n",
|
| 766 |
+
" accuracy_scores[clf_name].append(accuracy)\n",
|
| 767 |
+
" precision_scores[clf_name].append(precision)\n",
|
| 768 |
+
" recall_scores[clf_name].append(recall)\n",
|
| 769 |
+
" f1_scores[clf_name].append(f1)"
|
| 770 |
+
]
|
| 771 |
+
},
|
| 772 |
+
{
|
| 773 |
+
"cell_type": "markdown",
|
| 774 |
+
"id": "0ec7e67e-a086-48a7-83de-26b54ef03899",
|
| 775 |
+
"metadata": {},
|
| 776 |
+
"source": [
|
| 777 |
+
"#### Static Classifiers"
|
| 778 |
+
]
|
| 779 |
+
},
|
| 780 |
+
{
|
| 781 |
+
"cell_type": "code",
|
| 782 |
+
"execution_count": null,
|
| 783 |
+
"id": "e9e20dd2-bdd8-4626-be89-8bb866212de9",
|
| 784 |
+
"metadata": {
|
| 785 |
+
"scrolled": true
|
| 786 |
+
},
|
| 787 |
+
"outputs": [],
|
| 788 |
+
"source": [
|
| 789 |
+
"# Initialize\n",
|
| 790 |
+
"metric_sums = defaultdict(lambda: {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0})\n",
|
| 791 |
+
"conf_matrices = defaultdict(list)\n",
|
| 792 |
+
"roc_curves = defaultdict(list)\n",
|
| 793 |
+
"roc_aucs = defaultdict(list)\n",
|
| 794 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 795 |
+
"precision_scores = defaultdict(list)\n",
|
| 796 |
+
"recall_scores = defaultdict(list)\n",
|
| 797 |
+
"f1_scores = defaultdict(list)\n",
|
| 798 |
+
"\n",
|
| 799 |
+
"# Optimized Classifiers Detection\n",
|
| 800 |
+
"# classifiers = {\n",
|
| 801 |
+
"# 'RF': RandomForestClassifier(n_estimators=143, criterion='entropy', max_depth=15, random_state=0),\n",
|
| 802 |
+
"# 'XGB': XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=0),\n",
|
| 803 |
+
"# 'GB': GradientBoostingClassifier(n_estimators=300, max_depth=3, learning_rate=0.05),\n",
|
| 804 |
+
" # 'AB': AdaBoostClassifier(n_estimators=400, learning_rate=0.1),\n",
|
| 805 |
+
" # 'CB': CatBoostClassifier(depth = 3, iterations = 168, learning_rate = 0.1, verbose = 0),\n",
|
| 806 |
+
" # 'LGBM': LGBMClassifier(learning_rate = 0.1, max_depth = 3, n_estimators = 200) \n",
|
| 807 |
+
"# }\n",
|
| 808 |
+
"\n",
|
| 809 |
+
"\n",
|
| 810 |
+
"# Optimized Classifiers SevPred\n",
|
| 811 |
+
"# classifiers = {\n",
|
| 812 |
+
"# 'RF': RandomForestClassifier(\n",
|
| 813 |
+
"# n_estimators=300, \n",
|
| 814 |
+
"# criterion='entropy', \n",
|
| 815 |
+
"# max_depth=15, \n",
|
| 816 |
+
"# bootstrap=False, \n",
|
| 817 |
+
"# random_state=0\n",
|
| 818 |
+
"# ),\n",
|
| 819 |
+
"# 'XGB': XGBClassifier(\n",
|
| 820 |
+
"# n_estimators=300, \n",
|
| 821 |
+
"# max_depth=5, \n",
|
| 822 |
+
"# learning_rate=0.2, \n",
|
| 823 |
+
"# gamma=0.0, \n",
|
| 824 |
+
"# use_label_encoder=False, \n",
|
| 825 |
+
"# eval_metric='mlogloss', \n",
|
| 826 |
+
"# random_state=0\n",
|
| 827 |
+
"# ),\n",
|
| 828 |
+
"# 'GB': GradientBoostingClassifier(\n",
|
| 829 |
+
"# n_estimators=100, \n",
|
| 830 |
+
"# max_depth=5, \n",
|
| 831 |
+
"# learning_rate=0.1, \n",
|
| 832 |
+
"# subsample=0.7\n",
|
| 833 |
+
"# ),\n",
|
| 834 |
+
"# 'AB': AdaBoostClassifier(\n",
|
| 835 |
+
"# n_estimators=300, \n",
|
| 836 |
+
"# learning_rate=0.5, \n",
|
| 837 |
+
"# algorithm='SAMME'\n",
|
| 838 |
+
"# ),\n",
|
| 839 |
+
"# 'CB': CatBoostClassifier(\n",
|
| 840 |
+
"# depth=4,\n",
|
| 841 |
+
"# iterations=180,\n",
|
| 842 |
+
"# learning_rate=0.09,\n",
|
| 843 |
+
"# verbose=0\n",
|
| 844 |
+
"# ),\n",
|
| 845 |
+
"# 'LGBM': LGBMClassifier(\n",
|
| 846 |
+
"# learning_rate=0.08,\n",
|
| 847 |
+
"# max_depth=4,\n",
|
| 848 |
+
"# n_estimators=220\n",
|
| 849 |
+
"# )\n",
|
| 850 |
+
"# }\n",
|
| 851 |
+
"\n",
|
| 852 |
+
"\n",
|
| 853 |
+
"# Default Classifiers\n",
|
| 854 |
+
"# classifiers = {\n",
|
| 855 |
+
"# 'RandomForest': RandomForestClassifier(n_estimators=100, criterion='entropy', max_depth=7, random_state=0),\n",
|
| 856 |
+
"# 'XGBoost': XGBClassifier(n_estimators=100, max_depth=7, use_label_encoder=False, eval_metric='mlogloss', random_state=0),\n",
|
| 857 |
+
"# 'GradientBoosting': GradientBoostingClassifier(n_estimators=100, random_state=0),\n",
|
| 858 |
+
"# 'AdaBoost': AdaBoostClassifier(n_estimators=100, random_state=0),\n",
|
| 859 |
+
"# 'CatBoost': CatBoostClassifier(n_estimators=100, verbose=0, random_state=0),\n",
|
| 860 |
+
"# 'LightGBM': LGBMClassifier(n_estimators=100, random_state=0)\n",
|
| 861 |
+
"# }\n",
|
| 862 |
+
"\n",
|
| 863 |
+
"# voting_clf = VotingClassifier(estimators=[\n",
|
| 864 |
+
"# ('rf', classifiers['RF']),\n",
|
| 865 |
+
"# ('xgb', classifiers['XGB']),\n",
|
| 866 |
+
"# ('gb', classifiers['GB']),\n",
|
| 867 |
+
"# ('ada', classifiers['AB']),\n",
|
| 868 |
+
"# ('cat', classifiers['CB']),\n",
|
| 869 |
+
"# ('lgbm', classifiers['LGBM'])\n",
|
| 870 |
+
"# ], voting='soft', n_jobs=1)\n",
|
| 871 |
+
"\n",
|
| 872 |
+
"# classifiers['Vot'] = voting_clf\n",
|
| 873 |
+
"\n",
|
| 874 |
+
"# num_features = {\n",
|
| 875 |
+
"# 'RF': 150,\n",
|
| 876 |
+
"# 'XGB': 150,\n",
|
| 877 |
+
"# 'GB': 150,\n",
|
| 878 |
+
" # 'AB': 150,\n",
|
| 879 |
+
" # 'CB': 150,\n",
|
| 880 |
+
" # 'LGBM': 150,\n",
|
| 881 |
+
" # 'Vot': 150\n",
|
| 882 |
+
"# }\n",
|
| 883 |
+
"\n",
|
| 884 |
+
"for random_state in range(10):\n",
|
| 885 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 886 |
+
"\n",
|
| 887 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 888 |
+
" y = splitted_dataset['depression_category']\n",
|
| 889 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 890 |
+
" \n",
|
| 891 |
+
" lof = LocalOutlierFactor()\n",
|
| 892 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 893 |
+
" mask = yhat != -1\n",
|
| 894 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 895 |
+
" \n",
|
| 896 |
+
" original_columns = X.columns.tolist()\n",
|
| 897 |
+
"\n",
|
| 898 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 899 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 900 |
+
"\n",
|
| 901 |
+
" print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 902 |
+
"\n",
|
| 903 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 904 |
+
" \n",
|
| 905 |
+
" sampling_strategy_undersample = {0: 155}\n",
|
| 906 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 907 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 908 |
+
"\n",
|
| 909 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 910 |
+
"\n",
|
| 911 |
+
" scaler = MinMaxScaler()\n",
|
| 912 |
+
" \n",
|
| 913 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 914 |
+
" X_test = scaler.transform(X_test)\n",
|
| 915 |
+
" \n",
|
| 916 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 917 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 918 |
+
"\n",
|
| 919 |
+
" log_reg = LogisticRegression(C=0.09659168435718246, max_iter=100, solver='lbfgs', random_state=random_state)\n",
|
| 920 |
+
" log_reg.fit(X_res, y_res)\n",
|
| 921 |
+
" selector = SelectFromModel(log_reg, prefit=True)\n",
|
| 922 |
+
" \n",
|
| 923 |
+
" importance = np.abs(log_reg.coef_[0])\n",
|
| 924 |
+
" indices = np.argsort(importance)[::-1]\n",
|
| 925 |
+
" important_features = [original_columns[i] for i in indices]\n",
|
| 926 |
+
" \n",
|
| 927 |
+
" for clf_name, clf in classifiers.items():\n",
|
| 928 |
+
" num_top_features = num_features[clf_name]\n",
|
| 929 |
+
" selected_features = important_features[:num_top_features]\n",
|
| 930 |
+
" \n",
|
| 931 |
+
" X_res_fi = pd.DataFrame(X_res, columns=original_columns)[selected_features]\n",
|
| 932 |
+
" X_test_fi = pd.DataFrame(X_test, columns=original_columns)[selected_features]\n",
|
| 933 |
+
"\n",
|
| 934 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(\n",
|
| 935 |
+
" clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name\n",
|
| 936 |
+
" )\n",
|
| 937 |
+
" metric_sums[clf_name]['accuracy'] += accuracy\n",
|
| 938 |
+
" metric_sums[clf_name]['precision'] += precision\n",
|
| 939 |
+
" metric_sums[clf_name]['recall'] += recall\n",
|
| 940 |
+
" metric_sums[clf_name]['f1'] += f1\n",
|
| 941 |
+
" conf_matrices[clf_name].append(conf_matrix)\n",
|
| 942 |
+
" roc_curves[clf_name].append((fpr, tpr))\n",
|
| 943 |
+
" roc_aucs[clf_name].append(roc_auc)\n",
|
| 944 |
+
" accuracy_scores[clf_name].append(accuracy)\n",
|
| 945 |
+
" precision_scores[clf_name].append(precision)\n",
|
| 946 |
+
" recall_scores[clf_name].append(recall)\n",
|
| 947 |
+
" f1_scores[clf_name].append(f1)"
|
| 948 |
+
]
|
| 949 |
+
},
|
| 950 |
+
{
|
| 951 |
+
"cell_type": "markdown",
|
| 952 |
+
"id": "c0158f58-e6d5-4adf-90e3-f1892294ad24",
|
| 953 |
+
"metadata": {},
|
| 954 |
+
"source": [
|
| 955 |
+
"### Post Training (classic/static)\n",
|
| 956 |
+
"Only run after one of the training methods above are done"
|
| 957 |
+
]
|
| 958 |
+
},
|
| 959 |
+
{
|
| 960 |
+
"cell_type": "code",
|
| 961 |
+
"execution_count": null,
|
| 962 |
+
"id": "c2ab5543-ad6e-4463-8d9a-35fe3e4e8eb4",
|
| 963 |
+
"metadata": {},
|
| 964 |
+
"outputs": [],
|
| 965 |
+
"source": [
|
| 966 |
+
"print('\\nAverage Metrics over 10 Random States:')\n",
|
| 967 |
+
"for clf_name, metrics in metric_sums.items():\n",
|
| 968 |
+
" avg_accuracy = metrics['accuracy'] / 10\n",
|
| 969 |
+
" avg_precision = metrics['precision'] / 10\n",
|
| 970 |
+
" avg_recall = metrics['recall'] / 10\n",
|
| 971 |
+
" avg_f1 = metrics['f1'] / 10\n",
|
| 972 |
+
" std_accuracy = np.std(accuracy_scores[clf_name])\n",
|
| 973 |
+
" std_precision = np.std(precision_scores[clf_name])\n",
|
| 974 |
+
" std_recall = np.std(recall_scores[clf_name])\n",
|
| 975 |
+
" std_f1 = np.std(f1_scores[clf_name])\n",
|
| 976 |
+
" avg_auc = np.mean(roc_aucs[clf_name])\n",
|
| 977 |
+
" print(f'{clf_name} - Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}, Precision: {avg_precision:.4f} ± {std_precision:.4f}, Recall: {avg_recall:.4f} ± {std_recall:.4f}, F1-Score: {avg_f1:.4f} ± {std_f1:.4f}, AUC: {avg_auc:.4f}')"
|
| 978 |
+
]
|
| 979 |
+
},
|
| 980 |
+
{
|
| 981 |
+
"cell_type": "code",
|
| 982 |
+
"execution_count": null,
|
| 983 |
+
"id": "21d5c872-3452-41ca-8f7c-5f4ceb184fba",
|
| 984 |
+
"metadata": {},
|
| 985 |
+
"outputs": [],
|
| 986 |
+
"source": [
|
| 987 |
+
"# Plot ROC Curves for each classifier in one graph\n",
|
| 988 |
+
"plot_combined_roc_curve(roc_curves, classifiers.keys())"
|
| 989 |
+
]
|
| 990 |
+
},
|
| 991 |
+
{
|
| 992 |
+
"cell_type": "code",
|
| 993 |
+
"execution_count": null,
|
| 994 |
+
"id": "26f7a715-efc8-44a6-8c53-e59b6268e551",
|
| 995 |
+
"metadata": {
|
| 996 |
+
"scrolled": true
|
| 997 |
+
},
|
| 998 |
+
"outputs": [],
|
| 999 |
+
"source": [
|
| 1000 |
+
"# FN Curve\n",
|
| 1001 |
+
"df = pd.DataFrame(accuracy_scores)\n",
|
| 1002 |
+
"scores = [df[col].values for col in df.columns]\n",
|
| 1003 |
+
"stat, p = friedmanchisquare(*scores)\n",
|
| 1004 |
+
"print(f'Friedman Test Statistic: {stat}, p-value: {p}')\n",
|
| 1005 |
+
"ranks = df.rank(axis=1, method='average', ascending=False)\n",
|
| 1006 |
+
"average_ranks = ranks.mean().values\n",
|
| 1007 |
+
"n_datasets = df.shape[0]\n",
|
| 1008 |
+
"alpha = 0.05\n",
|
| 1009 |
+
"cd = compute_CD(average_ranks, n_datasets, alpha='0.05')\n",
|
| 1010 |
+
"print(f'Critical Difference: {cd}')\n",
|
| 1011 |
+
"classifiers = [f\"{clf} ({rank:.2f})\" for clf, rank in zip(df.columns, average_ranks)]\n",
|
| 1012 |
+
"plt.figure(figsize=(14, 8))\n",
|
| 1013 |
+
"graph_ranks(average_ranks, classifiers, cd=cd, width=6, textspace=1)\n",
|
| 1014 |
+
"plt.text(0.5, 1.19, f'Friedman-Nemenyi: {stat:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=14)\n",
|
| 1015 |
+
"plt.text(0.5, 1.10, f'CD: {cd:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=14)\n",
|
| 1016 |
+
"plt.tight_layout()"
|
| 1017 |
+
]
|
| 1018 |
+
},
|
| 1019 |
+
{
|
| 1020 |
+
"cell_type": "markdown",
|
| 1021 |
+
"id": "88072d69-8bb5-46ab-9c79-6990db7cc8d2",
|
| 1022 |
+
"metadata": {},
|
| 1023 |
+
"source": [
|
| 1024 |
+
"### Hyperparameter optimization (classic/static)"
|
| 1025 |
+
]
|
| 1026 |
+
},
|
| 1027 |
+
{
|
| 1028 |
+
"cell_type": "code",
|
| 1029 |
+
"execution_count": null,
|
| 1030 |
+
"id": "3a4a1546-c24f-4349-bf1b-75ba937700be",
|
| 1031 |
+
"metadata": {
|
| 1032 |
+
"scrolled": true
|
| 1033 |
+
},
|
| 1034 |
+
"outputs": [],
|
| 1035 |
+
"source": [
|
| 1036 |
+
"# Hyperparameter optimization classic\n",
|
| 1037 |
+
"search_spaces = {\n",
|
| 1038 |
+
" 'DecisionTree': {\n",
|
| 1039 |
+
" 'criterion': Categorical(['gini', 'entropy']),\n",
|
| 1040 |
+
" 'max_depth': Integer(1, 20),\n",
|
| 1041 |
+
" 'min_samples_split': Integer(2, 10),\n",
|
| 1042 |
+
" 'min_samples_leaf': Integer(1, 10)\n",
|
| 1043 |
+
" },\n",
|
| 1044 |
+
" 'LogisticRegression': {\n",
|
| 1045 |
+
" 'C': Real(1e-6, 1e+6, prior='log-uniform'),\n",
|
| 1046 |
+
" 'solver': Categorical(['lbfgs', 'liblinear']),\n",
|
| 1047 |
+
" 'max_iter': Integer(100, 1000)\n",
|
| 1048 |
+
" },\n",
|
| 1049 |
+
" 'NaiveBayes': {\n",
|
| 1050 |
+
" 'var_smoothing': Real(1e-9, 1e-2, prior='log-uniform')\n",
|
| 1051 |
+
" },\n",
|
| 1052 |
+
" 'KNeighbors': {\n",
|
| 1053 |
+
" 'n_neighbors': Integer(1, 30),\n",
|
| 1054 |
+
" 'weights': Categorical(['uniform', 'distance']),\n",
|
| 1055 |
+
" 'metric': Categorical(['euclidean', 'manhattan', 'minkowski'])\n",
|
| 1056 |
+
" },\n",
|
| 1057 |
+
" 'MLP': {\n",
|
| 1058 |
+
" 'hidden_layer_sizes': Integer(50, 200),\n",
|
| 1059 |
+
" 'alpha': Real(1e-6, 1e-2, prior='log-uniform'),\n",
|
| 1060 |
+
" 'learning_rate_init': Real(1e-4, 1e-2, prior='log-uniform')\n",
|
| 1061 |
+
" },\n",
|
| 1062 |
+
" 'SVC': {\n",
|
| 1063 |
+
" 'C': [0.1, 1, 10, 100, 1000],\n",
|
| 1064 |
+
" 'gamma': [1, 0.1, 0.01, 0.001, 0.0001],\n",
|
| 1065 |
+
" 'kernel': ['rbf']\n",
|
| 1066 |
+
" }\n",
|
| 1067 |
+
"}\n",
|
| 1068 |
+
"\n",
|
| 1069 |
+
"classifiers = {\n",
|
| 1070 |
+
" 'DecisionTree': DecisionTreeClassifier(random_state=0),\n",
|
| 1071 |
+
" 'LogisticRegression': LogisticRegression(max_iter=1000, random_state=0),\n",
|
| 1072 |
+
" 'NaiveBayes': GaussianNB(),\n",
|
| 1073 |
+
" 'KNeighbors': KNeighborsClassifier(),\n",
|
| 1074 |
+
" 'MLP': MLPClassifier(max_iter=1000, random_state=0),\n",
|
| 1075 |
+
" 'SVC': SVC(probability=True, random_state=0)\n",
|
| 1076 |
+
"}\n",
|
| 1077 |
+
"\n",
|
| 1078 |
+
"top_features_count = {\n",
|
| 1079 |
+
" 'DecisionTree': 200,\n",
|
| 1080 |
+
" 'LogisticRegression': 200,\n",
|
| 1081 |
+
" 'NaiveBayes': 200,\n",
|
| 1082 |
+
" 'KNeighbors': 200,\n",
|
| 1083 |
+
" 'MLP': 200,\n",
|
| 1084 |
+
" 'SVC': 200\n",
|
| 1085 |
+
"}\n",
|
| 1086 |
+
"\n",
|
| 1087 |
+
"random_state = 0\n",
|
| 1088 |
+
"print(f\"Processing for Random State: {random_state}\")\n",
|
| 1089 |
+
"\n",
|
| 1090 |
+
"X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1091 |
+
"y = splitted_dataset['depression_category']\n",
|
| 1092 |
+
"\n",
|
| 1093 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1094 |
+
"\n",
|
| 1095 |
+
"lof = LocalOutlierFactor()\n",
|
| 1096 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 1097 |
+
"mask = yhat != -1\n",
|
| 1098 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
"original_columns = X.columns.tolist()\n",
|
| 1101 |
+
"\n",
|
| 1102 |
+
"smote = SMOTE(random_state=random_state)\n",
|
| 1103 |
+
"X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1104 |
+
"\n",
|
| 1105 |
+
"print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 1106 |
+
"\n",
|
| 1107 |
+
"print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 1108 |
+
"\n",
|
| 1109 |
+
"sampling_strategy_undersample = {0: 155}\n",
|
| 1110 |
+
"rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1111 |
+
"X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1112 |
+
"print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1113 |
+
"\n",
|
| 1114 |
+
"scaler = MinMaxScaler()\n",
|
| 1115 |
+
"\n",
|
| 1116 |
+
"X_res = scaler.fit_transform(X_res)\n",
|
| 1117 |
+
"X_test = scaler.transform(X_test)\n",
|
| 1118 |
+
"\n",
|
| 1119 |
+
"X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1120 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1121 |
+
"\n",
|
| 1122 |
+
"corr_df = X_res.copy()\n",
|
| 1123 |
+
"corr_df['target'] = y_res\n",
|
| 1124 |
+
"\n",
|
| 1125 |
+
"corr_mat = corr_df.corr()\n",
|
| 1126 |
+
"target_correlation = corr_mat['target'].drop('target')\n",
|
| 1127 |
+
"\n",
|
| 1128 |
+
"for clf_name, clf in classifiers.items():\n",
|
| 1129 |
+
" print(f\"Optimizing {clf_name}\")\n",
|
| 1130 |
+
" \n",
|
| 1131 |
+
" top_features = target_correlation.abs().sort_values(ascending=False).head(top_features_count[clf_name]).index.tolist()\n",
|
| 1132 |
+
" \n",
|
| 1133 |
+
" X_res_fi = X_res[top_features]\n",
|
| 1134 |
+
" X_test_fi = X_test[top_features]\n",
|
| 1135 |
+
" \n",
|
| 1136 |
+
" opt = BayesSearchCV(clf, search_spaces[clf_name], n_iter=30, cv=3, random_state=random_state, n_jobs=-1, verbose = 30)\n",
|
| 1137 |
+
" opt.fit(X_res_fi, y_res)\n",
|
| 1138 |
+
" \n",
|
| 1139 |
+
" best_clf = opt.best_estimator_\n",
|
| 1140 |
+
" best_params = opt.best_params_\n",
|
| 1141 |
+
"\n",
|
| 1142 |
+
" print(f\"Best parameters for {clf_name}: {best_params}\")\n",
|
| 1143 |
+
" \n",
|
| 1144 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(best_clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name)\n",
|
| 1145 |
+
" print(f\"Best results for {clf_name}:\")\n",
|
| 1146 |
+
" print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, AUC: {roc_auc:.4f}')\n",
|
| 1147 |
+
" print(conf_matrix)\n",
|
| 1148 |
+
" print()"
|
| 1149 |
+
]
|
| 1150 |
+
},
|
| 1151 |
+
{
|
| 1152 |
+
"cell_type": "code",
|
| 1153 |
+
"execution_count": null,
|
| 1154 |
+
"id": "2f473cc7-577d-4dcd-a8b1-be7f33200fbc",
|
| 1155 |
+
"metadata": {
|
| 1156 |
+
"scrolled": true
|
| 1157 |
+
},
|
| 1158 |
+
"outputs": [],
|
| 1159 |
+
"source": [
|
| 1160 |
+
"# Hyperparameter optimization static\n",
|
| 1161 |
+
"metric_sums = defaultdict(lambda: {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0})\n",
|
| 1162 |
+
"conf_matrices = defaultdict(list)\n",
|
| 1163 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 1164 |
+
"precision_scores = defaultdict(list)\n",
|
| 1165 |
+
"recall_scores = defaultdict(list)\n",
|
| 1166 |
+
"f1_scores = defaultdict(list)\n",
|
| 1167 |
+
"\n",
|
| 1168 |
+
"classifiers = {\n",
|
| 1169 |
+
" # 'RandomForest': RandomForestClassifier(),\n",
|
| 1170 |
+
" # 'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='mlogloss'),\n",
|
| 1171 |
+
" # 'AdaBoost': AdaBoostClassifier(),\n",
|
| 1172 |
+
" # 'GradientBoosting': GradientBoostingClassifier(),\n",
|
| 1173 |
+
" 'CatBoost': CatBoostClassifier(verbose=0),\n",
|
| 1174 |
+
" 'LightGBM': LGBMClassifier()\n",
|
| 1175 |
+
"}\n",
|
| 1176 |
+
"\n",
|
| 1177 |
+
"num_features = {\n",
|
| 1178 |
+
" # 'RandomForest': 150,\n",
|
| 1179 |
+
" # 'XGBoost': 150,\n",
|
| 1180 |
+
" # 'GradientBoosting': 150,\n",
|
| 1181 |
+
" # 'AdaBoost': 150,\n",
|
| 1182 |
+
" 'CatBoost': 150,\n",
|
| 1183 |
+
" 'LightGBM': 150,\n",
|
| 1184 |
+
"}\n",
|
| 1185 |
+
"\n",
|
| 1186 |
+
"search_spaces = {\n",
|
| 1187 |
+
" # 'RandomForest': {\n",
|
| 1188 |
+
" # 'n_estimators': [100, 200, 300],\n",
|
| 1189 |
+
" # 'criterion': ['gini', 'entropy'],\n",
|
| 1190 |
+
" # 'max_depth': [None, 7, 15],\n",
|
| 1191 |
+
" # 'bootstrap': [True, False]\n",
|
| 1192 |
+
" # },\n",
|
| 1193 |
+
" # 'XGBoost': {\n",
|
| 1194 |
+
" # 'n_estimators': [100, 200, 300],\n",
|
| 1195 |
+
" # 'max_depth': [5, 10],\n",
|
| 1196 |
+
" # 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1197 |
+
" # 'gamma': [0, 0.2, 0.4],\n",
|
| 1198 |
+
" # },\n",
|
| 1199 |
+
" # 'GradientBoosting': {\n",
|
| 1200 |
+
" # 'n_estimators': [100, 200, 300],\n",
|
| 1201 |
+
" # 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1202 |
+
" # 'max_depth': [5, 10],\n",
|
| 1203 |
+
" # 'subsample': [0.7, 0.9, 1.0],\n",
|
| 1204 |
+
" # },\n",
|
| 1205 |
+
" # 'AdaBoost': {\n",
|
| 1206 |
+
" # 'n_estimators': [100, 200, 300],\n",
|
| 1207 |
+
" # 'learning_rate': [0.1, 0.5, 1.0],\n",
|
| 1208 |
+
" # 'algorithm': ['SAMME', 'SAMME.R']\n",
|
| 1209 |
+
" # },\n",
|
| 1210 |
+
" 'CatBoost': {\n",
|
| 1211 |
+
" 'iterations': [100, 200, 300],\n",
|
| 1212 |
+
" 'depth': [5, 7, 9],\n",
|
| 1213 |
+
" # 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1214 |
+
" },\n",
|
| 1215 |
+
" 'LightGBM': {\n",
|
| 1216 |
+
" 'n_estimators': [100, 200, 300],\n",
|
| 1217 |
+
" 'num_leaves': [31, 63, 127],\n",
|
| 1218 |
+
" # 'learning_rate': [0.01, 0.1, 0.2],\n",
|
| 1219 |
+
" # 'subsample': [0.7, 0.9, 1.0],\n",
|
| 1220 |
+
" }\n",
|
| 1221 |
+
"}\n",
|
| 1222 |
+
"\n",
|
| 1223 |
+
"def hyperparameter_optimization(clf, search_space, X, y):\n",
|
| 1224 |
+
" combined_results = []\n",
|
| 1225 |
+
" for random_state in range(3):\n",
|
| 1226 |
+
" cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=random_state)\n",
|
| 1227 |
+
" opt = BayesSearchCV(clf, search_space, n_iter=30, cv=cv, random_state=random_state, n_jobs=-1, verbose=0)\n",
|
| 1228 |
+
" opt.fit(X, y)\n",
|
| 1229 |
+
" combined_results.append(opt.best_params_)\n",
|
| 1230 |
+
" best_params = pd.DataFrame(combined_results).mode().iloc[0].to_dict()\n",
|
| 1231 |
+
" return best_params\n",
|
| 1232 |
+
"\n",
|
| 1233 |
+
"for random_state in range(9,10):\n",
|
| 1234 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 1235 |
+
"\n",
|
| 1236 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1237 |
+
" y = splitted_dataset['depression_category']\n",
|
| 1238 |
+
" \n",
|
| 1239 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1240 |
+
" \n",
|
| 1241 |
+
" lof = LocalOutlierFactor()\n",
|
| 1242 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 1243 |
+
" mask = yhat != -1\n",
|
| 1244 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1245 |
+
" \n",
|
| 1246 |
+
" original_columns = X.columns.tolist()\n",
|
| 1247 |
+
"\n",
|
| 1248 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 1249 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1250 |
+
"\n",
|
| 1251 |
+
" print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 1252 |
+
"\n",
|
| 1253 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\") \n",
|
| 1254 |
+
" \n",
|
| 1255 |
+
" sampling_strategy_undersample = {0: 155}\n",
|
| 1256 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1257 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1258 |
+
"\n",
|
| 1259 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1260 |
+
"\n",
|
| 1261 |
+
" scaler = MinMaxScaler()\n",
|
| 1262 |
+
" \n",
|
| 1263 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 1264 |
+
" X_test = scaler.transform(X_test)\n",
|
| 1265 |
+
" \n",
|
| 1266 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1267 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1268 |
+
"\n",
|
| 1269 |
+
" log_reg = LogisticRegression(C=0.09659168435718246, max_iter=100, solver='lbfgs', random_state=random_state)\n",
|
| 1270 |
+
" log_reg.fit(X_res, y_res)\n",
|
| 1271 |
+
" selector = SelectFromModel(log_reg, prefit=True)\n",
|
| 1272 |
+
" \n",
|
| 1273 |
+
" importance = np.abs(log_reg.coef_[0])\n",
|
| 1274 |
+
" indices = np.argsort(importance)[::-1]\n",
|
| 1275 |
+
" important_features = [original_columns[i] for i in indices[:300]]\n",
|
| 1276 |
+
"\n",
|
| 1277 |
+
" for clf_name, clf in classifiers.items():\n",
|
| 1278 |
+
" print(f\"Optimizing {clf_name}\")\n",
|
| 1279 |
+
" num_top_features = num_features[clf_name]\n",
|
| 1280 |
+
" selected_features = important_features[:num_top_features]\n",
|
| 1281 |
+
" \n",
|
| 1282 |
+
" X_res_fi = pd.DataFrame(X_res, columns=original_columns)[selected_features]\n",
|
| 1283 |
+
" \n",
|
| 1284 |
+
" best_params = hyperparameter_optimization(clf, search_spaces[clf_name], X_res_fi, y_res)\n",
|
| 1285 |
+
" if 'n_estimators' in best_params:\n",
|
| 1286 |
+
" best_params['n_estimators'] = int(best_params['n_estimators'])\n",
|
| 1287 |
+
" if 'max_depth' in best_params:\n",
|
| 1288 |
+
" best_params['max_depth'] = int(best_params['max_depth'])\n",
|
| 1289 |
+
" if 'iterations' in best_params:\n",
|
| 1290 |
+
" best_params['iterations'] = int(best_params['iterations'])\n",
|
| 1291 |
+
" clf.set_params(**best_params)\n",
|
| 1292 |
+
" print(f\"Best parameters for {clf_name}: {best_params}\")\n",
|
| 1293 |
+
"\n",
|
| 1294 |
+
" X_test_fi = pd.DataFrame(X_test, columns=original_columns)[selected_features]\n",
|
| 1295 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(clf, X_res_fi, y_res, X_test_fi, y_test, clf_name=clf_name)\n",
|
| 1296 |
+
" metric_sums[clf_name]['accuracy'] += accuracy\n",
|
| 1297 |
+
" metric_sums[clf_name]['precision'] += precision\n",
|
| 1298 |
+
" metric_sums[clf_name]['recall'] += recall\n",
|
| 1299 |
+
" metric_sums[clf_name]['f1'] += f1\n",
|
| 1300 |
+
" conf_matrices[clf_name].append(conf_matrix)\n",
|
| 1301 |
+
" accuracy_scores[clf_name].append(accuracy)\n",
|
| 1302 |
+
" precision_scores[clf_name].append(precision)\n",
|
| 1303 |
+
" recall_scores[clf_name].append(recall)\n",
|
| 1304 |
+
" f1_scores[clf_name].append(f1)"
|
| 1305 |
+
]
|
| 1306 |
+
},
|
| 1307 |
+
{
|
| 1308 |
+
"cell_type": "markdown",
|
| 1309 |
+
"id": "65df93d4-12d3-4d68-b402-633354849dff",
|
| 1310 |
+
"metadata": {},
|
| 1311 |
+
"source": [
|
| 1312 |
+
"### DES Training (all)"
|
| 1313 |
+
]
|
| 1314 |
+
},
|
| 1315 |
+
{
|
| 1316 |
+
"cell_type": "code",
|
| 1317 |
+
"execution_count": null,
|
| 1318 |
+
"id": "709bb1f8-249e-4409-a537-c6bbe9a399f3",
|
| 1319 |
+
"metadata": {
|
| 1320 |
+
"scrolled": true
|
| 1321 |
+
},
|
| 1322 |
+
"outputs": [],
|
| 1323 |
+
"source": [
|
| 1324 |
+
"metric_sums_des = {\n",
|
| 1325 |
+
" 'KNORAE': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1326 |
+
" 'KNORAU': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1327 |
+
" 'KNOP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1328 |
+
" 'DESMI': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1329 |
+
" 'METADES': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1330 |
+
" 'DESKNN': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1331 |
+
" 'DESP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1332 |
+
" 'FIRE-KNORA-U': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1333 |
+
" 'FIRE-KNORA-E': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1334 |
+
" 'FIRE-METADES': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1335 |
+
" 'FIRE-DESKNN': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1336 |
+
" 'FIRE-DESP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1337 |
+
" 'FIRE-KNOP': {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0},\n",
|
| 1338 |
+
"}\n",
|
| 1339 |
+
"\n",
|
| 1340 |
+
"conf_matrices_des = {\n",
|
| 1341 |
+
" 'KNORAE': [],\n",
|
| 1342 |
+
" 'KNORAU': [],\n",
|
| 1343 |
+
" 'KNOP': [],\n",
|
| 1344 |
+
" 'DESMI': [],\n",
|
| 1345 |
+
" 'METADES': [],\n",
|
| 1346 |
+
" 'DESKNN': [],\n",
|
| 1347 |
+
" 'DESP': [],\n",
|
| 1348 |
+
" 'FIRE-KNORA-U': [],\n",
|
| 1349 |
+
" 'FIRE-KNORA-E': [],\n",
|
| 1350 |
+
" 'FIRE-METADES': [],\n",
|
| 1351 |
+
" 'FIRE-DESKNN': [],\n",
|
| 1352 |
+
" 'FIRE-DESP': [],\n",
|
| 1353 |
+
" 'FIRE-KNOP': [],\n",
|
| 1354 |
+
"}\n",
|
| 1355 |
+
"\n",
|
| 1356 |
+
"roc_curves = defaultdict(list)\n",
|
| 1357 |
+
"roc_aucs = defaultdict(list)\n",
|
| 1358 |
+
"accuracy_scores = defaultdict(list)\n",
|
| 1359 |
+
"precision_scores = defaultdict(list)\n",
|
| 1360 |
+
"recall_scores = defaultdict(list)\n",
|
| 1361 |
+
"f1_scores = defaultdict(list)\n",
|
| 1362 |
+
"feature_importance_runs = []\n",
|
| 1363 |
+
"\n",
|
| 1364 |
+
"# Uncomment wanted combinations\n",
|
| 1365 |
+
"# base_classifiers = {\n",
|
| 1366 |
+
" # 'DecisionTree': DecisionTreeClassifier(\n",
|
| 1367 |
+
" # random_state=0, \n",
|
| 1368 |
+
" # criterion='gini', \n",
|
| 1369 |
+
" # max_depth=6, \n",
|
| 1370 |
+
" # min_samples_leaf=10, \n",
|
| 1371 |
+
" # min_samples_split=9\n",
|
| 1372 |
+
" # ),\n",
|
| 1373 |
+
" # 'LogisticRegression': LogisticRegression(\n",
|
| 1374 |
+
" # random_state=0, \n",
|
| 1375 |
+
" # C=0.09659168435718246, \n",
|
| 1376 |
+
" # max_iter=100, \n",
|
| 1377 |
+
" # solver='lbfgs'\n",
|
| 1378 |
+
" # ),\n",
|
| 1379 |
+
" # 'NaiveBayes': GaussianNB(\n",
|
| 1380 |
+
" # var_smoothing=0.0058873326349240295\n",
|
| 1381 |
+
" # ),\n",
|
| 1382 |
+
" # 'KNeighbors': KNeighborsClassifier(\n",
|
| 1383 |
+
" # metric='manhattan', \n",
|
| 1384 |
+
" # n_neighbors=15, \n",
|
| 1385 |
+
" # weights='uniform'\n",
|
| 1386 |
+
" # ),\n",
|
| 1387 |
+
" # 'MLP': MLPClassifier(\n",
|
| 1388 |
+
" # random_state=0, \n",
|
| 1389 |
+
" # max_iter=1000, \n",
|
| 1390 |
+
" # alpha=0.0003079393718075164, \n",
|
| 1391 |
+
" # hidden_layer_sizes=195, \n",
|
| 1392 |
+
" # learning_rate_init=0.0001675266159417717\n",
|
| 1393 |
+
" # ),\n",
|
| 1394 |
+
" # 'SVC': SVC(probability=True, kernel = 'rbf', C = 1.5, gamma = 'auto'),\n",
|
| 1395 |
+
" # 'RF': RandomForestClassifier(n_estimators=143, criterion='entropy', max_depth=15, random_state=0),\n",
|
| 1396 |
+
" # 'XGB': XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=0),\n",
|
| 1397 |
+
" # 'GB': GradientBoostingClassifier(n_estimators=300, max_depth=3, learning_rate=0.05),\n",
|
| 1398 |
+
" # 'AB': AdaBoostClassifier(n_estimators=400, learning_rate=0.1),\n",
|
| 1399 |
+
" # 'CB': CatBoostClassifier(depth = 3, iterations = 168, learning_rate = 0.1, verbose = 0),\n",
|
| 1400 |
+
" # 'LGBM': LGBMClassifier(learning_rate = 0.1, max_depth = 3, n_estimators = 200) \n",
|
| 1401 |
+
"# }\n",
|
| 1402 |
+
"\n",
|
| 1403 |
+
"random_state = 0\n",
|
| 1404 |
+
"\n",
|
| 1405 |
+
"for random_state in range(10):\n",
|
| 1406 |
+
" print(f\"Processing for Random State: {random_state}\")\n",
|
| 1407 |
+
"\n",
|
| 1408 |
+
" X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1409 |
+
" y = splitted_dataset['depression_category']\n",
|
| 1410 |
+
" \n",
|
| 1411 |
+
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1412 |
+
" \n",
|
| 1413 |
+
" lof = LocalOutlierFactor()\n",
|
| 1414 |
+
" yhat = lof.fit_predict(X_train)\n",
|
| 1415 |
+
" mask = yhat != -1\n",
|
| 1416 |
+
" X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1417 |
+
" \n",
|
| 1418 |
+
" original_columns = X.columns.tolist()\n",
|
| 1419 |
+
"\n",
|
| 1420 |
+
" smote = SMOTE(random_state=random_state)\n",
|
| 1421 |
+
" X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1422 |
+
"\n",
|
| 1423 |
+
" print(f\"Number of test labels before resampling: {y_test.value_counts()}\")\n",
|
| 1424 |
+
" sampling_strategy_undersample = {0: 155}\n",
|
| 1425 |
+
" rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1426 |
+
" X_test, y_test = rus.fit_resample(X_test, y_test) \n",
|
| 1427 |
+
"\n",
|
| 1428 |
+
" print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1429 |
+
"\n",
|
| 1430 |
+
" scaler = MinMaxScaler()\n",
|
| 1431 |
+
" X_res = scaler.fit_transform(X_res)\n",
|
| 1432 |
+
" X_test = scaler.transform(X_test)\n",
|
| 1433 |
+
" \n",
|
| 1434 |
+
" X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1435 |
+
" X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1436 |
+
"\n",
|
| 1437 |
+
" ada_fs = AdaBoostClassifier(n_estimators=100, random_state = random_state)\n",
|
| 1438 |
+
" ada_fs.fit(X_res, y_res)\n",
|
| 1439 |
+
"\n",
|
| 1440 |
+
" feature_importances = ada_fs.feature_importances_\n",
|
| 1441 |
+
" indices = np.argsort(feature_importances)[::-1]\n",
|
| 1442 |
+
" top_50_features = [original_columns[i] for i in indices[:50]]\n",
|
| 1443 |
+
" current_run_features = {original_columns[i]: feature_importances[i] for i in indices[:50]}\n",
|
| 1444 |
+
" \n",
|
| 1445 |
+
" feature_importance_runs.append(current_run_features)\n",
|
| 1446 |
+
"\n",
|
| 1447 |
+
" X_res_fi = X_res[top_50_features]\n",
|
| 1448 |
+
" X_test_fi = X_test[top_50_features]\n",
|
| 1449 |
+
" \n",
|
| 1450 |
+
" model_pool = list(base_classifiers.values())\n",
|
| 1451 |
+
" \n",
|
| 1452 |
+
" for clf in model_pool:\n",
|
| 1453 |
+
" clf.fit(X_res_fi, y_res)\n",
|
| 1454 |
+
" \n",
|
| 1455 |
+
" des_models = {\n",
|
| 1456 |
+
" 'KNORAE': KNORAE(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1457 |
+
" 'KNORAU': KNORAU(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1458 |
+
" 'DESMI': DESMI(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1459 |
+
" 'METADES': METADES(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1460 |
+
" 'DESKNN': DESKNN(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1461 |
+
" 'DESP': DESP(pool_classifiers=model_pool, random_state=random_state),\n",
|
| 1462 |
+
" 'KNOP': KNOP(pool_classifiers=model_pool, random_state=random_state, k=9),\n",
|
| 1463 |
+
" 'FIRE-KNORA-U': KNORAU(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1464 |
+
" 'FIRE-KNORA-E': KNORAE(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1465 |
+
" 'FIRE-METADES': METADES(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1466 |
+
" 'FIRE-DESKNN': DESKNN(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1467 |
+
" 'FIRE-DESP': DESP(pool_classifiers=model_pool, DFP=True, k=9, random_state = random_state),\n",
|
| 1468 |
+
" 'FIRE-KNOP': KNOP(pool_classifiers=model_pool, DFP=True, k=40, random_state = random_state)\n",
|
| 1469 |
+
" }\n",
|
| 1470 |
+
"\n",
|
| 1471 |
+
" for des_name, des_model in des_models.items():\n",
|
| 1472 |
+
" accuracy, precision, recall, f1, conf_matrix, fpr, tpr, roc_auc = train_evaluate_model(\n",
|
| 1473 |
+
" des_model, X_res_fi, y_res, X_test_fi, y_test, clf_name=des_name\n",
|
| 1474 |
+
" )\n",
|
| 1475 |
+
" metric_sums_des[des_name]['accuracy'] += accuracy\n",
|
| 1476 |
+
" metric_sums_des[des_name]['precision'] += precision\n",
|
| 1477 |
+
" metric_sums_des[des_name]['recall'] += recall\n",
|
| 1478 |
+
" metric_sums_des[des_name]['f1'] += f1\n",
|
| 1479 |
+
" conf_matrices_des[des_name].append(conf_matrix)\n",
|
| 1480 |
+
" roc_curves[des_name].append((fpr, tpr))\n",
|
| 1481 |
+
" roc_aucs[des_name].append(roc_auc)\n",
|
| 1482 |
+
" accuracy_scores[des_name].append(accuracy)\n",
|
| 1483 |
+
" precision_scores[des_name].append(precision)\n",
|
| 1484 |
+
" recall_scores[des_name].append(recall)\n",
|
| 1485 |
+
" f1_scores[des_name].append(f1)\n",
|
| 1486 |
+
"\n",
|
| 1487 |
+
" print(f'Confusion Matrix for {des_name} at Random State {random_state}:\\n{conf_matrix}\\n')"
|
| 1488 |
+
]
|
| 1489 |
+
},
|
| 1490 |
+
{
|
| 1491 |
+
"cell_type": "code",
|
| 1492 |
+
"execution_count": null,
|
| 1493 |
+
"id": "34b0ead4-1d8c-4d0a-b0c7-71cd71f6ab7d",
|
| 1494 |
+
"metadata": {},
|
| 1495 |
+
"outputs": [],
|
| 1496 |
+
"source": [
|
| 1497 |
+
"def plot_combined_roc_curve(roc_curves, classifier_names):\n",
|
| 1498 |
+
" plt.figure(figsize=(12, 8))\n",
|
| 1499 |
+
" mean_fpr = np.linspace(0, 1, 100)\n",
|
| 1500 |
+
" colors = plt.cm.get_cmap('tab20', len(classifier_names))\n",
|
| 1501 |
+
" \n",
|
| 1502 |
+
" for i, clf_name in enumerate(classifier_names):\n",
|
| 1503 |
+
" tprs = []\n",
|
| 1504 |
+
" for fpr, tpr in roc_curves[clf_name]:\n",
|
| 1505 |
+
" tprs.append(np.interp(mean_fpr, fpr, tpr))\n",
|
| 1506 |
+
" mean_tpr = np.mean(tprs, axis=0)\n",
|
| 1507 |
+
" mean_tpr[-1] = 1.0\n",
|
| 1508 |
+
" mean_auc = auc(mean_fpr, mean_tpr)\n",
|
| 1509 |
+
" plt.plot(mean_fpr, mean_tpr, color=colors(i), lw=2, linestyle='-', marker='o', markersize=4, \n",
|
| 1510 |
+
" label=f'{clf_name} (AUC = {mean_auc:.3f})')\n",
|
| 1511 |
+
"\n",
|
| 1512 |
+
" plt.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')\n",
|
| 1513 |
+
" plt.xlim([0.0, 1.0])\n",
|
| 1514 |
+
" plt.ylim([0.0, 1.05])\n",
|
| 1515 |
+
" plt.xlabel('False Positive Rate', fontsize=26)\n",
|
| 1516 |
+
" plt.ylabel('True Positive Rate', fontsize=26)\n",
|
| 1517 |
+
" plt.xticks(fontsize=30) # Increase x-axis numbers font size\n",
|
| 1518 |
+
" plt.yticks(fontsize=30) # Increase y-axis numbers font size\n",
|
| 1519 |
+
" plt.legend(loc=\"center left\", bbox_to_anchor=(1.05, 0.5), fontsize=26, frameon=True, framealpha=0.9) # Place legend beside the plot\n",
|
| 1520 |
+
" plt.grid(True)\n",
|
| 1521 |
+
"\n",
|
| 1522 |
+
" filename='bonk.svg'\n",
|
| 1523 |
+
"\n",
|
| 1524 |
+
" plt.savefig(filename, format='svg', bbox_inches = 'tight')\n",
|
| 1525 |
+
" plt.show()\n",
|
| 1526 |
+
"\n",
|
| 1527 |
+
" display(FileLink(filename))"
|
| 1528 |
+
]
|
| 1529 |
+
},
|
| 1530 |
+
{
|
| 1531 |
+
"cell_type": "code",
|
| 1532 |
+
"execution_count": null,
|
| 1533 |
+
"id": "4cfd8d73-26fd-4dbc-a6ca-92a9643e1d27",
|
| 1534 |
+
"metadata": {},
|
| 1535 |
+
"outputs": [],
|
| 1536 |
+
"source": [
|
| 1537 |
+
"print('\\nAverage Metrics over 10 Random States:')\n",
|
| 1538 |
+
"for des_name, metrics in metric_sums_des.items():\n",
|
| 1539 |
+
" avg_accuracy = metrics['accuracy'] / 10\n",
|
| 1540 |
+
" avg_precision = metrics['precision'] / 10\n",
|
| 1541 |
+
" avg_recall = metrics['recall'] / 10\n",
|
| 1542 |
+
" avg_f1 = metrics['f1'] / 10\n",
|
| 1543 |
+
" std_accuracy = np.std(accuracy_scores[des_name])\n",
|
| 1544 |
+
" std_precision = np.std(precision_scores[des_name])\n",
|
| 1545 |
+
" std_recall = np.std(recall_scores[des_name])\n",
|
| 1546 |
+
" std_f1 = np.std(f1_scores[des_name])\n",
|
| 1547 |
+
" avg_auc = np.mean(roc_aucs[des_name])\n",
|
| 1548 |
+
" print(f'{des_name} - Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}, Precision: {avg_precision:.4f} ± {std_precision:.4f}, Recall: {avg_recall:.4f} ± {std_recall:.4f}, F1-Score: {avg_f1:.4f} ± {std_f1:.4f}, AUC: {avg_auc:.4f}')\n",
|
| 1549 |
+
"\n",
|
| 1550 |
+
"plot_combined_roc_curve(roc_curves, list(des_models.keys()))"
|
| 1551 |
+
]
|
| 1552 |
+
},
|
| 1553 |
+
{
|
| 1554 |
+
"cell_type": "code",
|
| 1555 |
+
"execution_count": null,
|
| 1556 |
+
"id": "32780b70-dae5-4f7d-9ae4-128dc782fad4",
|
| 1557 |
+
"metadata": {},
|
| 1558 |
+
"outputs": [],
|
| 1559 |
+
"source": [
|
| 1560 |
+
"df = pd.DataFrame(accuracy_scores)\n",
|
| 1561 |
+
"scores = [df[col].values for col in df.columns]\n",
|
| 1562 |
+
"\n",
|
| 1563 |
+
"stat, p = friedmanchisquare(*scores)\n",
|
| 1564 |
+
"print(f'Friedman Test Statistic: {stat}, p-value: {p}')\n",
|
| 1565 |
+
"\n",
|
| 1566 |
+
"ranks = df.rank(axis=1, method='average', ascending=False)\n",
|
| 1567 |
+
"average_ranks = ranks.mean().values\n",
|
| 1568 |
+
"\n",
|
| 1569 |
+
"n_datasets = df.shape[0]\n",
|
| 1570 |
+
"alpha = 0.05\n",
|
| 1571 |
+
"\n",
|
| 1572 |
+
"cd = compute_CD(average_ranks, n_datasets, alpha='0.05')\n",
|
| 1573 |
+
"print(f'Critical Difference: {cd}')\n",
|
| 1574 |
+
"\n",
|
| 1575 |
+
"classifiers = [f\"{clf} ({rank:.2f})\" for clf, rank in zip(df.columns, average_ranks)]\n",
|
| 1576 |
+
"\n",
|
| 1577 |
+
"plt.figure(figsize=(14, 10))\n",
|
| 1578 |
+
"\n",
|
| 1579 |
+
"graph_ranks(average_ranks, classifiers, cd=cd, width=6, textspace=1)\n",
|
| 1580 |
+
"plt.xlabel('Classifiers')\n",
|
| 1581 |
+
"\n",
|
| 1582 |
+
"plt.text(0.5, 1.19, f'Friedman-Nemenyi: {stat:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=16)\n",
|
| 1583 |
+
"plt.text(0.5, 1.10, f'CD: {cd:.3f}', horizontalalignment='center', transform=plt.gca().transAxes, fontsize=16)\n",
|
| 1584 |
+
"\n",
|
| 1585 |
+
"plt.tight_layout()"
|
| 1586 |
+
]
|
| 1587 |
+
},
|
| 1588 |
+
{
|
| 1589 |
+
"cell_type": "markdown",
|
| 1590 |
+
"id": "f25a2b29-129f-4314-b2f9-8aff9615d5d7",
|
| 1591 |
+
"metadata": {},
|
| 1592 |
+
"source": [
|
| 1593 |
+
"### Shap (will mostly be exported files)"
|
| 1594 |
+
]
|
| 1595 |
+
},
|
| 1596 |
+
{
|
| 1597 |
+
"cell_type": "code",
|
| 1598 |
+
"execution_count": null,
|
| 1599 |
+
"id": "ed3c43d0-68d2-4aac-ac44-25a28dc6dc75",
|
| 1600 |
+
"metadata": {},
|
| 1601 |
+
"outputs": [],
|
| 1602 |
+
"source": [
|
| 1603 |
+
"# Example with XGB\n",
|
| 1604 |
+
"\n",
|
| 1605 |
+
"random_state = 2\n",
|
| 1606 |
+
"print(f\"Processing for Random State: {random_state}\")\n",
|
| 1607 |
+
"\n",
|
| 1608 |
+
"X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1609 |
+
"y = splitted_dataset['depression_category']\n",
|
| 1610 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1611 |
+
"\n",
|
| 1612 |
+
"lof = LocalOutlierFactor()\n",
|
| 1613 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 1614 |
+
"mask = yhat != -1\n",
|
| 1615 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1616 |
+
"\n",
|
| 1617 |
+
"original_columns = X.columns.tolist()\n",
|
| 1618 |
+
"\n",
|
| 1619 |
+
"smote = SMOTE(random_state=random_state)\n",
|
| 1620 |
+
"X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1621 |
+
"\n",
|
| 1622 |
+
"print(f\"Number of training labels after SMOTE: {y_res.value_counts()}\")\n",
|
| 1623 |
+
"print(f\"Number of test labels before resampling: {y_test.value_counts()}\")\n",
|
| 1624 |
+
"\n",
|
| 1625 |
+
"sampling_strategy_undersample = {0: 155}\n",
|
| 1626 |
+
"rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1627 |
+
"X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1628 |
+
"\n",
|
| 1629 |
+
"print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1630 |
+
"\n",
|
| 1631 |
+
"# Normalization\n",
|
| 1632 |
+
"# scaler = MinMaxScaler()\n",
|
| 1633 |
+
"# X_res = scaler.fit_transform(X_res)\n",
|
| 1634 |
+
"# X_test = scaler.transform(X_test)\n",
|
| 1635 |
+
"\n",
|
| 1636 |
+
"X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1637 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1638 |
+
"\n",
|
| 1639 |
+
"# Train XGBoost model on all features\n",
|
| 1640 |
+
"model = xgb.XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, use_label_encoder=False, eval_metric='mlogloss', random_state=random_state)\n",
|
| 1641 |
+
"model.fit(X_res, y_res)\n",
|
| 1642 |
+
"\n",
|
| 1643 |
+
"y_pred = model.predict(X_test)\n",
|
| 1644 |
+
"\n",
|
| 1645 |
+
"accuracy = accuracy_score(y_test, y_pred)\n",
|
| 1646 |
+
"print(f'Accuracy: {accuracy:.4f}')\n",
|
| 1647 |
+
"\n",
|
| 1648 |
+
"explainer = shap.Explainer(model, X_res)\n",
|
| 1649 |
+
"shap_values = explainer(X_res)"
|
| 1650 |
+
]
|
| 1651 |
+
},
|
| 1652 |
+
{
|
| 1653 |
+
"cell_type": "code",
|
| 1654 |
+
"execution_count": null,
|
| 1655 |
+
"id": "2cb9c929-4e86-42df-bda2-48730004c154",
|
| 1656 |
+
"metadata": {},
|
| 1657 |
+
"outputs": [],
|
| 1658 |
+
"source": [
|
| 1659 |
+
"def plot_shap_waterfall(instance_index, filename):\n",
|
| 1660 |
+
" shap_value = shap_values[instance_index]\n",
|
| 1661 |
+
" plt.figure(figsize=(14, 8))\n",
|
| 1662 |
+
" \n",
|
| 1663 |
+
" shap.plots.waterfall(shap_value, show=False)\n",
|
| 1664 |
+
" \n",
|
| 1665 |
+
" ax = plt.gca()\n",
|
| 1666 |
+
" \n",
|
| 1667 |
+
" ax.tick_params(axis='both', which='major', labelsize=16)\n",
|
| 1668 |
+
" ax.set_xlabel(ax.get_xlabel(), fontsize=20)\n",
|
| 1669 |
+
" ax.set_ylabel(ax.get_ylabel(), fontsize=22)\n",
|
| 1670 |
+
" \n",
|
| 1671 |
+
" plt.tight_layout()\n",
|
| 1672 |
+
"\n",
|
| 1673 |
+
" plt.savefig(filename, format='svg')\n",
|
| 1674 |
+
" \n",
|
| 1675 |
+
" plt.close()\n",
|
| 1676 |
+
"\n",
|
| 1677 |
+
"plot_shap_waterfall(0, \"waterfall_plot_instance_0.svg\")\n",
|
| 1678 |
+
"plot_shap_waterfall(562, \"waterfall_plot_instance_562.svg\")"
|
| 1679 |
+
]
|
| 1680 |
+
},
|
| 1681 |
+
{
|
| 1682 |
+
"cell_type": "code",
|
| 1683 |
+
"execution_count": null,
|
| 1684 |
+
"id": "e111213e-caea-48a0-adb7-c6b2362eed4f",
|
| 1685 |
+
"metadata": {},
|
| 1686 |
+
"outputs": [],
|
| 1687 |
+
"source": [
|
| 1688 |
+
"import matplotlib.pyplot as plt\n",
|
| 1689 |
+
"\n",
|
| 1690 |
+
"plt.figure(figsize=(14, 8))\n",
|
| 1691 |
+
"shap.summary_plot(\n",
|
| 1692 |
+
" shap_values,\n",
|
| 1693 |
+
" X_res,\n",
|
| 1694 |
+
" plot_type=\"bar\",\n",
|
| 1695 |
+
" feature_names=original_columns,\n",
|
| 1696 |
+
" show=False\n",
|
| 1697 |
+
")\n",
|
| 1698 |
+
"\n",
|
| 1699 |
+
"ax = plt.gca()\n",
|
| 1700 |
+
"ax.tick_params(axis='both', which='major', labelsize=16)\n",
|
| 1701 |
+
"ax.set_xlabel(ax.get_xlabel(), fontsize=16)\n",
|
| 1702 |
+
"ax.set_ylabel(ax.get_ylabel(), fontsize=22)\n",
|
| 1703 |
+
"\n",
|
| 1704 |
+
"plt.savefig(\"shap_summary_plot.svg\", format='svg')\n",
|
| 1705 |
+
"plt.close()"
|
| 1706 |
+
]
|
| 1707 |
+
},
|
| 1708 |
+
{
|
| 1709 |
+
"cell_type": "code",
|
| 1710 |
+
"execution_count": null,
|
| 1711 |
+
"id": "4ae84637-a018-49a1-ad4e-522aa937bfad",
|
| 1712 |
+
"metadata": {},
|
| 1713 |
+
"outputs": [],
|
| 1714 |
+
"source": [
|
| 1715 |
+
"shap.initjs()\n",
|
| 1716 |
+
"\n",
|
| 1717 |
+
"fig, ax = plt.subplots(figsize=(14, 8))\n",
|
| 1718 |
+
"\n",
|
| 1719 |
+
"shap.summary_plot(\n",
|
| 1720 |
+
" shap_values,\n",
|
| 1721 |
+
" X_res,\n",
|
| 1722 |
+
" plot_type=\"dot\",\n",
|
| 1723 |
+
" feature_names=original_columns,\n",
|
| 1724 |
+
" show=False\n",
|
| 1725 |
+
")\n",
|
| 1726 |
+
"\n",
|
| 1727 |
+
"ax.tick_params(axis='both', which='major', labelsize=16)\n",
|
| 1728 |
+
"ax.set_xlabel(ax.get_xlabel(), fontsize=16)\n",
|
| 1729 |
+
"ax.set_ylabel(ax.get_ylabel(), fontsize=22)\n",
|
| 1730 |
+
"\n",
|
| 1731 |
+
"fig.savefig(\"shap_summary_dot_plot.svg\", format='svg', bbox_inches='tight')\n",
|
| 1732 |
+
"\n",
|
| 1733 |
+
"plt.close(fig)"
|
| 1734 |
+
]
|
| 1735 |
+
},
|
| 1736 |
+
{
|
| 1737 |
+
"cell_type": "code",
|
| 1738 |
+
"execution_count": null,
|
| 1739 |
+
"id": "05f0de1c-3850-4d77-9184-d593451022c1",
|
| 1740 |
+
"metadata": {},
|
| 1741 |
+
"outputs": [],
|
| 1742 |
+
"source": [
|
| 1743 |
+
"from sklearn.tree import DecisionTreeClassifier, export_text\n",
|
| 1744 |
+
"from sklearn import tree\n",
|
| 1745 |
+
"\n",
|
| 1746 |
+
"random_state = 5\n",
|
| 1747 |
+
"\n",
|
| 1748 |
+
"X = splitted_dataset.drop('depression_category', axis=1)\n",
|
| 1749 |
+
"y = splitted_dataset['depression_category']\n",
|
| 1750 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=random_state)\n",
|
| 1751 |
+
"\n",
|
| 1752 |
+
"lof = LocalOutlierFactor()\n",
|
| 1753 |
+
"yhat = lof.fit_predict(X_train)\n",
|
| 1754 |
+
"mask = yhat != -1\n",
|
| 1755 |
+
"X_train, y_train = X_train[mask], y_train[mask]\n",
|
| 1756 |
+
"\n",
|
| 1757 |
+
"original_columns = X.columns.tolist()\n",
|
| 1758 |
+
"\n",
|
| 1759 |
+
"smote = SMOTE(random_state=random_state)\n",
|
| 1760 |
+
"X_res, y_res = smote.fit_resample(X_train, y_train)\n",
|
| 1761 |
+
"\n",
|
| 1762 |
+
"print(f\"Number of training labels after ROS: {y_res.value_counts()}\")\n",
|
| 1763 |
+
"print(f\"Number of test labels before resampling: {y_test.value_counts()}\")\n",
|
| 1764 |
+
"sampling_strategy_undersample = {0: 155}\n",
|
| 1765 |
+
"rus = RandomUnderSampler(sampling_strategy=sampling_strategy_undersample, random_state=random_state)\n",
|
| 1766 |
+
"X_test, y_test = rus.fit_resample(X_test, y_test)\n",
|
| 1767 |
+
"\n",
|
| 1768 |
+
"print(f\"Number of test labels after resampling: {y_test.value_counts()}\")\n",
|
| 1769 |
+
"\n",
|
| 1770 |
+
"# Normalization\n",
|
| 1771 |
+
"# scaler = MinMaxScaler()\n",
|
| 1772 |
+
"# X_res = scaler.fit_transform(X_res)\n",
|
| 1773 |
+
"# X_test = scaler.transform(X_test)\n",
|
| 1774 |
+
"\n",
|
| 1775 |
+
"X_res = pd.DataFrame(X_res, columns=original_columns)\n",
|
| 1776 |
+
"X_test = pd.DataFrame(X_test, columns=original_columns)\n",
|
| 1777 |
+
"\n",
|
| 1778 |
+
"decision_tree_model = DecisionTreeClassifier(\n",
|
| 1779 |
+
" random_state=0, \n",
|
| 1780 |
+
" criterion='gini', \n",
|
| 1781 |
+
" max_depth=6, \n",
|
| 1782 |
+
" min_samples_leaf=10, \n",
|
| 1783 |
+
" min_samples_split=9\n",
|
| 1784 |
+
")\n",
|
| 1785 |
+
"decision_tree_model.fit(X_res, y_res)\n",
|
| 1786 |
+
"\n",
|
| 1787 |
+
"plt.figure(figsize=(20, 14))\n",
|
| 1788 |
+
"tree.plot_tree(\n",
|
| 1789 |
+
" decision_tree_model, \n",
|
| 1790 |
+
" feature_names=original_columns, \n",
|
| 1791 |
+
" class_names=['depression', 'normal'],\n",
|
| 1792 |
+
" filled=True, \n",
|
| 1793 |
+
" rounded=True, \n",
|
| 1794 |
+
" fontsize=10,\n",
|
| 1795 |
+
" max_depth = 3\n",
|
| 1796 |
+
")\n",
|
| 1797 |
+
"\n",
|
| 1798 |
+
"plt.savefig(\"decision_tree_plot.svg\", format='svg')\n",
|
| 1799 |
+
"plt.close()\n",
|
| 1800 |
+
"\n",
|
| 1801 |
+
"print(\"Decision Tree plot saved as 'decision_tree_plot.svg'\")"
|
| 1802 |
+
]
|
| 1803 |
+
},
|
| 1804 |
+
{
|
| 1805 |
+
"cell_type": "code",
|
| 1806 |
+
"execution_count": null,
|
| 1807 |
+
"id": "b46a92a6-f370-4df4-ac29-8f382fc1820b",
|
| 1808 |
+
"metadata": {
|
| 1809 |
+
"scrolled": true
|
| 1810 |
+
},
|
| 1811 |
+
"outputs": [],
|
| 1812 |
+
"source": [
|
| 1813 |
+
"tree_rules = export_text(decision_tree_model, feature_names=original_columns, max_depth=50)\n",
|
| 1814 |
+
"print(\"Decision rules for the tree (up to depth 3):\")\n",
|
| 1815 |
+
"print(tree_rules)\n",
|
| 1816 |
+
"\n",
|
| 1817 |
+
"node_indicator = decision_tree_model.decision_path(X_test)\n",
|
| 1818 |
+
"\n",
|
| 1819 |
+
"sample_id = 0\n",
|
| 1820 |
+
"node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]]\n",
|
| 1821 |
+
"\n",
|
| 1822 |
+
"print(f\"\\nDecision path for sample {sample_id}:\")\n",
|
| 1823 |
+
"for node_id in node_index:\n",
|
| 1824 |
+
" if X_test.iloc[sample_id, decision_tree_model.tree_.feature[node_id]] <= decision_tree_model.tree_.threshold[node_id]:\n",
|
| 1825 |
+
" threshold_sign = \"<=\"\n",
|
| 1826 |
+
" else:\n",
|
| 1827 |
+
" threshold_sign = \">\"\n",
|
| 1828 |
+
" print(f\"Node {node_id}: (X_test[{sample_id}, {decision_tree_model.tree_.feature[node_id]}] = {X_test.iloc[sample_id, decision_tree_model.tree_.feature[node_id]]}) \"\n",
|
| 1829 |
+
" f\"{threshold_sign} {decision_tree_model.tree_.threshold[node_id]}\")\n",
|
| 1830 |
+
"\n",
|
| 1831 |
+
"# Get prediction for a specific test sample\n",
|
| 1832 |
+
"predicted_class = decision_tree_model.predict([X_test.iloc[sample_id]])\n",
|
| 1833 |
+
"print(f\"\\nPredicted class for test sample {sample_id}: {predicted_class}\")"
|
| 1834 |
+
]
|
| 1835 |
+
}
|
| 1836 |
+
],
|
| 1837 |
+
"metadata": {
|
| 1838 |
+
"kernelspec": {
|
| 1839 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1840 |
+
"language": "python",
|
| 1841 |
+
"name": "python3"
|
| 1842 |
+
},
|
| 1843 |
+
"language_info": {
|
| 1844 |
+
"codemirror_mode": {
|
| 1845 |
+
"name": "ipython",
|
| 1846 |
+
"version": 3
|
| 1847 |
+
},
|
| 1848 |
+
"file_extension": ".py",
|
| 1849 |
+
"mimetype": "text/x-python",
|
| 1850 |
+
"name": "python",
|
| 1851 |
+
"nbconvert_exporter": "python",
|
| 1852 |
+
"pygments_lexer": "ipython3",
|
| 1853 |
+
"version": "3.12.3"
|
| 1854 |
+
}
|
| 1855 |
+
},
|
| 1856 |
+
"nbformat": 4,
|
| 1857 |
+
"nbformat_minor": 5
|
| 1858 |
+
}
|