{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import argparse\n", "import pickle\n", "import pdb\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.font_manager as fm\n", "\n", "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n", "from sklearn.manifold import TSNE\n", "from sklearn.model_selection import train_test_split\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "from imblearn.over_sampling import SMOTE, BorderlineSMOTE\n", "\n", "sys.path.append(\"../Utils\")\n", "from phosbind_utils import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n" ] } ], "source": [ "# from torchvision import datasets, transforms\n", "device = (\n", " \"cuda\"\n", " if torch.cuda.is_available()\n", " else \"mps\"\n", " if torch.backends.mps.is_available()\n", " else \"cpu\"\n", ")\n", "print(f\"Using {device} device\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def create_parser():\n", " parser = argparse.ArgumentParser(\n", " description=\"Generate t-SNE plot with specified parameters\"\n", " )\n", " parser.add_argument(\n", " \"--ion_symbols\",\n", " type=str,\n", " help=\"In case of multi-ion mode, specify the list of ions you want to plot. Default is just K.\",\n", " )\n", " parser.add_argument(\n", " \"--input_file\",\n", " type=str,\n", " default=None,\n", " required=False,\n", " help=\"Load the arguments from this input file\",\n", " )\n", " parser.add_argument(\n", " \"--ion_names\",\n", " type=str,\n", " help=\"Specify ion name in a list. Default is just Potassium\",\n", " )\n", " parser.add_argument(\n", " \"--distance\",\n", " type=int,\n", " default=5.0,\n", " help=\"Specify distance (e.g., 4)\",\n", " )\n", "\n", " parser.add_argument(\n", " \"--number_of_embeddings\",\n", " type=int,\n", " default=None,\n", " required=False,\n", " help=\"Specify number of embeddings in case of importing specific dataset sizes only.\",\n", " )\n", " parser.add_argument(\n", " \"--logfile\",\n", " type=str,\n", " default=\"Build23-for-LS.log\",\n", " # required=True,\n", " help=\"Specify logfile name. Default is Build23-for-LS.log\",\n", " )\n", " parser.add_argument(\n", " \"--eta\",\n", " type=float,\n", " default=3,\n", " # required=True,\n", " help=\"Specify logfile name. Default is Build23-for-LS.log\",\n", " )\n", " return parser" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Read parser and generate ion_names_in_a_list\n", "parser = create_parser()\n", "args = parser.parse_args(\"\")\n", "embeddings_data = []\n", "ion_names_in_a_list = [\"Phosphate\", \"Sulfate\", \"Chloride\", \"Nitrate\", \"Carbonate\"]\n", "ion_symbols_in_a_list = [\"PO4\", \"SO4\", \"CL\", \"NO3\", \"CO3\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load the data into embeddings_data variable\n", "for ion_name in ion_names_in_a_list:\n", " with open(f\"../Features_extraction/esm2/Distance-{args.distance}/{ion_name}/{ion_name}BindingSiteEmbeddings-Distance{args.distance}angstroms-complete.pkl\", \"rb\") as handle:\n", " pickle_data = pickle.load(handle)\n", " embeddings_data.append(pickle_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "before: {np.str_('CL'): np.int64(12622), np.str_('CO3'): np.int64(205), np.str_('NO3'): np.int64(1654), np.str_('PO4'): np.int64(6089), np.str_('SO4'): np.int64(49914)}\n" ] } ], "source": [ "# Edit the next line of code to add/remove descriptors from the embeddings_data variable\n", "site_num, pdb_name, pdb_id, CN_atom, CN_residues, averaged_tensor_embedding, ion_data = [],[],[],[],[],[],[]\n", "stacked_tensorX0 = torch.tensor([])\n", "\n", "for i,_ in enumerate(embeddings_data):\n", " tensors_for_current_ion = []\n", " labels_for_current_ion = []\n", " for j,item in enumerate(embeddings_data[i]):\n", " site_num.append(item[0])\n", " pdb_name.append(item[1])\n", " pdb_id.append(item[2])\n", " CN_atom.append(item[3])\n", " CN_residues.append(item[4])\n", " averaged_tensor_embedding.append(item[5])\n", " tensors_for_current_ion.append(item[5])\n", " labels_for_current_ion.append(ion_symbols_in_a_list[i])\n", " ion_data.append(ion_symbols_in_a_list[i])\n", " \n", " stacked_tensorX0 = torch.cat((stacked_tensorX0, torch.stack(tensors_for_current_ion)), dim=0)\n", " \n", "unique,counts = np.unique(ion_data,return_counts=True)\n", "before = dict(zip(unique,counts))\n", "print(f\"before: {before}\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "after: {np.str_('CL'): np.int64(10535), np.str_('CO3'): np.int64(173), np.str_('NO3'): np.int64(1503), np.str_('PO4'): np.int64(5250), np.str_('SO4'): np.int64(42423)}\n", "diff: {np.str_('CL'): np.int64(2087), np.str_('CO3'): np.int64(32), np.str_('NO3'): np.int64(151), np.str_('PO4'): np.int64(839), np.str_('SO4'): np.int64(7491)}\n" ] } ], "source": [ "# Filter the unique binding sites into variables: stacked_tensorX0 and y_combined_filtered\n", "_, stacked_tensorX0_unique_indices = np.unique(stacked_tensorX0, axis=0, return_index=True)\n", "stacked_tensorX0_unique= stacked_tensorX0[sorted(stacked_tensorX0_unique_indices)]\n", "site_num_unique = np.array(site_num)[sorted(stacked_tensorX0_unique_indices)]\n", "pdb_name_unique = np.array(pdb_name)[sorted(stacked_tensorX0_unique_indices)]\n", "pdb_id_unique = np.array(pdb_id)[sorted(stacked_tensorX0_unique_indices)]\n", "CN_atom_unique = np.array(CN_atom)[sorted(stacked_tensorX0_unique_indices)]\n", "CN_residues_unique = np.array(CN_residues)[sorted(stacked_tensorX0_unique_indices)]\n", "ion_data_unique = np.array(ion_data)[sorted(stacked_tensorX0_unique_indices)]\n", "averaged_tensor_embedding_unique = np.stack(averaged_tensor_embedding)[sorted(stacked_tensorX0_unique_indices)]\n", "stacked_tensorX1 = stacked_tensorX0_unique.reshape(stacked_tensorX0_unique.shape[0], stacked_tensorX0_unique.shape[1], 1)\n", "\n", "y_combined = np.array([[0] if item == 'PO4' else [1] if item == 'SO4' else [2] if item == 'CL' else [3] if item == 'NO3' else [4] if item == 'CO3' else [5] for item in ion_data])\n", "y_combined_filtered = y_combined[sorted(stacked_tensorX0_unique_indices)]\n", "\n", "unique,counts = np.unique(ion_data_unique,return_counts=True)\n", "after = dict(zip(unique,counts))\n", "print(f\"after: {after}\")\n", "\n", "diff = {}\n", "for key in before.keys():\n", " diff[key] = before[key]-after[key]\n", "print(f\"diff: {diff}\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Class 0: 5250 samples\n", "Class 1: 42423 samples\n", "Class 2: 10535 samples\n", "Class 3: 1503 samples\n", "Class 4: 173 samples\n" ] } ], "source": [ "unique, counts = np.unique(y_combined_filtered, return_counts=True)\n", "for cls, cnt in zip(unique, counts):\n", " print(f\"Class {cls}: {cnt} samples\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Convert data to numpy if needed\n", "X_plot = stacked_tensorX0_unique.cpu().numpy() if isinstance(stacked_tensorX0_unique, torch.Tensor) else stacked_tensorX0_unique\n", "y_plot = y_combined_filtered.ravel()\n", "\n", "# Compute t-SNE\n", "print(\"Computing t-SNE...\")\n", "tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)\n", "X_tsne = tsne.fit_transform(X_plot)\n", "\n", "# Define colors for each class\n", "class_colors = {\n", " 0: \"#1f77b4\", # Blue - PO4\n", " 1: \"#ff7f0e\", # Orange - SO4\n", " 2: \"#2ca02c\", # Green - CL\n", " 3: \"#d62728\", # Red - NO3\n", " 4: \"#9467bd\", # Purple - CO3\n", "}\n", "\n", "class_names = {\n", " 0: \"PO4\",\n", " 1: \"SO4\",\n", " 2: \"CL\",\n", " 3: \"NO3\",\n", " 4: \"CO3\"\n", "}\n", "\n", "# Plot\n", "plt.figure(figsize=(10, 8), dpi=300)\n", "\n", "for cls in sorted(np.unique(y_plot)):\n", " mask = (y_plot == cls)\n", " plt.scatter(\n", " X_tsne[mask, 0],\n", " X_tsne[mask, 1],\n", " c=class_colors[cls],\n", " label=class_names[cls],\n", " s=30,\n", " alpha=0.7,\n", " edgecolors='black',\n", " linewidth=0.3\n", " )\n", "\n", "plt.xlabel(\"t-SNE 1\", fontsize=12)\n", "plt.ylabel(\"t-SNE 2\", fontsize=12)\n", "plt.title(\"t-SNE Projection of Ion Binding Sites\", fontsize=14, fontweight='bold')\n", "plt.legend(title=\"Ion Type\", fontsize=10, title_fontsize=11)\n", "plt.grid(True, alpha=0.3)\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"✅ t-SNE plot completed!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SECTION A: Save dataset.pt (NO SMOTE)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original dataset shape: (59884, 1280) (59884,)\n", "Class distribution: {np.int64(0): np.int64(5250), np.int64(1): np.int64(42423), np.int64(2): np.int64(10535), np.int64(3): np.int64(1503), np.int64(4): np.int64(173)}\n", "✅ Saved NO-SMOTE dataset to: ../Data_split/multiclass_data_no_SMOTE\n", "Train shape: torch.Size([47907, 1280]) torch.Size([47907])\n", "Val shape: torch.Size([5988, 1280]) torch.Size([5988])\n", "Test shape: torch.Size([5989, 1280]) torch.Size([5989])\n", "Train class distribution: {np.int64(0): np.int64(4200), np.int64(1): np.int64(33938), np.int64(2): np.int64(8428), np.int64(3): np.int64(1202), np.int64(4): np.int64(139)}\n", "Val class distribution: {np.int64(0): np.int64(525), np.int64(1): np.int64(4242), np.int64(2): np.int64(1053), np.int64(3): np.int64(151), np.int64(4): np.int64(17)}\n", "Test class distribution: {np.int64(0): np.int64(525), np.int64(1): np.int64(4243), np.int64(2): np.int64(1054), np.int64(3): np.int64(150), np.int64(4): np.int64(17)}\n" ] } ], "source": [ "# ------------------------------------------------------------\n", "# Load dataset\n", "# ------------------------------------------------------------\n", "features = stacked_tensorX0_unique\n", "labels = y_combined_filtered\n", "\n", "# Convert to numpy for splitting\n", "if isinstance(features, torch.Tensor):\n", " X = features.detach().cpu().numpy()\n", "else:\n", " X = np.asarray(features)\n", "\n", "y = np.asarray(labels).ravel().astype(int)\n", "\n", "print(\"Original dataset shape:\", X.shape, y.shape)\n", "print(\"Class distribution:\", dict(zip(*np.unique(y, return_counts=True))))\n", "\n", "# 80-10-10 split\n", "X_train, X_temp, y_train, y_temp = train_test_split(\n", " X, y,\n", " test_size=0.2,\n", " shuffle=True,\n", " stratify=y,\n", " random_state=42\n", ")\n", "\n", "X_val, X_test, y_val, y_test = train_test_split(\n", " X_temp, y_temp,\n", " test_size=0.5,\n", " shuffle=True,\n", " stratify=y_temp,\n", " random_state=42\n", ")\n", "\n", "# Convert back to torch tensors\n", "X_train = torch.tensor(X_train, dtype=torch.float32)\n", "X_val = torch.tensor(X_val, dtype=torch.float32)\n", "X_test = torch.tensor(X_test, dtype=torch.float32)\n", "\n", "y_train = torch.tensor(y_train, dtype=torch.long)\n", "y_val = torch.tensor(y_val, dtype=torch.long)\n", "y_test = torch.tensor(y_test, dtype=torch.long)\n", "\n", "# Save location\n", "data_save_location = \"../Data/multiclass_data_no_SMOTE\"\n", "check_dir(data_save_location)\n", "\n", "torch.save({'X_train': X_train, 'y_train': y_train}, f'{data_save_location}/training_data.pt')\n", "torch.save({'X_val': X_val, 'y_val': y_val}, f'{data_save_location}/validation_data.pt')\n", "torch.save({'X_test': X_test, 'y_test': y_test}, f'{data_save_location}/test_data.pt')\n", "\n", "print(\"✅ Saved NO-SMOTE dataset to:\", data_save_location)\n", "print(\"Train shape:\", X_train.shape, y_train.shape)\n", "print(\"Val shape: \", X_val.shape, y_val.shape)\n", "print(\"Test shape: \", X_test.shape, y_test.shape)\n", "\n", "print(\"Train class distribution:\", dict(zip(*np.unique(y_train.numpy(), return_counts=True))))\n", "print(\"Val class distribution: \", dict(zip(*np.unique(y_val.numpy(), return_counts=True))))\n", "print(\"Test class distribution: \", dict(zip(*np.unique(y_test.numpy(), return_counts=True))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SECTION B: SMOTE BEFORE SPLIT + is_synthetic LABEL" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original shape: (59884, 1280)\n", "Original class distribution: {np.int64(0): np.int64(5250), np.int64(1): np.int64(42423), np.int64(2): np.int64(10535), np.int64(3): np.int64(1503), np.int64(4): np.int64(173)}\n", "After SMOTE shape: (212115, 1280)\n", "Synthetic samples: 152231\n", "✅ Saved SMOTE dataset with synthetic labels\n", "\n", "Train synthetic count: 121768\n", "Val synthetic count: 15232\n", "Test synthetic count: 15231\n" ] } ], "source": [ "# ------------------------------------------------------------\n", "# Load dataset\n", "# ------------------------------------------------------------\n", "features = stacked_tensorX0_unique\n", "labels = y_combined_filtered\n", "\n", "# Convert to numpy\n", "if isinstance(features, torch.Tensor):\n", " X = features.detach().cpu().numpy()\n", "else:\n", " X = np.asarray(features)\n", "\n", "y = np.asarray(labels).ravel().astype(int)\n", "\n", "print(\"Original shape:\", X.shape)\n", "print(\"Original class distribution:\", dict(zip(*np.unique(y, return_counts=True))))\n", "\n", "# ------------------------------------------------------------\n", "# Apply SMOTE\n", "# ------------------------------------------------------------\n", "smote = SMOTE(random_state=42)\n", "X_smote, y_smote = smote.fit_resample(X, y)\n", "\n", "# ------------------------------------------------------------\n", "# Create is_synthetic flag\n", "# ------------------------------------------------------------\n", "n_original = X.shape[0]\n", "n_total = X_smote.shape[0]\n", "\n", "is_synthetic = np.zeros(n_total, dtype=int)\n", "is_synthetic[n_original:] = 1 # synthetic samples\n", "\n", "print(\"After SMOTE shape:\", X_smote.shape)\n", "print(\"Synthetic samples:\", np.sum(is_synthetic))\n", "\n", "# ------------------------------------------------------------\n", "# Split (80-10-10)\n", "# ------------------------------------------------------------\n", "X_train, X_temp, y_train, y_temp, syn_train, syn_temp = train_test_split(\n", " X_smote, y_smote, is_synthetic,\n", " test_size=0.2,\n", " stratify=y_smote,\n", " random_state=42\n", ")\n", "\n", "X_val, X_test, y_val, y_test, syn_val, syn_test = train_test_split(\n", " X_temp, y_temp, syn_temp,\n", " test_size=0.5,\n", " stratify=y_temp,\n", " random_state=42\n", ")\n", "\n", "# ------------------------------------------------------------\n", "# Convert to torch\n", "# ------------------------------------------------------------\n", "X_train = torch.tensor(X_train, dtype=torch.float32)\n", "X_val = torch.tensor(X_val, dtype=torch.float32)\n", "X_test = torch.tensor(X_test, dtype=torch.float32)\n", "\n", "y_train = torch.tensor(y_train, dtype=torch.long)\n", "y_val = torch.tensor(y_val, dtype=torch.long)\n", "y_test = torch.tensor(y_test, dtype=torch.long)\n", "\n", "syn_train = torch.tensor(syn_train, dtype=torch.long)\n", "syn_val = torch.tensor(syn_val, dtype=torch.long)\n", "syn_test = torch.tensor(syn_test, dtype=torch.long)\n", "\n", "# ------------------------------------------------------------\n", "# Save\n", "# ------------------------------------------------------------\n", "data_save_location = \"../Data/multiclass_data_SMOTE\"\n", "check_dir(data_save_location)\n", "\n", "torch.save({\n", " 'X_train': X_train,\n", " 'y_train': y_train,\n", " 'is_synthetic': syn_train\n", "}, f'{data_save_location}/training_data.pt')\n", "\n", "torch.save({\n", " 'X_val': X_val,\n", " 'y_val': y_val,\n", " 'is_synthetic': syn_val\n", "}, f'{data_save_location}/validation_data.pt')\n", "\n", "torch.save({\n", " 'X_test': X_test,\n", " 'y_test': y_test,\n", " 'is_synthetic': syn_test\n", "}, f'{data_save_location}/test_data.pt')\n", "\n", "# ------------------------------------------------------------\n", "# Debug prints\n", "# ------------------------------------------------------------\n", "print(\"✅ Saved SMOTE dataset with synthetic labels\")\n", "\n", "print(\"\\nTrain synthetic count:\", syn_train.sum().item())\n", "print(\"Val synthetic count: \", syn_val.sum().item())\n", "print(\"Test synthetic count: \", syn_test.sum().item())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SECTION B: SPLIT FIRST, THEN APPLY SMOTE ONLY ON TRAIN + is_synthetic LABEL" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original shape: (59884, 1280)\n", "Original class distribution: {np.int64(0): np.int64(5250), np.int64(1): np.int64(42423), np.int64(2): np.int64(10535), np.int64(3): np.int64(1503), np.int64(4): np.int64(173)}\n", "\n", "Before SMOTE:\n", "Train shape: (47907, 1280) | class distribution: {np.int64(0): np.int64(4200), np.int64(1): np.int64(33938), np.int64(2): np.int64(8428), np.int64(3): np.int64(1202), np.int64(4): np.int64(139)}\n", "Val shape: (5988, 1280) | class distribution: {np.int64(0): np.int64(525), np.int64(1): np.int64(4242), np.int64(2): np.int64(1053), np.int64(3): np.int64(151), np.int64(4): np.int64(17)}\n", "Test shape: (5989, 1280) | class distribution: {np.int64(0): np.int64(525), np.int64(1): np.int64(4243), np.int64(2): np.int64(1054), np.int64(3): np.int64(150), np.int64(4): np.int64(17)}\n", "\n", "After SMOTE on train only:\n", "Train shape: (169690, 1280)\n", "Train class distribution: {np.int64(0): np.int64(33938), np.int64(1): np.int64(33938), np.int64(2): np.int64(33938), np.int64(3): np.int64(33938), np.int64(4): np.int64(33938)}\n", "Synthetic train samples: 121783\n", "Synthetic val samples: 0\n", "Synthetic test samples: 0\n", "\n", "✅ Saved split-first, SMOTE-on-train-only dataset with is_synthetic flag\n", "Saved to: ../Data_split/multiclass_data_train_SMOTE_only_with_flag\n", "\n", "Torch shapes:\n", "Train: torch.Size([169690, 1280]) torch.Size([169690]) torch.Size([169690])\n", "Val: torch.Size([5988, 1280]) torch.Size([5988]) torch.Size([5988])\n", "Test: torch.Size([5989, 1280]) torch.Size([5989]) torch.Size([5989])\n" ] } ], "source": [ "# ------------------------------------------------------------\n", "# Load dataset\n", "# ------------------------------------------------------------\n", "features = stacked_tensorX0_unique\n", "labels = y_combined_filtered\n", "\n", "# Convert to numpy\n", "if isinstance(features, torch.Tensor):\n", " X = features.detach().cpu().numpy()\n", "else:\n", " X = np.asarray(features)\n", "\n", "y = np.asarray(labels).ravel().astype(int)\n", "\n", "print(\"Original shape:\", X.shape)\n", "print(\"Original class distribution:\", dict(zip(*np.unique(y, return_counts=True))))\n", "\n", "# ------------------------------------------------------------\n", "# Split first (80-10-10) on ORIGINAL data\n", "# ------------------------------------------------------------\n", "X_train, X_temp, y_train, y_temp = train_test_split(\n", " X, y,\n", " test_size=0.2,\n", " shuffle=True,\n", " stratify=y,\n", " random_state=42\n", ")\n", "\n", "X_val, X_test, y_val, y_test = train_test_split(\n", " X_temp, y_temp,\n", " test_size=0.5,\n", " shuffle=True,\n", " stratify=y_temp,\n", " random_state=42\n", ")\n", "\n", "print(\"\\nBefore SMOTE:\")\n", "print(\"Train shape:\", X_train.shape, \"| class distribution:\", dict(zip(*np.unique(y_train, return_counts=True))))\n", "print(\"Val shape: \", X_val.shape, \"| class distribution:\", dict(zip(*np.unique(y_val, return_counts=True))))\n", "print(\"Test shape: \", X_test.shape, \"| class distribution:\", dict(zip(*np.unique(y_test, return_counts=True))))\n", "\n", "# ------------------------------------------------------------\n", "# Apply SMOTE ONLY on training set\n", "# ------------------------------------------------------------\n", "smote = SMOTE(random_state=42)\n", "X_train_smote, y_train_smote = smote.fit_resample(X_train, y_train)\n", "\n", "# ------------------------------------------------------------\n", "# Create is_synthetic flag\n", "# Train: original samples = 0, SMOTE-generated = 1\n", "# Val/Test: all original = 0\n", "# ------------------------------------------------------------\n", "n_train_original = X_train.shape[0]\n", "n_train_total = X_train_smote.shape[0]\n", "\n", "is_synthetic_train = np.zeros(n_train_total, dtype=int)\n", "is_synthetic_train[n_train_original:] = 1\n", "\n", "is_synthetic_val = np.zeros(len(y_val), dtype=int)\n", "is_synthetic_test = np.zeros(len(y_test), dtype=int)\n", "\n", "print(\"\\nAfter SMOTE on train only:\")\n", "print(\"Train shape:\", X_train_smote.shape)\n", "print(\"Train class distribution:\", dict(zip(*np.unique(y_train_smote, return_counts=True))))\n", "print(\"Synthetic train samples:\", np.sum(is_synthetic_train))\n", "print(\"Synthetic val samples:\", np.sum(is_synthetic_val))\n", "print(\"Synthetic test samples:\", np.sum(is_synthetic_test))\n", "\n", "# ------------------------------------------------------------\n", "# Convert to torch tensors\n", "# ------------------------------------------------------------\n", "X_train_smote = torch.tensor(X_train_smote, dtype=torch.float32)\n", "X_val = torch.tensor(X_val, dtype=torch.float32)\n", "X_test = torch.tensor(X_test, dtype=torch.float32)\n", "\n", "y_train_smote = torch.tensor(y_train_smote, dtype=torch.long)\n", "y_val = torch.tensor(y_val, dtype=torch.long)\n", "y_test = torch.tensor(y_test, dtype=torch.long)\n", "\n", "is_synthetic_train = torch.tensor(is_synthetic_train, dtype=torch.long)\n", "is_synthetic_val = torch.tensor(is_synthetic_val, dtype=torch.long)\n", "is_synthetic_test = torch.tensor(is_synthetic_test, dtype=torch.long)\n", "\n", "# ------------------------------------------------------------\n", "# Save\n", "# ------------------------------------------------------------\n", "data_save_location = \"../Data/multiclass_data_SMOTE_on_train_only\"\n", "check_dir(data_save_location)\n", "\n", "torch.save({\n", " 'X_train': X_train_smote,\n", " 'y_train': y_train_smote,\n", " 'is_synthetic': is_synthetic_train\n", "}, f'{data_save_location}/training_data.pt')\n", "\n", "torch.save({\n", " 'X_val': X_val,\n", " 'y_val': y_val,\n", " 'is_synthetic': is_synthetic_val\n", "}, f'{data_save_location}/validation_data.pt')\n", "\n", "torch.save({\n", " 'X_test': X_test,\n", " 'y_test': y_test,\n", " 'is_synthetic': is_synthetic_test\n", "}, f'{data_save_location}/test_data.pt')\n", "\n", "# ------------------------------------------------------------\n", "# Final summary\n", "# ------------------------------------------------------------\n", "print(\"\\n✅ Saved split-first, SMOTE-on-train-only dataset with is_synthetic flag\")\n", "print(\"Saved to:\", data_save_location)\n", "\n", "print(\"\\nTorch shapes:\")\n", "print(\"Train:\", X_train_smote.shape, y_train_smote.shape, is_synthetic_train.shape)\n", "print(\"Val: \", X_val.shape, y_val.shape, is_synthetic_val.shape)\n", "print(\"Test: \", X_test.shape, y_test.shape, is_synthetic_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SECTION C: SPLIT FIRST, THEN APPLY BORDERLINE-SMOTE ONLY ON TRAIN + is_synthetic LABEL" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original shape: (59884, 1280)\n", "Original class distribution: {np.int64(0): np.int64(5250), np.int64(1): np.int64(42423), np.int64(2): np.int64(10535), np.int64(3): np.int64(1503), np.int64(4): np.int64(173)}\n", "\n", "Before Borderline-SMOTE:\n", "Train shape: (47907, 1280) | class distribution: {np.int64(0): np.int64(4200), np.int64(1): np.int64(33938), np.int64(2): np.int64(8428), np.int64(3): np.int64(1202), np.int64(4): np.int64(139)}\n", "Val shape: (5988, 1280) | class distribution: {np.int64(0): np.int64(525), np.int64(1): np.int64(4242), np.int64(2): np.int64(1053), np.int64(3): np.int64(151), np.int64(4): np.int64(17)}\n", "Test shape: (5989, 1280) | class distribution: {np.int64(0): np.int64(525), np.int64(1): np.int64(4243), np.int64(2): np.int64(1054), np.int64(3): np.int64(150), np.int64(4): np.int64(17)}\n", "\n", "After Borderline-SMOTE on train only:\n", "Train shape: (169690, 1280)\n", "Train class distribution: {np.int64(0): np.int64(33938), np.int64(1): np.int64(33938), np.int64(2): np.int64(33938), np.int64(3): np.int64(33938), np.int64(4): np.int64(33938)}\n", "Synthetic train samples: 121783\n", "Synthetic val samples: 0\n", "Synthetic test samples: 0\n", "\n", "✅ Saved split-first, Borderline-SMOTE-on-train-only dataset with is_synthetic flag\n", "Saved to: ../Data_split/multiclass_data_train_BorderlineSMOTE_only_with_flag\n", "\n", "Torch shapes:\n", "Train: torch.Size([169690, 1280]) torch.Size([169690]) torch.Size([169690])\n", "Val: torch.Size([5988, 1280]) torch.Size([5988]) torch.Size([5988])\n", "Test: torch.Size([5989, 1280]) torch.Size([5989]) torch.Size([5989])\n" ] } ], "source": [ "# ------------------------------------------------------------\n", "# Load dataset\n", "# ------------------------------------------------------------\n", "features = stacked_tensorX0_unique\n", "labels = y_combined_filtered\n", "\n", "# Convert to numpy\n", "if isinstance(features, torch.Tensor):\n", " X = features.detach().cpu().numpy()\n", "else:\n", " X = np.asarray(features)\n", "\n", "y = np.asarray(labels).ravel().astype(int)\n", "\n", "print(\"Original shape:\", X.shape)\n", "print(\"Original class distribution:\", dict(zip(*np.unique(y, return_counts=True))))\n", "\n", "# ------------------------------------------------------------\n", "# Split first (80-10-10) on ORIGINAL data\n", "# ------------------------------------------------------------\n", "X_train, X_temp, y_train, y_temp = train_test_split(\n", " X, y,\n", " test_size=0.2,\n", " shuffle=True,\n", " stratify=y,\n", " random_state=42\n", ")\n", "\n", "X_val, X_test, y_val, y_test = train_test_split(\n", " X_temp, y_temp,\n", " test_size=0.5,\n", " shuffle=True,\n", " stratify=y_temp,\n", " random_state=42\n", ")\n", "\n", "print(\"\\nBefore Borderline-SMOTE:\")\n", "print(\"Train shape:\", X_train.shape, \"| class distribution:\", dict(zip(*np.unique(y_train, return_counts=True))))\n", "print(\"Val shape: \", X_val.shape, \"| class distribution:\", dict(zip(*np.unique(y_val, return_counts=True))))\n", "print(\"Test shape: \", X_test.shape, \"| class distribution:\", dict(zip(*np.unique(y_test, return_counts=True))))\n", "\n", "# ------------------------------------------------------------\n", "# Apply Borderline-SMOTE ONLY on training set\n", "# kind='borderline-1' is the usual default choice\n", "# You can change to kind='borderline-2' if needed\n", "# ------------------------------------------------------------\n", "bsmote = BorderlineSMOTE(\n", " kind='borderline-1',\n", " random_state=42\n", ")\n", "X_train_bsmote, y_train_bsmote = bsmote.fit_resample(X_train, y_train)\n", "\n", "# ------------------------------------------------------------\n", "# Create is_synthetic flag\n", "# Train: original samples = 0, generated = 1\n", "# Val/Test: all original = 0\n", "# ------------------------------------------------------------\n", "n_train_original = X_train.shape[0]\n", "n_train_total = X_train_bsmote.shape[0]\n", "\n", "is_synthetic_train = np.zeros(n_train_total, dtype=int)\n", "is_synthetic_train[n_train_original:] = 1\n", "\n", "is_synthetic_val = np.zeros(len(y_val), dtype=int)\n", "is_synthetic_test = np.zeros(len(y_test), dtype=int)\n", "\n", "print(\"\\nAfter Borderline-SMOTE on train only:\")\n", "print(\"Train shape:\", X_train_bsmote.shape)\n", "print(\"Train class distribution:\", dict(zip(*np.unique(y_train_bsmote, return_counts=True))))\n", "print(\"Synthetic train samples:\", np.sum(is_synthetic_train))\n", "print(\"Synthetic val samples:\", np.sum(is_synthetic_val))\n", "print(\"Synthetic test samples:\", np.sum(is_synthetic_test))\n", "\n", "# ------------------------------------------------------------\n", "# Convert to torch tensors\n", "# ------------------------------------------------------------\n", "X_train_bsmote = torch.tensor(X_train_bsmote, dtype=torch.float32)\n", "X_val = torch.tensor(X_val, dtype=torch.float32)\n", "X_test = torch.tensor(X_test, dtype=torch.float32)\n", "\n", "y_train_bsmote = torch.tensor(y_train_bsmote, dtype=torch.long)\n", "y_val = torch.tensor(y_val, dtype=torch.long)\n", "y_test = torch.tensor(y_test, dtype=torch.long)\n", "\n", "is_synthetic_train = torch.tensor(is_synthetic_train, dtype=torch.long)\n", "is_synthetic_val = torch.tensor(is_synthetic_val, dtype=torch.long)\n", "is_synthetic_test = torch.tensor(is_synthetic_test, dtype=torch.long)\n", "\n", "# ------------------------------------------------------------\n", "# Save\n", "# ------------------------------------------------------------\n", "data_save_location = \"../Data/multiclass_data_BorderlineSMOTE_on_train_only\"\n", "check_dir(data_save_location)\n", "\n", "torch.save({\n", " 'X_train': X_train_bsmote,\n", " 'y_train': y_train_bsmote,\n", " 'is_synthetic': is_synthetic_train\n", "}, f'{data_save_location}/training_data.pt')\n", "\n", "torch.save({\n", " 'X_val': X_val,\n", " 'y_val': y_val,\n", " 'is_synthetic': is_synthetic_val\n", "}, f'{data_save_location}/validation_data.pt')\n", "\n", "torch.save({\n", " 'X_test': X_test,\n", " 'y_test': y_test,\n", " 'is_synthetic': is_synthetic_test\n", "}, f'{data_save_location}/test_data.pt')\n", "\n", "# ------------------------------------------------------------\n", "# Final summary\n", "# ------------------------------------------------------------\n", "print(\"\\n✅ Saved split-first, Borderline-SMOTE-on-train-only dataset with is_synthetic flag\")\n", "print(\"Saved to:\", data_save_location)\n", "\n", "print(\"\\nTorch shapes:\")\n", "print(\"Train:\", X_train_bsmote.shape, y_train_bsmote.shape, is_synthetic_train.shape)\n", "print(\"Val: \", X_val.shape, y_val.shape, is_synthetic_val.shape)\n", "print(\"Test: \", X_test.shape, y_test.shape, is_synthetic_test.shape)" ] } ], "metadata": { "kernelspec": { "display_name": "pp_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }