{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MLP model code" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "MLP hyperparameter sweep for multiclass classification.\n", "\n", "What it does\n", "------------\n", "1) Sweeps over a manual grid of hyperparameters.\n", "2) For EACH hyperparameter combo:\n", " - trains with early stopping on validation macro-F1\n", " - evaluates metrics on validation + test\n", " - writes ONE ROW immediately to a CSV (append mode) => crash-safe\n", "3) If re-run after a crash:\n", " - loads existing CSV and SKIPS combos already completed (resume capability)\n", "4) After sweep completes:\n", " - selects best params by SELECT_BEST_BY\n", " - retrains best model and saves ONLY that model + metadata\n", "5) Also saves the latent representation of the best model on the test set\n", " using the post-BN + post-ReLU representation from the fc3 block.\n", "\n", "Outputs\n", "-------\n", "- results_dir/\n", " - sweep_metrics.csv\n", " - sweep_log_.txt\n", " - best_model_/\n", " - mlp_state_dict.pt\n", " - metadata.json\n", " - test_outputs.npz\n", " - test_latent_post_bn_relu_fc3.pt\n", "\"\"\"\n", "\n", "import os\n", "import sys\n", "import json\n", "import time\n", "import copy\n", "import hashlib\n", "import random\n", "import itertools\n", "import contextlib\n", "from datetime import datetime\n", "\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "from sklearn.metrics import (\n", " accuracy_score,\n", " balanced_accuracy_score,\n", " f1_score,\n", " matthews_corrcoef,\n", " roc_auc_score,\n", ")\n", "from sklearn.utils.class_weight import compute_class_weight\n", "\n", "\n", "# =====================================================\n", "# USER CONFIG\n", "# =====================================================\n", "DATA_PATH = \"../Data/multiclass_data_no_SMOTE\" # Path where training_data.pt, validation_data.pt, test_data.pt are located\n", "RESULTS_DIR = \"../Results/results_mlp_hparam_sweep\" # Directory to save sweep results and best model\n", "\n", "USE_BALANCED_WEIGHTS = False # Whether to use class-balanced weights in the loss function (set to False since we're using SMOTE on train)\n", "GLOBAL_SEED = 42\n", "\n", "# Options: \"val_macro_f1\", \"val_bal_acc\", \"val_macro_auc_ovr\", \"val_mcc\"\n", "SELECT_BEST_BY = \"val_macro_f1\"\n", "\n", "HIDDEN_SIZES = (512, 256, 128)\n", "DATALOADER_WORKERS = 0\n", "\n", "CLASS_ID_TO_NAME = {\n", " 0: \"phosphate\",\n", " 1: \"sulfate\",\n", " 2: \"chloride\",\n", " 3: \"nitrate\",\n", " 4: \"carbonate\",\n", "}\n", "\n", "# =====================================================\n", "# PARAM GRID\n", "# =====================================================\n", "PARAM_GRID = {\n", " \"lr\": [0.001],\n", " \"weight_decay\": [1e-4],\n", " \"dropout\": [0.2],\n", " \"batch_size\": [128],\n", " \"epochs\": [200],\n", " \"patience\": [10],\n", " \"use_bn\": [True],\n", "}\n", "\n", "# =====================================================\n", "# REPRODUCIBILITY\n", "# =====================================================\n", "def seed_everything(seed: int = 42):\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", "\n", "seed_everything(GLOBAL_SEED)\n", "\n", "\n", "# =====================================================\n", "# IO SETUP\n", "# =====================================================\n", "os.makedirs(RESULTS_DIR, exist_ok=True)\n", "\n", "ts = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", "LOG_PATH = os.path.join(RESULTS_DIR, f\"sweep_log_{ts}.txt\")\n", "CSV_PATH = os.path.join(RESULTS_DIR, \"sweep_metrics.csv\")\n", "\n", "BEST_MODEL_DIR = os.path.join(RESULTS_DIR, f\"best_model_{ts}\")\n", "os.makedirs(BEST_MODEL_DIR, exist_ok=True)\n", "\n", "\n", "# =====================================================\n", "# TEE LOGGER\n", "# =====================================================\n", "class Tee:\n", " def __init__(self, *streams):\n", " self.streams = streams\n", " def write(self, data):\n", " for s in self.streams:\n", " s.write(data)\n", " def flush(self):\n", " for s in self.streams:\n", " try:\n", " s.flush()\n", " except Exception:\n", " pass\n", "\n", "\n", "# =====================================================\n", "# LOAD DATA\n", "# =====================================================\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using {device} device\")\n", "\n", "print(\"šŸ“¦ Loading datasets...\")\n", "train = torch.load(os.path.join(DATA_PATH, \"training_data.pt\"), map_location=\"cpu\")\n", "val = torch.load(os.path.join(DATA_PATH, \"validation_data.pt\"), map_location=\"cpu\")\n", "test = torch.load(os.path.join(DATA_PATH, \"test_data.pt\"), map_location=\"cpu\")\n", "print(\"āœ… Data loaded successfully!\")\n", "\n", "def to_np(t):\n", " return t.detach().cpu().numpy() if torch.is_tensor(t) else np.array(t)\n", "\n", "X_train, y_train = to_np(train[\"X_train\"]), to_np(train[\"y_train\"]).ravel().astype(int)\n", "X_val, y_val = to_np(val[\"X_val\"]), to_np(val[\"y_val\"]).ravel().astype(int)\n", "X_test, y_test = to_np(test[\"X_test\"]), to_np(test[\"y_test\"]).ravel().astype(int)\n", "\n", "classes = np.unique(y_train)\n", "n_classes = len(classes)\n", "input_size = X_train.shape[1]\n", "\n", "print(f\"Train: X={X_train.shape}, y={y_train.shape}\")\n", "print(f\"Val: X={X_val.shape}, y={y_val.shape}\")\n", "print(f\"Test: X={X_test.shape}, y={y_test.shape}\")\n", "print(f\"Classes: {n_classes} | Train distribution: {np.bincount(y_train)}\")\n", "print(f\"CSV (incremental): {CSV_PATH}\")\n", "\n", "\n", "# =====================================================\n", "# CLASS WEIGHTS\n", "# =====================================================\n", "if USE_BALANCED_WEIGHTS:\n", " class_weights = compute_class_weight(\n", " class_weight=\"balanced\",\n", " classes=np.arange(n_classes),\n", " y=y_train\n", " )\n", " print(\"Class weights:\", class_weights)\n", "else:\n", " class_weights = None\n", " print(\"Using uniform weights (no class weighting)\")\n", "\n", "\n", "# =====================================================\n", "# DATALOADER\n", "# =====================================================\n", "def make_loader(X, y, batch_size=256, shuffle=False):\n", " X_t = torch.tensor(X, dtype=torch.float32)\n", " y_t = torch.tensor(y, dtype=torch.long)\n", " ds = TensorDataset(X_t, y_t)\n", " pin = (device.type == \"cuda\")\n", " return DataLoader(\n", " ds,\n", " batch_size=batch_size,\n", " shuffle=shuffle,\n", " drop_last=False,\n", " num_workers=DATALOADER_WORKERS,\n", " pin_memory=pin\n", " )\n", "\n", "\n", "# =====================================================\n", "# MODEL\n", "# =====================================================\n", "class MLP(nn.Module):\n", " def __init__(self, in_dim, h1, h2, h3, out_dim, p_drop=0.3, use_bn=True):\n", " super().__init__()\n", " self.fc1 = nn.Linear(in_dim, h1)\n", " self.bn1 = nn.BatchNorm1d(h1) if use_bn else nn.Identity()\n", "\n", " self.fc2 = nn.Linear(h1, h2)\n", " self.bn2 = nn.BatchNorm1d(h2) if use_bn else nn.Identity()\n", "\n", " self.fc3 = nn.Linear(h2, h3)\n", " self.bn3 = nn.BatchNorm1d(h3) if use_bn else nn.Identity()\n", "\n", " self.out = nn.Linear(h3, out_dim)\n", " self.drop = nn.Dropout(p_drop)\n", "\n", " def forward_features(self, x, apply_dropout=False):\n", " x = self.fc1(x)\n", " x = self.bn1(x)\n", " x = F.relu(x)\n", " x = self.drop(x)\n", "\n", " x = self.fc2(x)\n", " x = self.bn2(x)\n", " x = F.relu(x)\n", " x = self.drop(x)\n", "\n", " x = self.fc3(x)\n", " x = self.bn3(x)\n", " x = F.relu(x)\n", "\n", " if apply_dropout:\n", " x = self.drop(x)\n", "\n", " return x\n", "\n", " def forward(self, x):\n", " x = self.forward_features(x, apply_dropout=True)\n", " return self.out(x)\n", "\n", "\n", "# =====================================================\n", "# METRICS\n", "# =====================================================\n", "def safe_multiclass_auc(y_true, y_proba, average=\"macro\"):\n", " try:\n", " return roc_auc_score(y_true, y_proba, multi_class=\"ovr\", average=average)\n", " except Exception:\n", " return np.nan\n", "\n", "def predict_proba(model, X, batch_size=1024):\n", " loader = make_loader(X, np.zeros(len(X), dtype=int), batch_size=batch_size, shuffle=False)\n", " model.eval()\n", " probs_all = []\n", " with torch.no_grad():\n", " for xb, _ in loader:\n", " xb = xb.to(device)\n", " logits = model(xb)\n", " probs = F.softmax(logits, dim=1).cpu().numpy()\n", " probs_all.append(probs)\n", " return np.vstack(probs_all)\n", "\n", "def per_class_f1_dict(y_true, y_pred, n_classes):\n", " scores = f1_score(y_true, y_pred, average=None, labels=np.arange(n_classes))\n", " return {f\"f1_c{i}\": float(scores[i]) for i in range(n_classes)}\n", "\n", "def compute_metrics_with_per_class(y_true, y_pred, y_proba, n_classes):\n", " out = {}\n", " out[\"acc\"] = float(accuracy_score(y_true, y_pred))\n", " out[\"bal_acc\"] = float(balanced_accuracy_score(y_true, y_pred))\n", " out[\"macro_f1\"] = float(f1_score(y_true, y_pred, average=\"macro\"))\n", " out[\"weighted_f1\"] = float(f1_score(y_true, y_pred, average=\"weighted\"))\n", " out[\"micro_f1\"] = float(f1_score(y_true, y_pred, average=\"micro\"))\n", " out[\"mcc\"] = float(matthews_corrcoef(y_true, y_pred))\n", " out[\"macro_auc_ovr\"] = float(safe_multiclass_auc(y_true, y_proba, average=\"macro\"))\n", " out[\"weighted_auc_ovr\"] = float(safe_multiclass_auc(y_true, y_proba, average=\"weighted\"))\n", " out.update(per_class_f1_dict(y_true, y_pred, n_classes))\n", " return out\n", "\n", "def save_prediction_outputs(save_dir, y_true, y_pred, y_proba, class_names=None, prefix=\"test\"):\n", " os.makedirs(save_dir, exist_ok=True)\n", " save_path = os.path.join(save_dir, f\"{prefix}_outputs.npz\")\n", "\n", " if class_names is None:\n", " class_names = np.array([f\"C{i}\" for i in range(y_proba.shape[1])], dtype=object)\n", " else:\n", " class_names = np.array(class_names, dtype=object)\n", "\n", " np.savez_compressed(\n", " save_path,\n", " y_true=np.asarray(y_true),\n", " y_pred=np.asarray(y_pred),\n", " y_proba=np.asarray(y_proba),\n", " class_names=class_names\n", " )\n", " print(f\"Saved prediction outputs to: {save_path}\")\n", "\n", "def collect_mlp_predictions(model, loader, device):\n", " model.eval()\n", "\n", " all_true = []\n", " all_pred = []\n", " all_proba = []\n", "\n", " with torch.no_grad():\n", " for xb, yb in loader:\n", " xb = xb.to(device)\n", " yb = yb.to(device)\n", "\n", " logits = model(xb)\n", " proba = torch.softmax(logits, dim=1)\n", " pred = torch.argmax(proba, dim=1)\n", "\n", " all_true.append(yb.cpu().numpy())\n", " all_pred.append(pred.cpu().numpy())\n", " all_proba.append(proba.cpu().numpy())\n", "\n", " y_true = np.concatenate(all_true)\n", " y_pred = np.concatenate(all_pred)\n", " y_proba = np.concatenate(all_proba)\n", "\n", " return y_true, y_pred, y_proba\n", "\n", "\n", "# =====================================================\n", "# LATENT SPACE EXTRACTION + SAVE\n", "# =====================================================\n", "def collect_latent_features(model, loader, device):\n", " \"\"\"\n", " Collect post-BN + post-ReLU latent features from model.forward_features().\n", " \"\"\"\n", " model.eval()\n", "\n", " all_embeddings = []\n", " all_labels = []\n", "\n", " with torch.no_grad():\n", " for xb, yb in loader:\n", " xb = xb.to(device)\n", " emb = model.forward_features(xb, apply_dropout=False)\n", " all_embeddings.append(emb.cpu())\n", " all_labels.append(yb.cpu())\n", "\n", " embeddings = torch.cat(all_embeddings, dim=0).numpy()\n", " labels = torch.cat(all_labels, dim=0).numpy().astype(int)\n", "\n", " return embeddings, labels\n", "\n", "def save_latent_space_pt(\n", " model,\n", " X,\n", " y,\n", " device,\n", " save_path,\n", " batch_size=1024,\n", " is_synthetic=None,\n", " latent_name=\"post_bn_relu_fc3\"\n", "):\n", " \"\"\"\n", " Save post-ReLU latent space outputs as a .pt file.\n", " \"\"\"\n", " loader = make_loader(X, y, batch_size=batch_size, shuffle=False)\n", " embeddings, labels = collect_latent_features(\n", " model=model,\n", " loader=loader,\n", " device=device\n", " )\n", "\n", " save_obj = {\n", " \"embeddings\": torch.tensor(embeddings, dtype=torch.float32),\n", " \"labels\": torch.tensor(labels, dtype=torch.long),\n", " \"latent_name\": latent_name,\n", " }\n", "\n", " if is_synthetic is not None:\n", " is_synthetic = np.asarray(is_synthetic).ravel().astype(int)\n", " if len(is_synthetic) != len(labels):\n", " raise ValueError(\"Length of is_synthetic must match number of samples.\")\n", " save_obj[\"is_synthetic\"] = torch.tensor(is_synthetic, dtype=torch.long)\n", "\n", " torch.save(save_obj, save_path)\n", " print(f\"Saved latent space .pt file to: {save_path}\")\n", " print(f\" latent_name : {latent_name}\")\n", " print(f\" embeddings : {save_obj['embeddings'].shape}\")\n", " print(f\" labels : {save_obj['labels'].shape}\")\n", " if \"is_synthetic\" in save_obj:\n", " print(f\" is_synthetic: {save_obj['is_synthetic'].shape}\")\n", "\n", "\n", "# =====================================================\n", "# PARAM UTILITIES\n", "# =====================================================\n", "def stable_param_id(params: dict) -> str:\n", " s = json.dumps(params, sort_keys=True)\n", " return hashlib.md5(s.encode(\"utf-8\")).hexdigest()[:12]\n", "\n", "def load_completed_param_ids(csv_path: str) -> set:\n", " if not os.path.exists(csv_path):\n", " return set()\n", " try:\n", " df = pd.read_csv(csv_path)\n", " if \"param_id\" not in df.columns:\n", " return set()\n", " return set(df[\"param_id\"].astype(str).tolist())\n", " except Exception:\n", " return set()\n", "\n", "\n", "# =====================================================\n", "# TRAINING\n", "# =====================================================\n", "def train_one_model(X_tr, y_tr, X_va, y_va, params, class_weights=None, seed=42):\n", " seed_everything(seed)\n", "\n", " model = MLP(\n", " input_size,\n", " HIDDEN_SIZES[0], HIDDEN_SIZES[1], HIDDEN_SIZES[2],\n", " n_classes,\n", " p_drop=params[\"dropout\"],\n", " use_bn=params[\"use_bn\"]\n", " ).to(device)\n", "\n", " train_loader = make_loader(X_tr, y_tr, batch_size=params[\"batch_size\"], shuffle=True)\n", " val_loader = make_loader(X_va, y_va, batch_size=1024, shuffle=False)\n", "\n", " if class_weights is not None:\n", " cw = torch.tensor(class_weights, dtype=torch.float32, device=device)\n", " criterion = nn.CrossEntropyLoss(weight=cw)\n", " else:\n", " criterion = nn.CrossEntropyLoss()\n", "\n", " optimizer = torch.optim.AdamW(\n", " model.parameters(),\n", " lr=params[\"lr\"],\n", " weight_decay=params[\"weight_decay\"]\n", " )\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params[\"epochs\"])\n", "\n", " best_val_macro_f1 = -1.0\n", " best_state = None\n", " epochs_no_improve = 0\n", " patience = int(params[\"patience\"])\n", "\n", " for _epoch in range(1, int(params[\"epochs\"]) + 1):\n", " model.train()\n", " for xb, yb in train_loader:\n", " xb, yb = xb.to(device), yb.to(device)\n", " optimizer.zero_grad(set_to_none=True)\n", " logits = model(xb)\n", " loss = criterion(logits, yb)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n", " optimizer.step()\n", "\n", " scheduler.step()\n", "\n", " model.eval()\n", " yv_pred = []\n", " with torch.no_grad():\n", " for xb, _ in val_loader:\n", " xb = xb.to(device)\n", " logits = model(xb)\n", " yv_pred.append(logits.argmax(dim=1).cpu().numpy())\n", " yv_pred = np.concatenate(yv_pred)\n", " val_macro_f1 = f1_score(y_va, yv_pred, average=\"macro\")\n", "\n", " if val_macro_f1 > best_val_macro_f1:\n", " best_val_macro_f1 = val_macro_f1\n", " best_state = copy.deepcopy(model.state_dict())\n", " epochs_no_improve = 0\n", " else:\n", " epochs_no_improve += 1\n", "\n", " if epochs_no_improve >= patience:\n", " break\n", "\n", " if best_state is not None:\n", " model.load_state_dict(best_state)\n", "\n", " return model\n", "\n", "\n", "# =====================================================\n", "# CSV APPEND (CRASH-SAFE)\n", "# =====================================================\n", "def append_row_to_csv(csv_path: str, row: dict):\n", " row_df = pd.DataFrame([row])\n", "\n", " if not os.path.exists(csv_path):\n", " row_df.to_csv(csv_path, index=False)\n", " else:\n", " row_df.to_csv(csv_path, mode=\"a\", header=False, index=False)\n", "\n", "\n", "# =====================================================\n", "# BUILD PARAM COMBOS\n", "# =====================================================\n", "keys = list(PARAM_GRID.keys())\n", "values = list(PARAM_GRID.values())\n", "param_combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]\n", "print(f\"Total combos in grid: {len(param_combinations)}\")\n", "\n", "\n", "# =====================================================\n", "# RUN SWEEP (RESUME-CAPABLE)\n", "# =====================================================\n", "log_f = open(LOG_PATH, \"w\", encoding=\"utf-8\")\n", "tee = Tee(sys.stdout, log_f)\n", "\n", "try:\n", " with contextlib.redirect_stdout(tee), contextlib.redirect_stderr(tee):\n", " completed = load_completed_param_ids(CSV_PATH)\n", " print(f\"Resume: found {len(completed)} completed combos in existing CSV.\")\n", "\n", " t0 = time.time()\n", "\n", " for idx, params in enumerate(param_combinations, 1):\n", " pid = stable_param_id(params)\n", " if pid in completed:\n", " print(f\"[{idx}/{len(param_combinations)}] SKIP (done) param_id={pid} params={params}\")\n", " continue\n", "\n", " print(f\"\\n[{idx}/{len(param_combinations)}] RUN param_id={pid} params={params}\")\n", "\n", " row = {\n", " \"timestamp\": datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"),\n", " \"param_id\": pid,\n", " \"use_balanced_weights\": bool(USE_BALANCED_WEIGHTS),\n", " \"seed\": int(GLOBAL_SEED),\n", " **params\n", " }\n", "\n", " try:\n", " model = train_one_model(\n", " X_train, y_train, X_val, y_val,\n", " params,\n", " class_weights=class_weights,\n", " seed=GLOBAL_SEED\n", " )\n", "\n", " val_proba = predict_proba(model, X_val, batch_size=1024)\n", " val_pred = val_proba.argmax(axis=1)\n", " val_metrics = compute_metrics_with_per_class(y_val, val_pred, val_proba, n_classes)\n", "\n", " test_proba = predict_proba(model, X_test, batch_size=1024)\n", " test_pred = test_proba.argmax(axis=1)\n", " test_metrics = compute_metrics_with_per_class(y_test, test_pred, test_proba, n_classes)\n", "\n", " row.update({f\"val_{k}\": v for k, v in val_metrics.items()})\n", " row.update({f\"test_{k}\": v for k, v in test_metrics.items()})\n", "\n", " append_row_to_csv(CSV_PATH, row)\n", " completed.add(pid)\n", "\n", " print(\n", " f\"VAL acc={row['val_acc']:.4f} bal_acc={row['val_bal_acc']:.4f} \"\n", " f\"macro_f1={row['val_macro_f1']:.4f} mcc={row['val_mcc']:.4f}\"\n", " )\n", " print(\n", " f\"TEST acc={row['test_acc']:.4f} bal_acc={row['test_bal_acc']:.4f} \"\n", " f\"macro_f1={row['test_macro_f1']:.4f} mcc={row['test_mcc']:.4f}\"\n", " )\n", "\n", " del model\n", " if device.type == \"cuda\":\n", " torch.cuda.empty_cache()\n", "\n", " except Exception as e:\n", " row[\"error\"] = str(e)\n", " append_row_to_csv(CSV_PATH, row)\n", " completed.add(pid)\n", " print(f\"āŒ FAILED param_id={pid} error={e}\")\n", "\n", " print(f\"\\nSweep finished. Wall time: {time.time()-t0:.1f} sec\")\n", " print(f\"CSV saved incrementally at: {CSV_PATH}\")\n", "\n", " # =====================================================\n", " # SELECT BEST + TRAIN AGAIN + SAVE ONLY BEST MODEL\n", " # =====================================================\n", " df = pd.read_csv(CSV_PATH)\n", "\n", " if \"error\" in df.columns:\n", " df_ok = df[df[\"error\"].isna()]\n", " else:\n", " df_ok = df\n", "\n", " if df_ok.empty:\n", " raise RuntimeError(\"No successful runs found in CSV; cannot select best model.\")\n", "\n", " if SELECT_BEST_BY not in df_ok.columns:\n", " raise RuntimeError(f\"SELECT_BEST_BY='{SELECT_BEST_BY}' not found in CSV columns.\")\n", "\n", " df_ok = df_ok.sort_values(SELECT_BEST_BY, ascending=False)\n", " best_row = df_ok.iloc[0].to_dict()\n", "\n", " best_params = {k: best_row[k] for k in keys}\n", " best_params[\"epochs\"] = int(best_params[\"epochs\"])\n", " best_params[\"patience\"] = int(best_params[\"patience\"])\n", " best_params[\"batch_size\"] = int(best_params[\"batch_size\"])\n", " best_params[\"use_bn\"] = bool(best_params[\"use_bn\"])\n", " best_params[\"lr\"] = float(best_params[\"lr\"])\n", " best_params[\"weight_decay\"] = float(best_params[\"weight_decay\"])\n", " best_params[\"dropout\"] = float(best_params[\"dropout\"])\n", "\n", " print(\"\\n\" + \"=\" * 90)\n", " print(f\"BEST PARAMS by {SELECT_BEST_BY} = {best_row[SELECT_BEST_BY]:.6f}\")\n", " print(best_params)\n", " print(\"=\" * 90)\n", "\n", " print(\"\\nTraining best model again (for saving)...\")\n", " best_model = train_one_model(\n", " X_train, y_train, X_val, y_val,\n", " best_params,\n", " class_weights=class_weights,\n", " seed=GLOBAL_SEED\n", " )\n", "\n", " model_path = os.path.join(BEST_MODEL_DIR, \"mlp_state_dict.pt\")\n", " torch.save(best_model.state_dict(), model_path)\n", "\n", " # =====================================================\n", " # SAVE LATENT SPACE (.pt) FOR BEST MODEL\n", " # =====================================================\n", " test_is_synthetic = test[\"is_synthetic\"].detach().cpu().numpy() if \"is_synthetic\" in test else None\n", "\n", " latent_test_path = os.path.join(BEST_MODEL_DIR, \"test_latent_post_bn_relu_fc3.pt\")\n", " save_latent_space_pt(\n", " model=best_model,\n", " X=X_test,\n", " y=y_test,\n", " device=device,\n", " save_path=latent_test_path,\n", " batch_size=1024,\n", " is_synthetic=test_is_synthetic,\n", " latent_name=\"post_bn_relu_fc3\"\n", " )\n", "\n", " # =====================================================\n", " # SAVE PREDICTION OUTPUTS\n", " # =====================================================\n", " test_loader = make_loader(X_test, y_test, batch_size=1024, shuffle=False)\n", "\n", " y_true, y_pred, y_proba = collect_mlp_predictions(best_model, test_loader, device)\n", "\n", " save_prediction_outputs(\n", " save_dir=BEST_MODEL_DIR,\n", " y_true=y_true,\n", " y_pred=y_pred,\n", " y_proba=y_proba,\n", " class_names=[CLASS_ID_TO_NAME[i] for i in range(n_classes)],\n", " prefix=\"test\"\n", " )\n", "\n", " meta = {\n", " \"data_path\": DATA_PATH,\n", " \"results_csv\": CSV_PATH,\n", " \"select_best_by\": SELECT_BEST_BY,\n", " \"best_score\": float(best_row[SELECT_BEST_BY]),\n", " \"best_param_id\": str(best_row[\"param_id\"]),\n", " \"best_params\": best_params,\n", " \"use_balanced_weights\": bool(USE_BALANCED_WEIGHTS),\n", " \"seed\": int(GLOBAL_SEED),\n", " \"feature_dim\": int(input_size),\n", " \"n_classes\": int(n_classes),\n", " \"architecture\": {\n", " \"hidden_sizes\": [int(x) for x in HIDDEN_SIZES]\n", " },\n", " \"class_id_to_name\": {str(k): v for k, v in CLASS_ID_TO_NAME.items()},\n", " \"latent_file\": \"test_latent_post_bn_relu_fc3.pt\",\n", " \"latent_name\": \"post_bn_relu_fc3\",\n", " \"timestamp_saved\": ts,\n", " }\n", " meta_path = os.path.join(BEST_MODEL_DIR, \"metadata.json\")\n", " with open(meta_path, \"w\", encoding=\"utf-8\") as jf:\n", " json.dump(meta, jf, indent=2)\n", "\n", " print(f\"\\nāœ… Saved BEST model weights: {model_path}\")\n", " print(f\"āœ… Saved BEST metadata: {meta_path}\")\n", " print(f\"āœ… Saved BEST latent space: {latent_test_path}\")\n", "\n", "finally:\n", " log_f.close()\n", " print(\"Log file closed.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Confusion matrix plot" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =====================================================\n", "# šŸ“Š CONFUSION MATRIX (FULLY CUSTOMIZABLE - SINGLE CELL)\n", "# =====================================================\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.metrics import confusion_matrix\n", "\n", "# =========================\n", "# šŸ”§ USER INPUT\n", "# =========================\n", "NPZ_PATH = f\"../Results/results_mlp_hparam_sweep/best_model/test_outputs.npz\" \n", "SAVE_PATH = f\"../Results/results_mlp_hparam_sweep/best_model/confusion_matrix.png\"\n", "\n", "# =========================\n", "# šŸŽØ STYLE OPTIONS\n", "# =========================\n", "FONT_FAMILY = \"DejaVu Sans\" \n", "TITLE_SIZE = 16\n", "LABEL_SIZE = 12\n", "TICK_SIZE = 12\n", "ANNOT_SIZE = 12\n", "\n", "FIGSIZE = (6, 5)\n", "DPI = 300\n", "\n", "CMAP = \"Blues\" # try: \"viridis\", \"magma\", \"coolwarm\"\n", "SHOW_VALUES = True\n", "FORMAT = \".2f\"\n", "\n", "NORMALIZE = True # True = normalized, False = raw counts\n", "\n", "# Grid / borders\n", "SHOW_GRID = False\n", "LINEWIDTH = 0.4\n", "LINECOLOR = \"black\"\n", "\n", "# Colorbar\n", "SHOW_CBAR = True\n", "CBAR_LABEL = \"\"\n", "\n", "# Tick rotation\n", "XTICK_ROT = 45\n", "YTICK_ROT = 0\n", "\n", "# =========================\n", "# šŸ“‚ LOAD DATA\n", "# =========================\n", "data = np.load(NPZ_PATH, allow_pickle=True)\n", "\n", "y_true = data[\"y_true\"]\n", "y_pred = data[\"y_pred\"]\n", "class_names = data[\"class_names\"]\n", "\n", "# =========================\n", "# šŸ“Š CONFUSION MATRIX\n", "# =========================\n", "cm = confusion_matrix(y_true, y_pred)\n", "\n", "if NORMALIZE:\n", " cm = cm.astype(float) / cm.sum(axis=1, keepdims=True)\n", " cm = np.nan_to_num(cm)\n", " title = \"Normalized Confusion Matrix\"\n", " vmin, vmax = 0.0, 1.0\n", "else:\n", " title = \"Confusion Matrix\"\n", " FORMAT = \"d\"\n", " vmin, vmax = None, None\n", "\n", "# =========================\n", "# šŸŽØ PLOT SETTINGS\n", "# =========================\n", "plt.rcParams[\"font.family\"] = FONT_FAMILY\n", "\n", "plt.figure(figsize=FIGSIZE, dpi=DPI)\n", "\n", "ax = sns.heatmap(\n", " cm,\n", " annot=SHOW_VALUES,\n", " fmt=FORMAT,\n", " cmap=CMAP,\n", " xticklabels=class_names,\n", " yticklabels=class_names,\n", " cbar=SHOW_CBAR,\n", " linewidths=LINEWIDTH if SHOW_GRID else 0,\n", " linecolor=LINECOLOR,\n", " vmin=vmin,\n", " vmax=vmax,\n", " annot_kws={\"size\": ANNOT_SIZE}\n", ")\n", "\n", "# Labels and title\n", "plt.xlabel(\"Predicted Label\", fontsize=LABEL_SIZE)\n", "plt.ylabel(\"True Label\", fontsize=LABEL_SIZE)\n", "plt.title(title, fontsize=TITLE_SIZE)\n", "\n", "# Tick formatting\n", "plt.xticks(rotation=XTICK_ROT, ha=\"right\", fontsize=TICK_SIZE)\n", "plt.yticks(rotation=YTICK_ROT, fontsize=TICK_SIZE)\n", "\n", "# Colorbar label\n", "if SHOW_CBAR:\n", " cbar = ax.collections[0].colorbar\n", " cbar.set_label(CBAR_LABEL, fontsize=LABEL_SIZE)\n", "\n", "plt.tight_layout()\n", "\n", "# =========================\n", "# šŸ’¾ SAVE (optional)\n", "# =========================\n", "if SAVE_PATH:\n", " plt.savefig(SAVE_PATH, dpi=DPI, bbox_inches=\"tight\")\n", " print(f\"Saved figure → {SAVE_PATH}\")\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# ROC Curve plot" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =====================================================\n", "# ROC CURVE: per-class dotted lines + macro mean black\n", "# with bootstrap std-dev shaded band\n", "# =====================================================\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.metrics import roc_curve, auc\n", "from sklearn.preprocessing import label_binarize\n", "\n", "# =========================\n", "# USER INPUT\n", "# =========================\n", "NPZ_PATH = \"../Results/results_mlp_hparam_sweep/best_model/test_outputs.npz\"\n", "SAVE_PATH = \"../Results/results_mlp_hparam_sweep/best_model/roc_auc.png\"\n", "\n", "N_BOOTSTRAPS = 300 \n", "RANDOM_SEED = 42\n", "\n", "# =========================\n", "# STYLE OPTIONS\n", "# =========================\n", "FONT_FAMILY = \"DejaVu Sans\"\n", "FIGSIZE = (7, 6)\n", "DPI = 600\n", "\n", "TITLE_SIZE = 18\n", "LABEL_SIZE = 16\n", "TICK_SIZE = 14\n", "LEGEND_SIZE = 10\n", "\n", "MACRO_LINEWIDTH = 1.5\n", "CLASS_LINEWIDTH = 1.5\n", "STD_ALPHA = 0.25\n", "\n", "MACRO_COLOR = \"black\"\n", "STD_COLOR = \"gray\"\n", "CLASS_COLORS = [\"#1f77b4\", \"#ff7f0e\", \"#2ca02c\", \"#d62728\", \"#9467bd\"]\n", "\n", "# =========================\n", "# LOAD DATA\n", "# =========================\n", "data = np.load(NPZ_PATH, allow_pickle=True)\n", "y_true = data[\"y_true\"]\n", "y_proba = data[\"y_proba\"]\n", "class_names = data[\"class_names\"]\n", "\n", "n_classes = y_proba.shape[1]\n", "\n", "# =========================\n", "# BINARIZE LABELS\n", "# =========================\n", "y_bin = label_binarize(y_true, classes=np.arange(n_classes))\n", "\n", "# =========================\n", "# HELPERS\n", "# =========================\n", "def compute_per_class_roc(y_bin_local, y_score_local):\n", " fpr_dict, tpr_dict, auc_dict = {}, {}, {}\n", " for i in range(y_score_local.shape[1]):\n", " fpr_dict[i], tpr_dict[i], _ = roc_curve(y_bin_local[:, i], y_score_local[:, i])\n", " auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])\n", " return fpr_dict, tpr_dict, auc_dict\n", "\n", "def compute_macro_roc(fpr_dict, tpr_dict, n_cls):\n", " all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_cls)]))\n", " mean_tpr = np.zeros_like(all_fpr)\n", " for i in range(n_cls):\n", " mean_tpr += np.interp(all_fpr, fpr_dict[i], tpr_dict[i])\n", " mean_tpr /= n_cls\n", " macro_auc_val = auc(all_fpr, mean_tpr)\n", " return all_fpr, mean_tpr, macro_auc_val\n", "\n", "# =========================\n", "# BASE ROC\n", "# =========================\n", "fpr_base, tpr_base, auc_base = compute_per_class_roc(y_bin, y_proba)\n", "macro_fpr, macro_tpr, macro_auc = compute_macro_roc(fpr_base, tpr_base, n_classes)\n", "\n", "# =========================\n", "# BOOTSTRAP FOR MACRO STD BAND\n", "# =========================\n", "rng = np.random.RandomState(RANDOM_SEED)\n", "macro_tprs_boot = []\n", "\n", "for _ in range(N_BOOTSTRAPS):\n", " idx = rng.choice(len(y_true), size=len(y_true), replace=True)\n", "\n", " # need all classes present for valid multiclass bootstrap\n", " if len(np.unique(y_true[idx])) < n_classes:\n", " continue\n", "\n", " y_true_bs = y_true[idx]\n", " y_proba_bs = y_proba[idx]\n", " y_bin_bs = label_binarize(y_true_bs, classes=np.arange(n_classes))\n", "\n", " try:\n", " fpr_bs, tpr_bs, _ = compute_per_class_roc(y_bin_bs, y_proba_bs)\n", " macro_fpr_bs, macro_tpr_bs, _ = compute_macro_roc(fpr_bs, tpr_bs, n_classes)\n", "\n", " interp_macro_tpr = np.interp(macro_fpr, macro_fpr_bs, macro_tpr_bs)\n", " interp_macro_tpr[0] = 0.0\n", " interp_macro_tpr[-1] = 1.0\n", " macro_tprs_boot.append(interp_macro_tpr)\n", " except Exception:\n", " continue\n", "\n", "macro_tprs_boot = np.array(macro_tprs_boot)\n", "\n", "if len(macro_tprs_boot) > 0:\n", " macro_mean_tpr = macro_tprs_boot.mean(axis=0)\n", " macro_std_tpr = macro_tprs_boot.std(axis=0)\n", " macro_lower = np.maximum(macro_mean_tpr - macro_std_tpr, 0)\n", " macro_upper = np.minimum(macro_mean_tpr + macro_std_tpr, 1)\n", "else:\n", " macro_mean_tpr = macro_tpr.copy()\n", " macro_lower = macro_tpr.copy()\n", " macro_upper = macro_tpr.copy()\n", "\n", "# =========================\n", "# PLOT\n", "# =========================\n", "plt.rcParams[\"font.family\"] = FONT_FAMILY\n", "plt.figure(figsize=FIGSIZE, dpi=DPI)\n", "\n", "# per-class dotted ROC curves\n", "for i in range(n_classes):\n", " plt.plot(\n", " fpr_base[i],\n", " tpr_base[i],\n", " linestyle=\":\",\n", " linewidth=CLASS_LINEWIDTH,\n", " color=CLASS_COLORS[i % len(CLASS_COLORS)],\n", " label=f\"{class_names[i]} (AUC = {auc_base[i]:.3f})\"\n", " )\n", "\n", "# macro mean solid black line\n", "plt.plot(\n", " macro_fpr,\n", " macro_tpr,\n", " color=MACRO_COLOR,\n", " linewidth=MACRO_LINEWIDTH,\n", " label=f\"Macro-average (AUC = {macro_auc:.3f})\"\n", ")\n", "\n", "# macro std-dev shaded band\n", "plt.fill_between(\n", " macro_fpr,\n", " macro_lower,\n", " macro_upper,\n", " color=STD_COLOR,\n", " alpha=STD_ALPHA,\n", " label=\"Macro ±1 std\"\n", ")\n", "\n", "# random baseline\n", "plt.plot([0, 1], [0, 1], linestyle=\"--\", color=\"black\", linewidth=1)\n", "\n", "plt.xlabel(\"False Positive Rate\", fontsize=LABEL_SIZE)\n", "plt.ylabel(\"True Positive Rate\", fontsize=LABEL_SIZE)\n", "plt.title(\"ROC Curve: Macro-Average and Per-Class\", fontsize=TITLE_SIZE)\n", "\n", "plt.xticks(fontsize=TICK_SIZE)\n", "plt.yticks(fontsize=TICK_SIZE)\n", "plt.legend(fontsize=LEGEND_SIZE, loc=\"lower right\", frameon=True)\n", "plt.grid(alpha=0.0)\n", "\n", "plt.tight_layout()\n", "\n", "if SAVE_PATH is not None:\n", " plt.savefig(SAVE_PATH, dpi=DPI, bbox_inches=\"tight\")\n", " print(f\"Saved → {SAVE_PATH}\")\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "pp_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 2 }