{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Random Forest Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device (RF runs on CPU; this is only for loading tensors).\n", "šŸ“¦ Loading datasets...\n", "āœ… Data loaded successfully!\n", "Train: X=(169692, 1280), y=(169692,)\n", "Val: X=(21211, 1280), y=(21211,)\n", "Test: X=(21212, 1280), y=(21212,)\n", "Classes: 5 | Train distribution: [33938 33939 33938 33938 33939]\n", "CSV (incremental): training_runs/smote_study/results_rf_best_hparam_smote/sweep_metrics.csv\n", "Using uniform weights (no balancing)\n", "Total combos in grid: 12\n", "Resume: found 11 completed combos in existing CSV.\n", "[1/12] SKIP (done) param_id=171f36f06824 params={'n_estimators': 200, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.2, 'bootstrap': False}\n", "[2/12] SKIP (done) param_id=ae75a831a2c3 params={'n_estimators': 200, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.5, 'bootstrap': False}\n", "[3/12] SKIP (done) param_id=a47bef876928 params={'n_estimators': 200, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.8, 'bootstrap': False}\n", "[4/12] SKIP (done) param_id=ac9388907d7d params={'n_estimators': 300, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.2, 'bootstrap': False}\n", "[5/12] SKIP (done) param_id=c19c1670d890 params={'n_estimators': 300, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.5, 'bootstrap': False}\n", "[6/12] SKIP (done) param_id=245179ee1abc params={'n_estimators': 300, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.8, 'bootstrap': False}\n", "[7/12] SKIP (done) param_id=ceaf6cca42d9 params={'n_estimators': 600, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.2, 'bootstrap': False}\n", "[8/12] SKIP (done) param_id=b8a3c667d38c params={'n_estimators': 600, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.5, 'bootstrap': False}\n", "[9/12] SKIP (done) param_id=92d962d1834e params={'n_estimators': 600, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.8, 'bootstrap': False}\n", "[10/12] SKIP (done) param_id=48d0301d6872 params={'n_estimators': 800, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.2, 'bootstrap': False}\n", "[11/12] SKIP (done) param_id=e470342d984f params={'n_estimators': 800, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.5, 'bootstrap': False}\n", "\n", "[12/12] RUN param_id=60ef71cfde09 params={'n_estimators': 800, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.8, 'bootstrap': False}\n", "VAL acc=0.9679 bal_acc=0.9679 macro_f1=0.9681 mcc=0.9601 macro_auc=0.9966\n", "TEST acc=0.9687 bal_acc=0.9687 macro_f1=0.9689 mcc=0.9611 macro_auc=0.9969\n", "\n", "Sweep finished. Wall time: 7709.1 sec\n", "CSV saved incrementally at: training_runs/smote_study/results_rf_best_hparam_smote/sweep_metrics.csv\n", "\n", "==========================================================================================\n", "BEST PARAMS by val_macro_f1 = 0.977406\n", "{'n_estimators': 800, 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'criterion': 'gini', 'max_features': 0.2, 'bootstrap': False}\n", "==========================================================================================\n", "\n", "Training best RF again (for saving)...\n", "Saved prediction outputs to: training_runs/smote_study/results_rf_best_hparam_smote/rf_best_model_20260326-111232/test_outputs.npz\n", "\n", "āœ… Saved BEST model: training_runs/smote_study/results_rf_best_hparam_smote/rf_best_model_20260326-111232/random_forest.joblib\n", "āœ… Saved BEST metadata: training_runs/smote_study/results_rf_best_hparam_smote/rf_best_model_20260326-111232/metadata_rf.json\n", "Log file closed.\n", "\n", "Done.\n" ] } ], "source": [ "\"\"\"\n", "RandomForest hyperparameter sweep (multiclass).\n", "\n", "What it does\n", "------------\n", "1) Sweeps over a manual RF param grid.\n", "2) For EACH hyperparameter combo:\n", " - trains on train split\n", " - evaluates metrics on val + test\n", " - appends ONE ROW immediately to a CSV (append mode) => crash-safe\n", "3) Resume capability:\n", " - if CSV exists, it skips param combos already completed (using stable param_id)\n", "4) After sweep:\n", " - selects best hyperparams by SELECT_BEST_BY (default val_macro_f1)\n", " - retrains best model and saves ONLY that model + metadata\n", "\n", "Outputs\n", "-------\n", "- results_dir/\n", " - sweep_metrics.csv (incrementally appended; resume-safe)\n", " - rf_sweep_log_.txt\n", " - rf_best_model_/\n", " - random_forest.joblib\n", " - metadata_rf.json\n", " - test_outputs.npz\n", "\"\"\"\n", "\n", "import os\n", "import sys\n", "import json\n", "import time\n", "import hashlib\n", "import random\n", "import itertools\n", "import contextlib\n", "from datetime import datetime\n", "\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import (\n", " accuracy_score,\n", " balanced_accuracy_score,\n", " f1_score,\n", " matthews_corrcoef,\n", " roc_auc_score,\n", ")\n", "from sklearn.utils.class_weight import compute_class_weight\n", "import joblib\n", "\n", "\n", "# =====================================================\n", "# USER CONFIG\n", "# =====================================================\n", "DATA_PATH = \"../Data/multiclass_data_no_SMOTE\"\n", "RESULTS_DIR = \"../Results/results_rf_hparam_sweep\"\n", "\n", "USE_BALANCED_WEIGHTS = False\n", "GLOBAL_SEED = 42\n", "\n", "# Options (must exist as a CSV column):\n", "# \"val_macro_f1\", \"val_bal_acc\", \"val_macro_auc_ovr\", \"val_mcc\"\n", "SELECT_BEST_BY = \"val_macro_f1\"\n", "\n", "# Class names for metadata only (optional)\n", "CLASS_ID_TO_NAME = {\n", " 0: \"phosphate\",\n", " 1: \"sulfate\",\n", " 2: \"chloride\",\n", " 3: \"nitrate\",\n", " 4: \"carbonate\",\n", "}\n", "\n", "# =====================================================\n", "# PARAM GRID\n", "# =====================================================\n", "PARAM_GRID = {\n", " \"n_estimators\": [200],\n", " \"max_depth\": [None],\n", " \"min_samples_split\": [2],\n", " \"min_samples_leaf\": [1],\n", " \"criterion\": [\"gini\"],\n", " \"max_features\": [0.2], \n", " \"bootstrap\": [False],\n", "}\n", "\n", "\n", "# =====================================================\n", "# REPRODUCIBILITY\n", "# =====================================================\n", "def seed_everything(seed: int = 42):\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", "\n", "seed_everything(GLOBAL_SEED)\n", "\n", "\n", "# =====================================================\n", "# IO SETUP\n", "# =====================================================\n", "os.makedirs(RESULTS_DIR, exist_ok=True)\n", "\n", "ts = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", "LOG_PATH = os.path.join(RESULTS_DIR, f\"rf_sweep_log_{ts}.txt\")\n", "\n", "# Stable CSV path for resume\n", "CSV_PATH = os.path.join(RESULTS_DIR, \"sweep_metrics.csv\")\n", "\n", "BEST_MODEL_DIR = os.path.join(RESULTS_DIR, f\"rf_best_model_{ts}\")\n", "os.makedirs(BEST_MODEL_DIR, exist_ok=True)\n", "\n", "\n", "# =====================================================\n", "# TEE LOGGER\n", "# =====================================================\n", "class Tee:\n", " def __init__(self, *streams):\n", " self.streams = streams\n", " def write(self, data):\n", " for s in self.streams:\n", " s.write(data)\n", " def flush(self):\n", " for s in self.streams:\n", " try:\n", " s.flush()\n", " except Exception:\n", " pass\n", "\n", "\n", "# =====================================================\n", "# LOAD DATA\n", "# =====================================================\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using {device} device (RF runs on CPU; this is only for loading tensors).\")\n", "\n", "print(\"šŸ“¦ Loading datasets...\")\n", "train = torch.load(os.path.join(DATA_PATH, \"training_data.pt\"), map_location=\"cpu\")\n", "val = torch.load(os.path.join(DATA_PATH, \"validation_data.pt\"), map_location=\"cpu\")\n", "test = torch.load(os.path.join(DATA_PATH, \"test_data.pt\"), map_location=\"cpu\")\n", "print(\"āœ… Data loaded successfully!\")\n", "\n", "def to_np(t):\n", " return t.detach().cpu().numpy() if torch.is_tensor(t) else np.array(t)\n", "\n", "X_train, y_train = to_np(train[\"X_train\"]), to_np(train[\"y_train\"]).ravel().astype(int)\n", "X_val, y_val = to_np(val[\"X_val\"]), to_np(val[\"y_val\"]).ravel().astype(int)\n", "X_test, y_test = to_np(test[\"X_test\"]), to_np(test[\"y_test\"]).ravel().astype(int)\n", "\n", "classes = np.unique(y_train)\n", "n_classes = len(classes)\n", "\n", "print(f\"Train: X={X_train.shape}, y={y_train.shape}\")\n", "print(f\"Val: X={X_val.shape}, y={y_val.shape}\")\n", "print(f\"Test: X={X_test.shape}, y={y_test.shape}\")\n", "print(f\"Classes: {n_classes} | Train distribution: {np.bincount(y_train)}\")\n", "print(f\"CSV (incremental): {CSV_PATH}\")\n", "\n", "\n", "# =====================================================\n", "# CLASS WEIGHTS\n", "# =====================================================\n", "if USE_BALANCED_WEIGHTS:\n", " cw = compute_class_weight(class_weight=\"balanced\", classes=classes, y=y_train)\n", " class_weight_dict = {int(c): float(w) for c, w in zip(classes, cw)}\n", " print(\"Using balanced class weights:\", class_weight_dict)\n", "else:\n", " class_weight_dict = None\n", " print(\"Using uniform weights (no balancing)\")\n", "\n", "# =====================================================\n", "# PARAM UTILITIES\n", "# =====================================================\n", "def stable_param_id(params: dict) -> str:\n", " s = json.dumps(params, sort_keys=True, default=str)\n", " return hashlib.md5(s.encode(\"utf-8\")).hexdigest()[:12]\n", "\n", "def load_completed_param_ids(csv_path: str) -> set:\n", " if not os.path.exists(csv_path):\n", " return set()\n", " try:\n", " df = pd.read_csv(csv_path)\n", " if \"param_id\" not in df.columns:\n", " return set()\n", " return set(df[\"param_id\"].astype(str).tolist())\n", " except Exception:\n", " return set()\n", "\n", "# =====================================================\n", "# ROC-AUC HELPERS\n", "# =====================================================\n", "def safe_auc(y_true, y_proba, average=\"macro\"):\n", " try:\n", " return roc_auc_score(y_true, y_proba, multi_class=\"ovr\", average=average)\n", " except Exception:\n", " return np.nan\n", "\n", "def save_prediction_outputs(save_dir, y_true, y_pred, y_proba, class_names=None, prefix=\"test\"):\n", " os.makedirs(save_dir, exist_ok=True)\n", "\n", " save_path = os.path.join(save_dir, f\"{prefix}_outputs.npz\")\n", "\n", " if class_names is None:\n", " class_names = np.array([f\"C{i}\" for i in range(y_proba.shape[1])], dtype=object)\n", " else:\n", " class_names = np.array(class_names, dtype=object)\n", "\n", " np.savez_compressed(\n", " save_path,\n", " y_true=np.asarray(y_true),\n", " y_pred=np.asarray(y_pred),\n", " y_proba=np.asarray(y_proba),\n", " class_names=class_names\n", " )\n", " print(f\"Saved prediction outputs to: {save_path}\")\n", "\n", "# =====================================================\n", "# METRICS\n", "# =====================================================\n", "def per_class_f1_dict(y_true, y_pred, n_classes):\n", " scores = f1_score(y_true, y_pred, average=None, labels=np.arange(n_classes))\n", " return {f\"f1_c{i}\": float(scores[i]) for i in range(n_classes)}\n", "\n", "def compute_metrics_with_per_class(y_true, y_pred, y_proba, n_classes):\n", " out = {}\n", " out[\"acc\"] = float(accuracy_score(y_true, y_pred))\n", " out[\"bal_acc\"] = float(balanced_accuracy_score(y_true, y_pred))\n", " out[\"macro_f1\"] = float(f1_score(y_true, y_pred, average=\"macro\"))\n", " out[\"weighted_f1\"] = float(f1_score(y_true, y_pred, average=\"weighted\"))\n", " out[\"micro_f1\"] = float(f1_score(y_true, y_pred, average=\"micro\"))\n", " out[\"mcc\"] = float(matthews_corrcoef(y_true, y_pred))\n", " out[\"macro_auc_ovr\"] = float(safe_auc(y_true, y_proba, average=\"macro\"))\n", " out[\"weighted_auc_ovr\"] = float(safe_auc(y_true, y_proba, average=\"weighted\"))\n", " out.update(per_class_f1_dict(y_true, y_pred, n_classes))\n", " return out\n", "\n", "\n", "# =====================================================\n", "# CSV APPEND (CRASH-SAFE)\n", "# =====================================================\n", "def append_row_to_csv(csv_path: str, row: dict):\n", " row_df = pd.DataFrame([row])\n", " if not os.path.exists(csv_path):\n", " row_df.to_csv(csv_path, index=False)\n", " else:\n", " row_df.to_csv(csv_path, mode=\"a\", header=False, index=False)\n", "\n", "\n", "# =====================================================\n", "# BUILD PARAM COMBOS\n", "# =====================================================\n", "keys = list(PARAM_GRID.keys())\n", "values = list(PARAM_GRID.values())\n", "param_combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]\n", "print(f\"Total combos in grid: {len(param_combinations)}\")\n", "\n", "\n", "# =====================================================\n", "# MAIN SWEEP (RESUME-CAPABLE)\n", "# =====================================================\n", "log_f = open(LOG_PATH, \"w\", encoding=\"utf-8\")\n", "tee = Tee(sys.stdout, log_f)\n", "\n", "try:\n", " with contextlib.redirect_stdout(tee), contextlib.redirect_stderr(tee):\n", " completed = load_completed_param_ids(CSV_PATH)\n", " print(f\"Resume: found {len(completed)} completed combos in existing CSV.\")\n", "\n", " t0 = time.time()\n", "\n", " for idx, params in enumerate(param_combinations, 1):\n", " # āœ… FIX: ensure correct dtype for max_features\n", " if isinstance(params[\"max_features\"], str) and params[\"max_features\"] not in [\"sqrt\", \"log2\"]:\n", " params[\"max_features\"] = float(params[\"max_features\"])\n", " pid = stable_param_id(params)\n", " if pid in completed:\n", " print(f\"[{idx}/{len(param_combinations)}] SKIP (done) param_id={pid} params={params}\")\n", " continue\n", "\n", " print(f\"\\n[{idx}/{len(param_combinations)}] RUN param_id={pid} params={params}\")\n", "\n", " row = {\n", " \"timestamp\": datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"),\n", " \"param_id\": pid,\n", " \"use_balanced_weights\": bool(USE_BALANCED_WEIGHTS),\n", " \"seed\": int(GLOBAL_SEED),\n", " **params\n", " }\n", "\n", " try:\n", " model = RandomForestClassifier(\n", " random_state=GLOBAL_SEED,\n", " n_jobs=-1,\n", " class_weight=class_weight_dict,\n", " **params\n", " )\n", " model.fit(X_train, y_train)\n", "\n", " # VAL\n", " y_val_pred = model.predict(X_val)\n", " y_val_proba = model.predict_proba(X_val)\n", " val_metrics = compute_metrics_with_per_class(y_val, y_val_pred, y_val_proba, n_classes)\n", "\n", " # TEST\n", " y_test_pred = model.predict(X_test)\n", " y_test_proba = model.predict_proba(X_test)\n", " test_metrics = compute_metrics_with_per_class(y_test, y_test_pred, y_test_proba, n_classes)\n", "\n", " row.update({f\"val_{k}\": v for k, v in val_metrics.items()})\n", " row.update({f\"test_{k}\": v for k, v in test_metrics.items()})\n", "\n", " append_row_to_csv(CSV_PATH, row)\n", " completed.add(pid)\n", "\n", " print(\n", " f\"VAL acc={row['val_acc']:.4f} bal_acc={row['val_bal_acc']:.4f} \"\n", " f\"macro_f1={row['val_macro_f1']:.4f} mcc={row['val_mcc']:.4f} \"\n", " f\"macro_auc={row['val_macro_auc_ovr']:.4f}\"\n", " )\n", " print(\n", " f\"TEST acc={row['test_acc']:.4f} bal_acc={row['test_bal_acc']:.4f} \"\n", " f\"macro_f1={row['test_macro_f1']:.4f} mcc={row['test_mcc']:.4f} \"\n", " f\"macro_auc={row['test_macro_auc_ovr']:.4f}\"\n", " )\n", "\n", " except Exception as e:\n", " # Append an error row too, so you know it failed.\n", " row[\"error\"] = str(e)\n", " append_row_to_csv(CSV_PATH, row)\n", " completed.add(pid)\n", " print(f\"āŒ FAILED param_id={pid} error={e}\")\n", "\n", " print(f\"\\nSweep finished. Wall time: {time.time()-t0:.1f} sec\")\n", " print(f\"CSV saved incrementally at: {CSV_PATH}\")\n", "\n", " # =====================================================\n", " # SELECT BEST + TRAIN AGAIN + SAVE ONLY BEST MODEL\n", " # =====================================================\n", " df = pd.read_csv(CSV_PATH)\n", "\n", " # Only successful runs\n", " if \"error\" in df.columns:\n", " df_ok = df[df[\"error\"].isna()]\n", " else:\n", " df_ok = df\n", "\n", " if df_ok.empty:\n", " raise RuntimeError(\"No successful runs found in CSV; cannot select best model.\")\n", "\n", " if SELECT_BEST_BY not in df_ok.columns:\n", " raise RuntimeError(f\"SELECT_BEST_BY='{SELECT_BEST_BY}' not found in CSV columns.\")\n", "\n", " df_ok = df_ok.sort_values(SELECT_BEST_BY, ascending=False)\n", " best_row = df_ok.iloc[0].to_dict()\n", "\n", " best_params = {k: best_row[k] for k in keys}\n", "\n", " # normalize types for sklearn\n", " best_params[\"n_estimators\"] = int(best_params[\"n_estimators\"])\n", " best_params[\"min_samples_split\"] = int(best_params[\"min_samples_split\"])\n", " best_params[\"min_samples_leaf\"] = int(best_params[\"min_samples_leaf\"])\n", " # max_depth may be NaN in CSV when None; handle it:\n", " if pd.isna(best_params[\"max_depth\"]):\n", " best_params[\"max_depth\"] = None\n", " else:\n", " best_params[\"max_depth\"] = int(best_params[\"max_depth\"])\n", " # max_features can be string or float\n", " if isinstance(best_params[\"max_features\"], str) and best_params[\"max_features\"] not in [\"sqrt\", \"log2\"]:\n", " best_params[\"max_features\"] = float(best_params[\"max_features\"])\n", " else:\n", " best_params[\"max_features\"] = float(best_params[\"max_features\"])\n", " best_params[\"bootstrap\"] = bool(best_params[\"bootstrap\"])\n", "\n", " print(\"\\n\" + \"=\" * 90)\n", " print(f\"BEST PARAMS by {SELECT_BEST_BY} = {best_row[SELECT_BEST_BY]:.6f}\")\n", " print(best_params)\n", " print(\"=\" * 90)\n", "\n", " print(\"\\nTraining best RF again (for saving)...\")\n", " final_model = RandomForestClassifier(\n", " random_state=GLOBAL_SEED,\n", " n_jobs=-1,\n", " class_weight=class_weight_dict,\n", " **best_params\n", " )\n", " final_model.fit(X_train, y_train)\n", "\n", " model_path = os.path.join(BEST_MODEL_DIR, \"random_forest.joblib\")\n", " joblib.dump(final_model, model_path)\n", "\n", " # =====================================================\n", " # Save prediction outputs for confusion matrix / ROC\n", " # =====================================================\n", " y_pred = final_model.predict(X_test)\n", " y_proba = final_model.predict_proba(X_test)\n", "\n", " save_prediction_outputs(\n", " save_dir=BEST_MODEL_DIR,\n", " y_true=y_test,\n", " y_pred=y_pred,\n", " y_proba=y_proba,\n", " class_names=[CLASS_ID_TO_NAME[i] for i in range(n_classes)],\n", " prefix=\"test\"\n", " )\n", "\n", " meta = {\n", " \"data_path\": DATA_PATH,\n", " \"results_csv\": CSV_PATH,\n", " \"select_best_by\": SELECT_BEST_BY,\n", " \"best_score\": float(best_row[SELECT_BEST_BY]),\n", " \"best_param_id\": str(best_row[\"param_id\"]),\n", " \"best_params\": best_params,\n", " \"use_balanced_weights\": bool(USE_BALANCED_WEIGHTS),\n", " \"seed\": int(GLOBAL_SEED),\n", " \"feature_dim\": int(X_train.shape[1]),\n", " \"n_classes\": int(n_classes),\n", " \"class_id_to_name\": {str(k): v for k, v in CLASS_ID_TO_NAME.items()},\n", " \"timestamp_saved\": ts,\n", " }\n", " meta_path = os.path.join(BEST_MODEL_DIR, \"metadata_rf.json\")\n", " with open(meta_path, \"w\", encoding=\"utf-8\") as jf:\n", " json.dump(meta, jf, indent=2)\n", "\n", " print(f\"\\nāœ… Saved BEST model: {model_path}\")\n", " print(f\"āœ… Saved BEST metadata: {meta_path}\")\n", "\n", "finally:\n", " log_f.close()\n", " print(\"Log file closed.\")\n", "\n", "print(\"\\nDone.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =====================================================\n", "# šŸ“Š CONFUSION MATRIX (FULLY CUSTOMIZABLE - SINGLE CELL)\n", "# =====================================================\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.metrics import confusion_matrix\n", "\n", "# =========================\n", "# šŸ”§ USER INPUT\n", "# =========================\n", "NPZ_PATH = \"../Results/results_rf_hparam_sweep/best_model/test_outputs.npz\" \n", "SAVE_PATH = \"../Results/results_rf_hparam_sweep/best_model/confusion_matrix.png\"\n", "\n", "# =========================\n", "# šŸŽØ STYLE OPTIONS\n", "# =========================\n", "FONT_FAMILY = \"DejaVu Sans\" \n", "TITLE_SIZE = 16\n", "LABEL_SIZE = 12\n", "TICK_SIZE = 12\n", "ANNOT_SIZE = 12\n", "\n", "FIGSIZE = (6, 5)\n", "DPI = 300\n", "\n", "CMAP = \"Blues\" # try: \"viridis\", \"magma\", \"coolwarm\"\n", "SHOW_VALUES = True\n", "FORMAT = \".2f\"\n", "\n", "NORMALIZE = True # True = normalized, False = raw counts\n", "\n", "# Grid / borders\n", "SHOW_GRID = False\n", "LINEWIDTH = 0.4\n", "LINECOLOR = \"black\"\n", "\n", "# Colorbar\n", "SHOW_CBAR = True\n", "CBAR_LABEL = \"\"\n", "\n", "# Tick rotation\n", "XTICK_ROT = 45\n", "YTICK_ROT = 0\n", "\n", "# =========================\n", "# šŸ“‚ LOAD DATA\n", "# =========================\n", "data = np.load(NPZ_PATH, allow_pickle=True)\n", "\n", "y_true = data[\"y_true\"]\n", "y_pred = data[\"y_pred\"]\n", "class_names = data[\"class_names\"]\n", "\n", "# =========================\n", "# šŸ“Š CONFUSION MATRIX\n", "# =========================\n", "cm = confusion_matrix(y_true, y_pred)\n", "\n", "if NORMALIZE:\n", " cm = cm.astype(float) / cm.sum(axis=1, keepdims=True)\n", " cm = np.nan_to_num(cm)\n", " title = \"Normalized Confusion Matrix\"\n", " vmin, vmax = 0.0, 1.0\n", "else:\n", " title = \"Confusion Matrix\"\n", " FORMAT = \"d\"\n", " vmin, vmax = None, None\n", "\n", "# =========================\n", "# šŸŽØ PLOT SETTINGS\n", "# =========================\n", "plt.rcParams[\"font.family\"] = FONT_FAMILY\n", "\n", "plt.figure(figsize=FIGSIZE, dpi=DPI)\n", "\n", "ax = sns.heatmap(\n", " cm,\n", " annot=SHOW_VALUES,\n", " fmt=FORMAT,\n", " cmap=CMAP,\n", " xticklabels=class_names,\n", " yticklabels=class_names,\n", " cbar=SHOW_CBAR,\n", " linewidths=LINEWIDTH if SHOW_GRID else 0,\n", " linecolor=LINECOLOR,\n", " vmin=vmin,\n", " vmax=vmax,\n", " annot_kws={\"size\": ANNOT_SIZE}\n", ")\n", "\n", "# Labels and title\n", "plt.xlabel(\"Predicted Label\", fontsize=LABEL_SIZE)\n", "plt.ylabel(\"True Label\", fontsize=LABEL_SIZE)\n", "plt.title(title, fontsize=TITLE_SIZE)\n", "\n", "# Tick formatting\n", "plt.xticks(rotation=XTICK_ROT, ha=\"right\", fontsize=TICK_SIZE)\n", "plt.yticks(rotation=YTICK_ROT, fontsize=TICK_SIZE)\n", "\n", "# Colorbar label\n", "if SHOW_CBAR:\n", " cbar = ax.collections[0].colorbar\n", " cbar.set_label(CBAR_LABEL, fontsize=LABEL_SIZE)\n", "\n", "plt.tight_layout()\n", "\n", "# =========================\n", "# šŸ’¾ SAVE (optional)\n", "# =========================\n", "if SAVE_PATH:\n", " plt.savefig(SAVE_PATH, dpi=DPI, bbox_inches=\"tight\")\n", " print(f\"Saved figure → {SAVE_PATH}\")\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =====================================================\n", "# ROC CURVE: per-class dotted lines + macro mean black\n", "# with bootstrap std-dev shaded band\n", "# =====================================================\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.metrics import roc_curve, auc\n", "from sklearn.preprocessing import label_binarize\n", "\n", "# =========================\n", "# USER INPUT\n", "# =========================\n", "NPZ_PATH = \"../Results/results_rf_hparam_sweep/best_model/test_outputs.npz\"\n", "SAVE_PATH = \"../Results/results_rf_hparam_sweep/best_model/roc_auc.png\"\n", "\n", "N_BOOTSTRAPS = 300 \n", "RANDOM_SEED = 42\n", "\n", "# =========================\n", "# STYLE OPTIONS\n", "# =========================\n", "FONT_FAMILY = \"DejaVu Sans\"\n", "FIGSIZE = (7, 6)\n", "DPI = 600\n", "\n", "TITLE_SIZE = 18\n", "LABEL_SIZE = 16\n", "TICK_SIZE = 14\n", "LEGEND_SIZE = 10\n", "\n", "MACRO_LINEWIDTH = 1.5\n", "CLASS_LINEWIDTH = 1.5\n", "STD_ALPHA = 0.25\n", "\n", "MACRO_COLOR = \"black\"\n", "STD_COLOR = \"gray\"\n", "CLASS_COLORS = [\"#1f77b4\", \"#ff7f0e\", \"#2ca02c\", \"#d62728\", \"#9467bd\"]\n", "\n", "# =========================\n", "# LOAD DATA\n", "# =========================\n", "data = np.load(NPZ_PATH, allow_pickle=True)\n", "y_true = data[\"y_true\"]\n", "y_proba = data[\"y_proba\"]\n", "class_names = data[\"class_names\"]\n", "\n", "n_classes = y_proba.shape[1]\n", "\n", "# =========================\n", "# BINARIZE LABELS\n", "# =========================\n", "y_bin = label_binarize(y_true, classes=np.arange(n_classes))\n", "\n", "# =========================\n", "# HELPERS\n", "# =========================\n", "def compute_per_class_roc(y_bin_local, y_score_local):\n", " fpr_dict, tpr_dict, auc_dict = {}, {}, {}\n", " for i in range(y_score_local.shape[1]):\n", " fpr_dict[i], tpr_dict[i], _ = roc_curve(y_bin_local[:, i], y_score_local[:, i])\n", " auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])\n", " return fpr_dict, tpr_dict, auc_dict\n", "\n", "def compute_macro_roc(fpr_dict, tpr_dict, n_cls):\n", " all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_cls)]))\n", " mean_tpr = np.zeros_like(all_fpr)\n", " for i in range(n_cls):\n", " mean_tpr += np.interp(all_fpr, fpr_dict[i], tpr_dict[i])\n", " mean_tpr /= n_cls\n", " macro_auc_val = auc(all_fpr, mean_tpr)\n", " return all_fpr, mean_tpr, macro_auc_val\n", "\n", "# =========================\n", "# BASE ROC\n", "# =========================\n", "fpr_base, tpr_base, auc_base = compute_per_class_roc(y_bin, y_proba)\n", "macro_fpr, macro_tpr, macro_auc = compute_macro_roc(fpr_base, tpr_base, n_classes)\n", "\n", "# =========================\n", "# BOOTSTRAP FOR MACRO STD BAND\n", "# =========================\n", "rng = np.random.RandomState(RANDOM_SEED)\n", "macro_tprs_boot = []\n", "\n", "for _ in range(N_BOOTSTRAPS):\n", " idx = rng.choice(len(y_true), size=len(y_true), replace=True)\n", "\n", " # need all classes present for valid multiclass bootstrap\n", " if len(np.unique(y_true[idx])) < n_classes:\n", " continue\n", "\n", " y_true_bs = y_true[idx]\n", " y_proba_bs = y_proba[idx]\n", " y_bin_bs = label_binarize(y_true_bs, classes=np.arange(n_classes))\n", "\n", " try:\n", " fpr_bs, tpr_bs, _ = compute_per_class_roc(y_bin_bs, y_proba_bs)\n", " macro_fpr_bs, macro_tpr_bs, _ = compute_macro_roc(fpr_bs, tpr_bs, n_classes)\n", "\n", " interp_macro_tpr = np.interp(macro_fpr, macro_fpr_bs, macro_tpr_bs)\n", " interp_macro_tpr[0] = 0.0\n", " interp_macro_tpr[-1] = 1.0\n", " macro_tprs_boot.append(interp_macro_tpr)\n", " except Exception:\n", " continue\n", "\n", "macro_tprs_boot = np.array(macro_tprs_boot)\n", "\n", "if len(macro_tprs_boot) > 0:\n", " macro_mean_tpr = macro_tprs_boot.mean(axis=0)\n", " macro_std_tpr = macro_tprs_boot.std(axis=0)\n", " macro_lower = np.maximum(macro_mean_tpr - macro_std_tpr, 0)\n", " macro_upper = np.minimum(macro_mean_tpr + macro_std_tpr, 1)\n", "else:\n", " macro_mean_tpr = macro_tpr.copy()\n", " macro_lower = macro_tpr.copy()\n", " macro_upper = macro_tpr.copy()\n", "\n", "# =========================\n", "# PLOT\n", "# =========================\n", "plt.rcParams[\"font.family\"] = FONT_FAMILY\n", "plt.figure(figsize=FIGSIZE, dpi=DPI)\n", "\n", "# per-class dotted ROC curves\n", "for i in range(n_classes):\n", " plt.plot(\n", " fpr_base[i],\n", " tpr_base[i],\n", " linestyle=\":\",\n", " linewidth=CLASS_LINEWIDTH,\n", " color=CLASS_COLORS[i % len(CLASS_COLORS)],\n", " label=f\"{class_names[i]} (AUC = {auc_base[i]:.3f})\"\n", " )\n", "\n", "# macro mean solid black line\n", "plt.plot(\n", " macro_fpr,\n", " macro_tpr,\n", " color=MACRO_COLOR,\n", " linewidth=MACRO_LINEWIDTH,\n", " label=f\"Macro-average (AUC = {macro_auc:.3f})\"\n", ")\n", "\n", "# macro std-dev shaded band\n", "plt.fill_between(\n", " macro_fpr,\n", " macro_lower,\n", " macro_upper,\n", " color=STD_COLOR,\n", " alpha=STD_ALPHA,\n", " label=\"Macro ±1 std\"\n", ")\n", "\n", "# random baseline\n", "plt.plot([0, 1], [0, 1], linestyle=\"--\", color=\"black\", linewidth=1)\n", "\n", "plt.xlabel(\"False Positive Rate\", fontsize=LABEL_SIZE)\n", "plt.ylabel(\"True Positive Rate\", fontsize=LABEL_SIZE)\n", "plt.title(\"ROC Curve: Macro-Average and Per-Class\", fontsize=TITLE_SIZE)\n", "\n", "plt.xticks(fontsize=TICK_SIZE)\n", "plt.yticks(fontsize=TICK_SIZE)\n", "plt.legend(fontsize=LEGEND_SIZE, loc=\"lower right\", frameon=True)\n", "plt.grid(alpha=0.0)\n", "\n", "plt.tight_layout()\n", "\n", "if SAVE_PATH is not None:\n", " plt.savefig(SAVE_PATH, dpi=DPI, bbox_inches=\"tight\")\n", " print(f\"Saved → {SAVE_PATH}\")\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "pp_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 2 }