Upload 2 files
Browse filesadded evaluation pipeline and model weights
- ensemble_checkpoint.pth +3 -0
- final_ensemble_pipeline.ipynb +675 -0
ensemble_checkpoint.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb5421224d0381aec5065db260c258b22e27594a432b2234d4ebb1e925c53589
|
| 3 |
+
size 34882689
|
final_ensemble_pipeline.ipynb
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "74bd5ceb-afa1-4bfd-ba39-10af717cf2a5",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"Remember to change the test and model Path!\n",
|
| 9 |
+
"Since I'm using Embedding to encode headlines to vector, it takes 10+ min. to encode information for test set which I cannot do it on my end since I do not have access to hiddne test set! "
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": 1,
|
| 15 |
+
"id": "a458f2b7-3ab1-479f-9627-ef7ef8ef76b4",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"outputs": [],
|
| 18 |
+
"source": [
|
| 19 |
+
"import torch\n",
|
| 20 |
+
"import torch.nn as nn\n",
|
| 21 |
+
"import torch.optim as optim\n",
|
| 22 |
+
"from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler\n",
|
| 23 |
+
"from tqdm import tqdm\n",
|
| 24 |
+
"import numpy as np\n",
|
| 25 |
+
"import random\n",
|
| 26 |
+
"import os\n",
|
| 27 |
+
"import copy\n",
|
| 28 |
+
"from torch.utils.data import TensorDataset\n",
|
| 29 |
+
"import pandas as pd"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 2,
|
| 35 |
+
"id": "d7943628-3454-4d21-a95d-ca53acd9b6dc",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"class LabelSmoothingBCELoss(nn.Module):\n",
|
| 40 |
+
" def __init__(self, smoothing=0.1):\n",
|
| 41 |
+
" \"\"\"\n",
|
| 42 |
+
" Label Smoothing Binary Cross Entropy Loss\n",
|
| 43 |
+
" \n",
|
| 44 |
+
" Args:\n",
|
| 45 |
+
" smoothing (float): Amount of label smoothing to apply\n",
|
| 46 |
+
" \"\"\"\n",
|
| 47 |
+
" super(LabelSmoothingBCELoss, self).__init__()\n",
|
| 48 |
+
" self.smoothing = smoothing\n",
|
| 49 |
+
" \n",
|
| 50 |
+
" def forward(self, predictions, targets):\n",
|
| 51 |
+
" \"\"\"\n",
|
| 52 |
+
" Compute label-smoothed binary cross entropy loss\n",
|
| 53 |
+
" \n",
|
| 54 |
+
" Args:\n",
|
| 55 |
+
" predictions (torch.Tensor): Model predictions\n",
|
| 56 |
+
" targets (torch.Tensor): Binary labels\n",
|
| 57 |
+
" \n",
|
| 58 |
+
" Returns:\n",
|
| 59 |
+
" torch.Tensor: Smoothed loss\n",
|
| 60 |
+
" \"\"\"\n",
|
| 61 |
+
" # Apply label smoothing\n",
|
| 62 |
+
" smooth_targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing\n",
|
| 63 |
+
" \n",
|
| 64 |
+
" # Standard Binary Cross Entropy Loss\n",
|
| 65 |
+
" loss = nn.functional.binary_cross_entropy(predictions, smooth_targets)\n",
|
| 66 |
+
" \n",
|
| 67 |
+
" return loss\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"class EarlyStoppingCallback:\n",
|
| 70 |
+
" def __init__(self, patience=5, min_delta=0.001):\n",
|
| 71 |
+
" \"\"\"\n",
|
| 72 |
+
" Early stopping mechanism\n",
|
| 73 |
+
" \n",
|
| 74 |
+
" Args:\n",
|
| 75 |
+
" patience (int): Number of epochs to wait for improvement\n",
|
| 76 |
+
" min_delta (float): Minimum change to qualify as an improvement\n",
|
| 77 |
+
" \"\"\"\n",
|
| 78 |
+
" self.patience = patience\n",
|
| 79 |
+
" self.min_delta = min_delta\n",
|
| 80 |
+
" self.counter = 0\n",
|
| 81 |
+
" self.best_loss = float('inf')\n",
|
| 82 |
+
" self.early_stop = False\n",
|
| 83 |
+
" self.best_model_state = None\n",
|
| 84 |
+
" \n",
|
| 85 |
+
" def __call__(self, val_loss, model):\n",
|
| 86 |
+
" \"\"\"\n",
|
| 87 |
+
" Check if training should stop\n",
|
| 88 |
+
" \n",
|
| 89 |
+
" Args:\n",
|
| 90 |
+
" val_loss (float): Current validation loss\n",
|
| 91 |
+
" model (nn.Module): Current model state\n",
|
| 92 |
+
" \n",
|
| 93 |
+
" Returns:\n",
|
| 94 |
+
" bool: Whether to stop training\n",
|
| 95 |
+
" \"\"\"\n",
|
| 96 |
+
" if val_loss < self.best_loss - self.min_delta:\n",
|
| 97 |
+
" self.best_loss = val_loss\n",
|
| 98 |
+
" self.counter = 0\n",
|
| 99 |
+
" # Save the best model state\n",
|
| 100 |
+
" self.best_model_state = copy.deepcopy(model.state_dict())\n",
|
| 101 |
+
" else:\n",
|
| 102 |
+
" self.counter += 1\n",
|
| 103 |
+
" if self.counter >= self.patience:\n",
|
| 104 |
+
" self.early_stop = True\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" return self.early_stop\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"class EnsembleMLPClassifier(nn.Module):\n",
|
| 109 |
+
" def __init__(self, \n",
|
| 110 |
+
" input_dim=1024, # BGE embedding dimension\n",
|
| 111 |
+
" hidden_layers=None,\n",
|
| 112 |
+
" dropout_rate=0.2,\n",
|
| 113 |
+
" activation=nn.ReLU(), # Allow passing activation functions dynamically\n",
|
| 114 |
+
" device=None):\n",
|
| 115 |
+
" super(EnsembleMLPClassifier, self).__init__()\n",
|
| 116 |
+
" \n",
|
| 117 |
+
" # Default configuration if not provided\n",
|
| 118 |
+
" if hidden_layers is None:\n",
|
| 119 |
+
" hidden_layers = [512, 256, 128]\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" # Set device (GPU if available, else CPU)\n",
|
| 122 |
+
" self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 123 |
+
" \n",
|
| 124 |
+
" # Store initialization parameters\n",
|
| 125 |
+
" self.input_dim = input_dim\n",
|
| 126 |
+
" self.hidden_layers = hidden_layers\n",
|
| 127 |
+
" self.dropout_rate = dropout_rate\n",
|
| 128 |
+
" self.activation = activation\n",
|
| 129 |
+
" \n",
|
| 130 |
+
" # Add linear gate mechanism\n",
|
| 131 |
+
" self.gate = nn.Linear(input_dim, input_dim, bias=False)\n",
|
| 132 |
+
" \n",
|
| 133 |
+
" # Create layers dynamically based on hidden_layers specification\n",
|
| 134 |
+
" layers = []\n",
|
| 135 |
+
" prev_dim = input_dim\n",
|
| 136 |
+
" for hidden_dim in hidden_layers:\n",
|
| 137 |
+
" # Dense Layer with dynamic activation and BatchNorm\n",
|
| 138 |
+
" layers.extend([\n",
|
| 139 |
+
" nn.Linear(prev_dim, hidden_dim),\n",
|
| 140 |
+
" nn.BatchNorm1d(hidden_dim),\n",
|
| 141 |
+
" activation,\n",
|
| 142 |
+
" nn.Dropout(dropout_rate)\n",
|
| 143 |
+
" ])\n",
|
| 144 |
+
" prev_dim = hidden_dim\n",
|
| 145 |
+
" \n",
|
| 146 |
+
" # Final output layer for binary classification\n",
|
| 147 |
+
" layers.append(nn.Linear(prev_dim, 1))\n",
|
| 148 |
+
" layers.append(nn.Sigmoid())\n",
|
| 149 |
+
" \n",
|
| 150 |
+
" # Create the model and move to device\n",
|
| 151 |
+
" self.model = nn.Sequential(*layers)\n",
|
| 152 |
+
" self.to(self.device)\n",
|
| 153 |
+
"\n",
|
| 154 |
+
" def forward(self, x):\n",
|
| 155 |
+
" \"\"\"Forward pass through the network\"\"\"\n",
|
| 156 |
+
" # Apply gating mechanism\n",
|
| 157 |
+
" x = self.gate(x) * x\n",
|
| 158 |
+
" return self.model(x)\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"class EnsembleClassifier:\n",
|
| 161 |
+
" def __init__(self, num_models=5, label_smoothing=0.1):\n",
|
| 162 |
+
" self.models = self._create_diverse_models(num_models)\n",
|
| 163 |
+
" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 164 |
+
" self.label_smoothing = label_smoothing\n",
|
| 165 |
+
" self.model_weights = None \n",
|
| 166 |
+
" \n",
|
| 167 |
+
" def _create_diverse_models(self, num_models):\n",
|
| 168 |
+
" models = []\n",
|
| 169 |
+
" \n",
|
| 170 |
+
" # Predefined configurations for consistency across runs\n",
|
| 171 |
+
" architectures = [\n",
|
| 172 |
+
" {'hidden_layers': [512, 256, 128], 'dropout_rate': 0.2, 'activation': nn.ReLU()},\n",
|
| 173 |
+
" {'hidden_layers': [1024, 512], 'dropout_rate': 0.3, 'activation': nn.LeakyReLU()},\n",
|
| 174 |
+
" {'hidden_layers': [256, 128, 64], 'dropout_rate': 0.1, 'activation': nn.GELU()},\n",
|
| 175 |
+
" {'hidden_layers': [512, 128], 'dropout_rate': 0.25, 'activation': nn.SELU()},\n",
|
| 176 |
+
" {'hidden_layers': [256, 128], 'dropout_rate': 0.15, 'activation': nn.Tanh()}\n",
|
| 177 |
+
" ]\n",
|
| 178 |
+
" \n",
|
| 179 |
+
" # Optimizer strategies\n",
|
| 180 |
+
" optimizers = [optim.Adam, optim.AdamW, optim.SGD]\n",
|
| 181 |
+
" \n",
|
| 182 |
+
" for i in range(num_models):\n",
|
| 183 |
+
" # Use predefined architectures in a consistent order\n",
|
| 184 |
+
" config = architectures[i % len(architectures)]\n",
|
| 185 |
+
" optimizer_fn = optimizers[i % len(optimizers)]\n",
|
| 186 |
+
" \n",
|
| 187 |
+
" model = EnsembleMLPClassifier(\n",
|
| 188 |
+
" input_dim=1024,\n",
|
| 189 |
+
" hidden_layers=config['hidden_layers'],\n",
|
| 190 |
+
" dropout_rate=config['dropout_rate'],\n",
|
| 191 |
+
" activation=config['activation']\n",
|
| 192 |
+
" )\n",
|
| 193 |
+
" \n",
|
| 194 |
+
" # Custom weight initialization\n",
|
| 195 |
+
" def init_weights(m):\n",
|
| 196 |
+
" if isinstance(m, nn.Linear):\n",
|
| 197 |
+
" init_methods = [\n",
|
| 198 |
+
" nn.init.xavier_uniform_,\n",
|
| 199 |
+
" nn.init.kaiming_normal_,\n",
|
| 200 |
+
" nn.init.orthogonal_\n",
|
| 201 |
+
" ]\n",
|
| 202 |
+
" init_method = init_methods[i % len(init_methods)] # Consistent initialization\n",
|
| 203 |
+
" init_method(m.weight)\n",
|
| 204 |
+
" if m.bias is not None:\n",
|
| 205 |
+
" nn.init.zeros_(m.bias)\n",
|
| 206 |
+
" \n",
|
| 207 |
+
" model.model.apply(init_weights)\n",
|
| 208 |
+
" \n",
|
| 209 |
+
" # Attach optimizer to model instance for flexibility\n",
|
| 210 |
+
" model.optimizer_fn = optimizer_fn\n",
|
| 211 |
+
" \n",
|
| 212 |
+
" # Add L2 regularization to the model (Weight Decay)\n",
|
| 213 |
+
" model.regularization = {\n",
|
| 214 |
+
" 'weight_decay': 1e-5 # Example regularization value\n",
|
| 215 |
+
" }\n",
|
| 216 |
+
" \n",
|
| 217 |
+
" models.append(model)\n",
|
| 218 |
+
" \n",
|
| 219 |
+
" return models\n",
|
| 220 |
+
" \n",
|
| 221 |
+
" def train(self, train_dataset, batch_size=32, num_epochs=20):\n",
|
| 222 |
+
" for model_idx, model in enumerate(tqdm(self.models, desc=\"Training Models\", position=0)):\n",
|
| 223 |
+
" print(f\"Starting training for Model {model_idx + 1}/{len(self.models)}\")\n",
|
| 224 |
+
" \n",
|
| 225 |
+
" # Randomly split 80% for training and 20% for validation\n",
|
| 226 |
+
" total_size = len(train_dataset)\n",
|
| 227 |
+
" train_size = int(0.8 * total_size)\n",
|
| 228 |
+
" val_size = total_size - train_size\n",
|
| 229 |
+
" \n",
|
| 230 |
+
" train_subset, val_subset = random_split(train_dataset, [train_size, val_size])\n",
|
| 231 |
+
" \n",
|
| 232 |
+
" # Create data loaders for training and validation\n",
|
| 233 |
+
" train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)\n",
|
| 234 |
+
" val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)\n",
|
| 235 |
+
" \n",
|
| 236 |
+
" # Optimizer with learning rate scheduler\n",
|
| 237 |
+
" optimizer = optim.AdamW(model.parameters(), lr=1e-3)\n",
|
| 238 |
+
" scheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
|
| 239 |
+
" optimizer, \n",
|
| 240 |
+
" T_max=num_epochs, \n",
|
| 241 |
+
" eta_min=1e-5\n",
|
| 242 |
+
" )\n",
|
| 243 |
+
" \n",
|
| 244 |
+
" # Label Smoothing Loss\n",
|
| 245 |
+
" criterion = LabelSmoothingBCELoss(smoothing=self.label_smoothing)\n",
|
| 246 |
+
" \n",
|
| 247 |
+
" # Early stopping\n",
|
| 248 |
+
" early_stopping = EarlyStoppingCallback(patience=4, min_delta=0.001)\n",
|
| 249 |
+
" \n",
|
| 250 |
+
" model.train()\n",
|
| 251 |
+
" epoch_progress = tqdm(range(num_epochs), desc=f\"Model {model_idx} Training\", position=1, leave=False)\n",
|
| 252 |
+
" \n",
|
| 253 |
+
" best_val_loss = float('inf')\n",
|
| 254 |
+
" for epoch in epoch_progress:\n",
|
| 255 |
+
" total_loss = 0\n",
|
| 256 |
+
" \n",
|
| 257 |
+
" # Training phase\n",
|
| 258 |
+
" for batch in train_loader:\n",
|
| 259 |
+
" inputs, labels = batch\n",
|
| 260 |
+
" inputs, labels = inputs.to(model.device), labels.to(model.device)\n",
|
| 261 |
+
" \n",
|
| 262 |
+
" optimizer.zero_grad()\n",
|
| 263 |
+
" outputs = model(inputs)\n",
|
| 264 |
+
" loss = criterion(outputs, labels.float().unsqueeze(1))\n",
|
| 265 |
+
" loss.backward()\n",
|
| 266 |
+
" optimizer.step()\n",
|
| 267 |
+
" \n",
|
| 268 |
+
" total_loss += loss.item()\n",
|
| 269 |
+
" avg_train_loss = total_loss / len(train_loader)\n",
|
| 270 |
+
" \n",
|
| 271 |
+
" # Validation phase\n",
|
| 272 |
+
" model.eval()\n",
|
| 273 |
+
" val_loss = 0\n",
|
| 274 |
+
" with torch.no_grad():\n",
|
| 275 |
+
" for val_batch in val_loader:\n",
|
| 276 |
+
" val_inputs, val_labels = val_batch\n",
|
| 277 |
+
" val_inputs, val_labels = val_inputs.to(model.device), val_labels.to(model.device)\n",
|
| 278 |
+
" val_outputs = model(val_inputs)\n",
|
| 279 |
+
" val_loss += criterion(val_outputs, val_labels.float().unsqueeze(1)).item()\n",
|
| 280 |
+
" \n",
|
| 281 |
+
" avg_val_loss = val_loss / len(val_loader)\n",
|
| 282 |
+
" epoch_progress.set_postfix({\n",
|
| 283 |
+
" 'train_loss': avg_train_loss,\n",
|
| 284 |
+
" 'val_loss': avg_val_loss\n",
|
| 285 |
+
" })\n",
|
| 286 |
+
" \n",
|
| 287 |
+
" # Early stopping check\n",
|
| 288 |
+
" if early_stopping(avg_val_loss, model):\n",
|
| 289 |
+
" if early_stopping.best_model_state:\n",
|
| 290 |
+
" model.load_state_dict(early_stopping.best_model_state)\n",
|
| 291 |
+
" print(f\"Early stopping triggered for Model {model_idx}\")\n",
|
| 292 |
+
" break\n",
|
| 293 |
+
" \n",
|
| 294 |
+
" # Learning rate adjustment\n",
|
| 295 |
+
" scheduler.step()\n",
|
| 296 |
+
" \n",
|
| 297 |
+
" # Reset to training mode\n",
|
| 298 |
+
" model.train()\n",
|
| 299 |
+
" \n",
|
| 300 |
+
" # Store model's final state after training\n",
|
| 301 |
+
" model.eval()\n",
|
| 302 |
+
" \n",
|
| 303 |
+
" def compute_test_weights(self, test_loader):\n",
|
| 304 |
+
" \"\"\"\n",
|
| 305 |
+
" Compute model weights based on test accuracy while emphasizing distinctions.\n",
|
| 306 |
+
" \"\"\"\n",
|
| 307 |
+
" model_accuracies = []\n",
|
| 308 |
+
" for model_idx, model in enumerate(self.models):\n",
|
| 309 |
+
" correct = 0\n",
|
| 310 |
+
" total = 0\n",
|
| 311 |
+
" model.eval()\n",
|
| 312 |
+
" with torch.no_grad():\n",
|
| 313 |
+
" for inputs, labels in test_loader:\n",
|
| 314 |
+
" inputs, labels = inputs.to(model.device), labels.to(model.device)\n",
|
| 315 |
+
" outputs = model(inputs)\n",
|
| 316 |
+
" preds = (outputs > 0.5).float()\n",
|
| 317 |
+
" correct += (preds == labels).sum().item()\n",
|
| 318 |
+
" total += labels.size(0)\n",
|
| 319 |
+
" accuracy = correct / total\n",
|
| 320 |
+
" model_accuracies.append(accuracy)\n",
|
| 321 |
+
" \n",
|
| 322 |
+
" # Apply a power transformation for distinction\n",
|
| 323 |
+
" accuracies = np.array(model_accuracies)\n",
|
| 324 |
+
" print(f\"Raw model accuracies: {accuracies}\")\n",
|
| 325 |
+
" \n",
|
| 326 |
+
" # Use power scaling to exaggerate differences (e.g., square the accuracies)\n",
|
| 327 |
+
" power_scaling_factor = 2 # Choose 2 for squaring, can experiment with higher values\n",
|
| 328 |
+
" scaled_accuracies = accuracies ** power_scaling_factor\n",
|
| 329 |
+
" \n",
|
| 330 |
+
" # Smooth the accuracies slightly to avoid over-reliance on any single model\n",
|
| 331 |
+
" smoothed_accuracies = scaled_accuracies * (1 - 0.1) + 0.1 * np.mean(scaled_accuracies)\n",
|
| 332 |
+
" \n",
|
| 333 |
+
" # Normalize weights so they sum to 1\n",
|
| 334 |
+
" weights = smoothed_accuracies / smoothed_accuracies.sum()\n",
|
| 335 |
+
" \n",
|
| 336 |
+
" # Store model weights\n",
|
| 337 |
+
" self.model_weights = torch.tensor(weights, dtype=torch.float32).to(self.device)\n",
|
| 338 |
+
" print(f\"Model weights after scaling: {self.model_weights}\")\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" def predict(self, test_loader, confidence_threshold=0.5, return_raw_scores=True):\n",
|
| 342 |
+
" \"\"\"\n",
|
| 343 |
+
" Prediction with confidence-weighted voting, optionally returning raw scores.\n",
|
| 344 |
+
" \"\"\"\n",
|
| 345 |
+
" if self.model_weights is None:\n",
|
| 346 |
+
" raise ValueError(\"Model weights not computed. Call compute_test_weights first.\")\n",
|
| 347 |
+
" \n",
|
| 348 |
+
" all_predictions = []\n",
|
| 349 |
+
" for model_idx, model in enumerate(self.models):\n",
|
| 350 |
+
" model.eval()\n",
|
| 351 |
+
" model_preds = []\n",
|
| 352 |
+
" with torch.no_grad():\n",
|
| 353 |
+
" for batch in test_loader:\n",
|
| 354 |
+
" inputs, _ = batch\n",
|
| 355 |
+
" inputs = inputs.to(model.device)\n",
|
| 356 |
+
" outputs = model(inputs)\n",
|
| 357 |
+
" model_preds.append(outputs)\n",
|
| 358 |
+
" \n",
|
| 359 |
+
" # Concatenate predictions for this model\n",
|
| 360 |
+
" all_predictions.append(torch.cat(model_preds))\n",
|
| 361 |
+
" \n",
|
| 362 |
+
" # Stack predictions and compute weighted average\n",
|
| 363 |
+
" stacked_preds = torch.stack(all_predictions, dim=1).squeeze(-1)\n",
|
| 364 |
+
" weighted_preds = (stacked_preds * self.model_weights.view(1, -1)).sum(dim=1)\n",
|
| 365 |
+
" \n",
|
| 366 |
+
" # Final prediction with thresholding\n",
|
| 367 |
+
" final_preds = (weighted_preds > confidence_threshold).float()\n",
|
| 368 |
+
" \n",
|
| 369 |
+
" # Optionally return raw probabilities for debugging\n",
|
| 370 |
+
" if return_raw_scores:\n",
|
| 371 |
+
" return final_preds, weighted_preds.cpu().numpy()\n",
|
| 372 |
+
" \n",
|
| 373 |
+
" return final_preds\n",
|
| 374 |
+
"\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" def save_models(self, save_dir='ensemble_models/model_test_4'):\n",
|
| 377 |
+
" \"\"\"\n",
|
| 378 |
+
" Save ensemble model weights and model weights with progress tracking\n",
|
| 379 |
+
" \"\"\"\n",
|
| 380 |
+
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" save_data = {\n",
|
| 383 |
+
" 'models': {},\n",
|
| 384 |
+
" 'model_weights': self.model_weights.cpu().numpy() if self.model_weights is not None else None\n",
|
| 385 |
+
" }\n",
|
| 386 |
+
"\n",
|
| 387 |
+
" for i, model in tqdm(enumerate(self.models), desc=\"Saving Models\", total=len(self.models)):\n",
|
| 388 |
+
" save_data['models'][i] = model.state_dict()\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" torch.save(save_data, os.path.join(save_dir, 'ensemble_checkpoint.pth'))\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" def load_models(self, save_dir='ensemble_models/model_test_4'):\n",
|
| 393 |
+
" \"\"\"\n",
|
| 394 |
+
" Load ensemble model weights and model weights with progress tracking\n",
|
| 395 |
+
" \"\"\"\n",
|
| 396 |
+
" checkpoint_path = os.path.join(save_dir, 'ensemble_checkpoint.pth')\n",
|
| 397 |
+
"\n",
|
| 398 |
+
" save_data = torch.load(checkpoint_path)\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" for i, model in tqdm(enumerate(self.models), desc=\"Loading Models\", total=len(self.models)):\n",
|
| 401 |
+
" model.load_state_dict(save_data['models'][i])\n",
|
| 402 |
+
" model.eval() # Set to evaluation mode\n",
|
| 403 |
+
"\n",
|
| 404 |
+
" if save_data['model_weights'] is not None:\n",
|
| 405 |
+
" self.model_weights = torch.tensor(save_data['model_weights'], dtype=torch.float32).to(self.device)\n",
|
| 406 |
+
" \n",
|
| 407 |
+
" def evaluate(self, test_loader):\n",
|
| 408 |
+
" \"\"\"\n",
|
| 409 |
+
" Evaluate ensemble performance with weighted voting, supporting both CPU and GPU.\n",
|
| 410 |
+
" \"\"\"\n",
|
| 411 |
+
" # Collect ground truth labels\n",
|
| 412 |
+
" all_labels = torch.cat([labels for _, labels in test_loader], dim=0).to(self.device)\n",
|
| 413 |
+
" \n",
|
| 414 |
+
" # Get predictions for the entire test set\n",
|
| 415 |
+
" test_preds = self.predict(test_loader, return_raw_scores=True)\n",
|
| 416 |
+
" \n",
|
| 417 |
+
" # Ensure predictions and labels are on the same device\n",
|
| 418 |
+
" all_labels = all_labels.cpu().numpy().ravel() # Flatten to 1D\n",
|
| 419 |
+
" test_preds, raw_probs = test_preds\n",
|
| 420 |
+
" test_preds = test_preds.cpu().numpy().ravel() # Flatten to 1D\n",
|
| 421 |
+
" \n",
|
| 422 |
+
" # Print debug information\n",
|
| 423 |
+
" # print(\"Ground truth labels (all_labels):\", all_labels)\n",
|
| 424 |
+
" # print(\"Predicted classes (test_preds):\", test_preds)\n",
|
| 425 |
+
" # print(\"Raw probabilities (raw_probs):\", raw_probs) \n",
|
| 426 |
+
" \n",
|
| 427 |
+
" # Calculate metrics\n",
|
| 428 |
+
" accuracy = np.mean(test_preds == all_labels)\n",
|
| 429 |
+
" precision = precision_score(all_labels, test_preds, zero_division=0)\n",
|
| 430 |
+
" recall = recall_score(all_labels, test_preds, zero_division=0)\n",
|
| 431 |
+
" \n",
|
| 432 |
+
" return {\n",
|
| 433 |
+
" \"accuracy\": accuracy,\n",
|
| 434 |
+
" \"precision\": precision,\n",
|
| 435 |
+
" \"recall\": recall\n",
|
| 436 |
+
" }"
|
| 437 |
+
]
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"cell_type": "code",
|
| 441 |
+
"execution_count": 3,
|
| 442 |
+
"id": "a95bb0eb-48ba-4c46-9cc5-4f6a1ee19dee",
|
| 443 |
+
"metadata": {},
|
| 444 |
+
"outputs": [
|
| 445 |
+
{
|
| 446 |
+
"name": "stdout",
|
| 447 |
+
"output_type": "stream",
|
| 448 |
+
"text": [
|
| 449 |
+
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
|
| 450 |
+
"Requirement already satisfied: FlagEmbedding in /opt/conda/lib/python3.11/site-packages (1.3.3)\n",
|
| 451 |
+
"Requirement already satisfied: torch>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.2.2+cu121)\n",
|
| 452 |
+
"Requirement already satisfied: transformers==4.44.2 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.44.2)\n",
|
| 453 |
+
"Requirement already satisfied: datasets==2.19.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.19.0)\n",
|
| 454 |
+
"Requirement already satisfied: accelerate>=0.20.1 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (1.2.0)\n",
|
| 455 |
+
"Requirement already satisfied: sentence-transformers in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (3.3.1)\n",
|
| 456 |
+
"Requirement already satisfied: peft in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.14.0)\n",
|
| 457 |
+
"Requirement already satisfied: ir-datasets in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.5.9)\n",
|
| 458 |
+
"Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.2.0)\n",
|
| 459 |
+
"Requirement already satisfied: protobuf in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.25.3)\n",
|
| 460 |
+
"Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.9.0)\n",
|
| 461 |
+
"Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (1.26.4)\n",
|
| 462 |
+
"Requirement already satisfied: pyarrow>=12.0.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (15.0.2)\n",
|
| 463 |
+
"Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.6)\n",
|
| 464 |
+
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.3.8)\n",
|
| 465 |
+
"Requirement already satisfied: pandas in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.2.2)\n",
|
| 466 |
+
"Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.31.0)\n",
|
| 467 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (4.66.2)\n",
|
| 468 |
+
"Requirement already satisfied: xxhash in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.5.0)\n",
|
| 469 |
+
"Requirement already satisfied: multiprocess in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.70.16)\n",
|
| 470 |
+
"Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /opt/conda/lib/python3.11/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0->FlagEmbedding) (2024.3.1)\n",
|
| 471 |
+
"Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.11.10)\n",
|
| 472 |
+
"Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.26.5)\n",
|
| 473 |
+
"Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (24.0)\n",
|
| 474 |
+
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (6.0.1)\n",
|
| 475 |
+
"Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (2024.11.6)\n",
|
| 476 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.4.5)\n",
|
| 477 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.19.1)\n",
|
| 478 |
+
"Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from accelerate>=0.20.1->FlagEmbedding) (5.9.8)\n",
|
| 479 |
+
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (4.11.0)\n",
|
| 480 |
+
"Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (1.12)\n",
|
| 481 |
+
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.3)\n",
|
| 482 |
+
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.1.3)\n",
|
| 483 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
| 484 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
| 485 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
| 486 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (8.9.2.26)\n",
|
| 487 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.3.1)\n",
|
| 488 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.0.2.54)\n",
|
| 489 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (10.3.2.106)\n",
|
| 490 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.4.5.107)\n",
|
| 491 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.0.106)\n",
|
| 492 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.19.3)\n",
|
| 493 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
| 494 |
+
"Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.2.0)\n",
|
| 495 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.6.0->FlagEmbedding) (12.4.127)\n",
|
| 496 |
+
"Requirement already satisfied: beautifulsoup4>=4.4.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.12.3)\n",
|
| 497 |
+
"Requirement already satisfied: inscriptis>=2.2.0 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.5.0)\n",
|
| 498 |
+
"Requirement already satisfied: lxml>=4.5.2 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (5.3.0)\n",
|
| 499 |
+
"Requirement already satisfied: trec-car-tools>=2.5.4 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.6)\n",
|
| 500 |
+
"Requirement already satisfied: lz4>=3.1.10 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.3.3)\n",
|
| 501 |
+
"Requirement already satisfied: warc3-wet>=0.2.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
|
| 502 |
+
"Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
|
| 503 |
+
"Requirement already satisfied: zlib-state>=0.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.1.9)\n",
|
| 504 |
+
"Requirement already satisfied: ijson>=3.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (3.3.0)\n",
|
| 505 |
+
"Requirement already satisfied: unlzw3>=0.2.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.2)\n",
|
| 506 |
+
"Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.4.2)\n",
|
| 507 |
+
"Requirement already satisfied: scipy in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.13.0)\n",
|
| 508 |
+
"Requirement already satisfied: Pillow in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (10.3.0)\n",
|
| 509 |
+
"Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.11/site-packages (from beautifulsoup4>=4.4.1->ir-datasets->FlagEmbedding) (2.5)\n",
|
| 510 |
+
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (2.4.4)\n",
|
| 511 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.3.2)\n",
|
| 512 |
+
"Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (23.2.0)\n",
|
| 513 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.5.0)\n",
|
| 514 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (6.1.0)\n",
|
| 515 |
+
"Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (0.2.1)\n",
|
| 516 |
+
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.18.3)\n",
|
| 517 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.3.2)\n",
|
| 518 |
+
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.7)\n",
|
| 519 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2.2.1)\n",
|
| 520 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2024.2.2)\n",
|
| 521 |
+
"Requirement already satisfied: cbor>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from trec-car-tools>=2.5.4->ir-datasets->FlagEmbedding) (1.0.0)\n",
|
| 522 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->FlagEmbedding) (2.1.5)\n",
|
| 523 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2.9.0)\n",
|
| 524 |
+
"Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
|
| 525 |
+
"Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
|
| 526 |
+
"Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (1.4.0)\n",
|
| 527 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (3.4.0)\n",
|
| 528 |
+
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch>=1.6.0->FlagEmbedding) (1.3.0)\n",
|
| 529 |
+
"Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.0->FlagEmbedding) (1.16.0)\n",
|
| 530 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 531 |
+
"\u001b[0m"
|
| 532 |
+
]
|
| 533 |
+
},
|
| 534 |
+
{
|
| 535 |
+
"data": {
|
| 536 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 537 |
+
"model_id": "a24dee20be054f138b75c100ab2e6a36",
|
| 538 |
+
"version_major": 2,
|
| 539 |
+
"version_minor": 0
|
| 540 |
+
},
|
| 541 |
+
"text/plain": [
|
| 542 |
+
"Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]"
|
| 543 |
+
]
|
| 544 |
+
},
|
| 545 |
+
"metadata": {},
|
| 546 |
+
"output_type": "display_data"
|
| 547 |
+
},
|
| 548 |
+
{
|
| 549 |
+
"name": "stdout",
|
| 550 |
+
"output_type": "stream",
|
| 551 |
+
"text": [
|
| 552 |
+
"Encoding titles...\n"
|
| 553 |
+
]
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"name": "stderr",
|
| 557 |
+
"output_type": "stream",
|
| 558 |
+
"text": [
|
| 559 |
+
"You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
| 560 |
+
]
|
| 561 |
+
},
|
| 562 |
+
{
|
| 563 |
+
"name": "stdout",
|
| 564 |
+
"output_type": "stream",
|
| 565 |
+
"text": [
|
| 566 |
+
"Processed 20/20 titles\n"
|
| 567 |
+
]
|
| 568 |
+
}
|
| 569 |
+
],
|
| 570 |
+
"source": [
|
| 571 |
+
"!pip install FlagEmbedding\n",
|
| 572 |
+
"from FlagEmbedding import BGEM3FlagModel\n",
|
| 573 |
+
"model = BGEM3FlagModel('BAAI/bge-m3')\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"# Remember to change the test path\n",
|
| 576 |
+
"test_data_path = \"/home/jovyan/work/test_data_random_subset.csv\"\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"data = data = pd.read_csv(test_data_path)\n",
|
| 579 |
+
"titles = data['title'].tolist()\n",
|
| 580 |
+
"labels = data['labels'].tolist()\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"batch_size = 32\n",
|
| 583 |
+
"embeddings = []\n",
|
| 584 |
+
"\n",
|
| 585 |
+
"print('Encoding titles...')\n",
|
| 586 |
+
"for i in range(0, len(titles), batch_size):\n",
|
| 587 |
+
" batch = titles[i:i + batch_size]\n",
|
| 588 |
+
" batch_embeddings = model.encode(batch, batch_size=batch_size, max_length=512)['dense_vecs']\n",
|
| 589 |
+
" embeddings.extend(batch_embeddings)\n",
|
| 590 |
+
" print(f\"Processed {i + len(batch)}/{len(titles)} titles\")\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"embeddings_df = pd.DataFrame(embeddings)\n",
|
| 593 |
+
"embeddings_df['label'] = labels\n",
|
| 594 |
+
"\n",
|
| 595 |
+
"# Convert embeddings and labels to PyTorch tensors\n",
|
| 596 |
+
"X_test = torch.FloatTensor(embeddings_df.iloc[:, :-1].values) # Features\n",
|
| 597 |
+
"y_test = torch.FloatTensor(embeddings_df['label'].values).view(-1, 1) # Labels\n",
|
| 598 |
+
"\n",
|
| 599 |
+
"# Create DataLoader for the test dataset\n",
|
| 600 |
+
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 601 |
+
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"cell_type": "code",
|
| 606 |
+
"execution_count": 5,
|
| 607 |
+
"id": "c6bcf956-4e26-4278-a6fe-9955322cf06a",
|
| 608 |
+
"metadata": {},
|
| 609 |
+
"outputs": [
|
| 610 |
+
{
|
| 611 |
+
"name": "stderr",
|
| 612 |
+
"output_type": "stream",
|
| 613 |
+
"text": [
|
| 614 |
+
"Loading Models: 100%|██████████| 5/5 [00:00<00:00, 1799.05it/s]"
|
| 615 |
+
]
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"name": "stdout",
|
| 619 |
+
"output_type": "stream",
|
| 620 |
+
"text": [
|
| 621 |
+
"{'accuracy': 0.9, 'precision': 0.9, 'recall': 0.9}\n"
|
| 622 |
+
]
|
| 623 |
+
},
|
| 624 |
+
{
|
| 625 |
+
"name": "stderr",
|
| 626 |
+
"output_type": "stream",
|
| 627 |
+
"text": [
|
| 628 |
+
"\n"
|
| 629 |
+
]
|
| 630 |
+
}
|
| 631 |
+
],
|
| 632 |
+
"source": [
|
| 633 |
+
"from sklearn.metrics import precision_score, recall_score\n",
|
| 634 |
+
"ensemble = EnsembleClassifier(5) \n",
|
| 635 |
+
"\n",
|
| 636 |
+
"# Load saved model weights\n",
|
| 637 |
+
"# Be sure to change to the actual path\n",
|
| 638 |
+
"ensemble.load_models(save_dir='/home/jovyan/work/ensemble_models/model_test_4')\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"# Evaluate the ensemble\n",
|
| 641 |
+
"results = ensemble.evaluate(test_loader)\n",
|
| 642 |
+
"print(results)"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"cell_type": "code",
|
| 647 |
+
"execution_count": null,
|
| 648 |
+
"id": "77da0a63-cf76-4cbb-8b48-da115e124946",
|
| 649 |
+
"metadata": {},
|
| 650 |
+
"outputs": [],
|
| 651 |
+
"source": []
|
| 652 |
+
}
|
| 653 |
+
],
|
| 654 |
+
"metadata": {
|
| 655 |
+
"kernelspec": {
|
| 656 |
+
"display_name": "Python 3 (ipykernel)",
|
| 657 |
+
"language": "python",
|
| 658 |
+
"name": "python3"
|
| 659 |
+
},
|
| 660 |
+
"language_info": {
|
| 661 |
+
"codemirror_mode": {
|
| 662 |
+
"name": "ipython",
|
| 663 |
+
"version": 3
|
| 664 |
+
},
|
| 665 |
+
"file_extension": ".py",
|
| 666 |
+
"mimetype": "text/x-python",
|
| 667 |
+
"name": "python",
|
| 668 |
+
"nbconvert_exporter": "python",
|
| 669 |
+
"pygments_lexer": "ipython3",
|
| 670 |
+
"version": "3.11.8"
|
| 671 |
+
}
|
| 672 |
+
},
|
| 673 |
+
"nbformat": 4,
|
| 674 |
+
"nbformat_minor": 5
|
| 675 |
+
}
|