Spaces:
Running
Running
jl commited on
Commit Β·
3e8208f
1
Parent(s): 8092c6e
fix: regex filter for text
Browse files- notebooks/thesis-proposed-model-vo3.ipynb +0 -1
- notebooks/thesis-proposed-model-vo4.ipynb +1424 -0
- src/app.py +18 -4
notebooks/thesis-proposed-model-vo3.ipynb
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceType":"datasetVersion","sourceId":14919503,"datasetId":9546304,"databundleVersionId":15785969}],"dockerImageVersionId":31286,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.optim import AdamW\nfrom sklearn.model_selection import train_test_split\nfrom transformers import AutoModel, AutoTokenizer\n\nimport numpy as np\nimport pandas as pd\nimport time\nimport os\nfrom tqdm.auto import tqdm\n\nfrom sklearn.metrics import (\n accuracy_score,\n precision_recall_fscore_support,\n classification_report,\n confusion_matrix\n)\n\nimport matplotlib.pyplot as plt\nimport seaborn as sns","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class AdditionalCustomDataset(Dataset):\n \"\"\"\n Pre-tokenizes all data during initialization for much faster training.\n \"\"\"\n def __init__(self, texts, labels, additional_texts, tokenizer, bert_tokenizer, max_length):\n self.labels = labels\n self.max_length = max_length\n \n # Pre-tokenize ALL data once during initialization (MUCH faster than tokenizing in __getitem__)\n print(\"Pre-tokenizing primary texts...\")\n primary_encodings = tokenizer(\n texts,\n max_length=max_length,\n truncation=True,\n padding='max_length',\n return_tensors='pt'\n )\n self.primary_input_ids = primary_encodings['input_ids']\n self.primary_attention_mask = primary_encodings['attention_mask']\n \n print(\"Pre-tokenizing additional texts...\")\n additional_encodings = bert_tokenizer(\n additional_texts,\n max_length=max_length,\n truncation=True,\n padding='max_length',\n return_tensors='pt'\n )\n self.additional_input_ids = additional_encodings['input_ids']\n self.additional_attention_mask = additional_encodings['attention_mask']\n \n print(f\"Pre-tokenization complete. Dataset size: {len(self.labels)}\")\n\n def __len__(self):\n return len(self.labels)\n\n def __getitem__(self, idx):\n return (\n self.primary_input_ids[idx],\n self.primary_attention_mask[idx],\n self.additional_input_ids[idx],\n self.additional_attention_mask[idx],\n self.labels[idx]\n )","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class ProjectionMLP(nn.Module):\n def __init__(self, input_size, output_size):\n super(ProjectionMLP, self).__init__()\n self.layers = nn.Sequential(\n nn.Linear(input_size, output_size),\n nn.ReLU(),\n nn.Linear(output_size, 2)\n )\n\n def forward(self, x):\n return self.layers(x)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class GumbelTokenSelector(nn.Module):\n def __init__(self, hidden_size, tau=1.0):\n super().__init__()\n self.tau = tau\n self.proj = nn.Linear(hidden_size * 2, 1)\n \n def forward(self, token_embeddings, cls_embedding, training=True):\n \"\"\"\n token_embeddings: (B, L, H)\n cls_embedding: (B, H)\n \"\"\"\n B, L, H = token_embeddings.size()\n \n cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)\n x = torch.cat([token_embeddings, cls_exp], dim=-1)\n \n logits = self.proj(x).squeeze(-1) # (B, L)\n \n if training:\n probs = F.gumbel_softmax(\n torch.stack([logits, torch.zeros_like(logits)], dim=-1),\n tau=self.tau,\n hard=False\n )[..., 0]\n else:\n probs = torch.sigmoid(logits)\n \n return probs, logits","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class MultiScaleAttentionCNN(nn.Module):\n def __init__(\n self,\n hidden_size=768,\n num_filters=128,\n kernel_sizes=(2, 3, 4),\n dropout=0.3,\n ):\n super().__init__()\n \n self.convs = nn.ModuleList([\n nn.Conv1d(hidden_size, num_filters, k)\n for k in kernel_sizes\n ])\n \n self.attention_fc = nn.Linear(num_filters, 1)\n self.dropout = nn.Dropout(dropout)\n self.out_dim = num_filters * len(kernel_sizes)\n \n def forward(self, x, mask):\n \"\"\"\n x: (B, L, H)\n mask: (B, L)\n \"\"\"\n x = x.transpose(1, 2) # (B, H, L)\n feats = []\n \n for conv in self.convs:\n h = F.relu(conv(x)) # (B, C, L')\n h = h.transpose(1, 2) # (B, L', C)\n \n attn = self.attention_fc(h).squeeze(-1)\n attn = attn.masked_fill(mask[:, :attn.size(1)] == 0, -1e9)\n alpha = F.softmax(attn, dim=1)\n \n pooled = torch.sum(h * alpha.unsqueeze(-1), dim=1)\n feats.append(pooled)\n \n out = torch.cat(feats, dim=1)\n return self.dropout(out)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class TemporalCNN(nn.Module):\n def __init__(\n self,\n hidden_size=768,\n num_filters=128,\n kernel_sizes=(2, 3, 4),\n dropout=0.1,\n dilation_base=2,\n ):\n super().__init__()\n \n self.kernel_sizes = kernel_sizes\n self.dilation_base = dilation_base\n \n # Dilated convolutions with exponentially increasing dilation rates\n self.convs = nn.ModuleList([\n nn.Conv1d(\n hidden_size, \n num_filters, \n k,\n dilation=dilation_base ** i, # dilation = 2^i\n padding=0 # we'll handle causal padding manually\n )\n for i, k in enumerate(kernel_sizes)\n ])\n \n self.dropout = nn.Dropout(dropout)\n self.out_dim = num_filters * len(kernel_sizes)\n \n def _causal_padding(self, x, kernel_size, dilation):\n \"\"\"\n Apply left padding only (causal) to ensure output at time t\n only depends on inputs from time 0 to t.\n \"\"\"\n # Calculate required padding: (kernel_size - 1) * dilation\n padding = (kernel_size - 1) * dilation\n # Pad only on the left (temporal dimension)\n return F.pad(x, (padding, 0))\n \n def forward(self, x, attention_mask):\n \"\"\"\n x: (B, L, H)\n attention_mask: (B, L)\n \"\"\"\n # zero-out padding tokens\n mask = attention_mask.unsqueeze(-1)\n x = x * mask\n \n # (B, H, L) for Conv1d\n x = x.transpose(1, 2)\n \n feats = []\n for i, conv in enumerate(self.convs):\n kernel_size = self.kernel_sizes[i]\n dilation = self.dilation_base ** i\n \n # Apply causal padding (left-only)\n x_padded = self._causal_padding(x, kernel_size, dilation)\n \n # Apply dilated convolution\n c = F.relu(conv(x_padded)) # (B, C, L')\n \n # Global max pooling over the temporal dimension\n p = F.max_pool1d(c, kernel_size=c.size(2)).squeeze(2) # (B, C)\n \n feats.append(p)\n \n out = torch.cat(feats, dim=1) # (B, num_filters * len(kernel_sizes))\n return self.dropout(out)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class ConcatModel(nn.Module):\n def __init__(\n self,\n hatebert_model,\n additional_model,\n temporal_cnn,\n msa_cnn,\n selector,\n projection_mlp,\n unfreeze_n_layers_hate=12, #hatebert unfreeze all 12 layers\n unfreeze_n_layers_add=0, # additional bert freeze all layers\n hate_pooler=True,\n add_pooler=False\n \n ):\n super().__init__()\n \n self.hatebert_model = hatebert_model\n self.additional_model = additional_model\n \n self.temporal_cnn = temporal_cnn\n self.msa_cnn = msa_cnn\n self.selector = selector\n self.projection_mlp = projection_mlp\n \n # freeze everything on additional bert\n for p in self.additional_model.parameters():\n p.requires_grad = False\n\n # will unfreeze last n layers of additional model\n add_num_layers = len(self.additional_model.encoder.layer)\n for i in range(add_num_layers - unfreeze_n_layers_add, add_num_layers):\n for param in self.additional_model.encoder.layer[i].parameters():\n param.requires_grad = True\n\n if self.additional_model.pooler is not None and add_pooler:\n for p in self.additional_model.pooler.parameters():\n p.requires_grad = True\n\n # freeze everything on hatebert\n for p in self.hatebert_model.parameters():\n p.requires_grad = False\n \n # FIX: Use hatebert_model to get the correct number of layers\n hate_num_layers = len(self.hatebert_model.encoder.layer)\n for i in range(hate_num_layers - unfreeze_n_layers_hate, hate_num_layers):\n for param in self.hatebert_model.encoder.layer[i].parameters():\n param.requires_grad = True\n \n if self.hatebert_model.pooler is not None and hate_pooler:\n for p in self.hatebert_model.pooler.parameters():\n p.requires_grad = True\n\n \n def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):\n # ================= HateBERT =================\n hate_outputs = self.hatebert_model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n )\n seq_emb = hate_outputs.last_hidden_state # (B, L, H)\n cls_emb = seq_emb[:, 0, :] # (B, H)\n \n # ---- Token Selector ----\n token_probs, token_logits = self.selector(seq_emb, cls_emb, self.training)\n \n # ---- Temporal CNN on FULL embeddings (NOT masked) ----\n temporal_feat = self.temporal_cnn(seq_emb, attention_mask)\n \n # ---- Rationale-Weighted Summary Vector H_r ----\n weights = token_probs.unsqueeze(-1) # (B, L, 1)\n H_r = (seq_emb * weights).sum(dim=1) / (weights.sum(dim=1) + 1e-6)\n \n # ================= Frozen Rationale BERT =================\n add_outputs = self.additional_model(\n input_ids=additional_input_ids,\n attention_mask=additional_attention_mask,\n )\n add_seq = add_outputs.last_hidden_state\n \n # ---- Multi-Scale Attention CNN ----\n msa_feat = self.msa_cnn(add_seq, additional_attention_mask)\n \n # ================= CONCAT (4 components) =================\n concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)\n \n logits = self.projection_mlp(concat)\n return logits","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class EarlyStopping:\n \"\"\"\n Early stopping to stop training when validation loss doesn't improve.\n \"\"\"\n def __init__(self, patience=10, min_delta=1.0, mode='min', verbose=True):\n \"\"\"\n Args:\n patience (int): How many epochs to wait after last improvement.\n min_delta (float): Minimum change to qualify as an improvement.\n mode (str): 'min' for loss, 'max' for accuracy/f1.\n verbose (bool): Print messages when improvement occurs.\n \"\"\"\n self.patience = patience\n self.min_delta = min_delta\n self.mode = mode\n self.verbose = verbose\n self.counter = 0\n self.best_score = None\n self.early_stop = False\n self.best_model_state = None\n \n def __call__(self, current_score, model):\n \"\"\"\n Call this after each epoch with the validation metric and model.\n \n Args:\n current_score: Current epoch's validation metric (loss, accuracy, f1, etc.)\n model: The model to save if there's improvement\n \n Returns:\n bool: True if training should stop, False otherwise\n \"\"\"\n if self.best_score is None:\n # First epoch\n self.best_score = current_score\n self.save_checkpoint(model)\n if self.verbose:\n print(f\"Initial best score: {self.best_score:.4f}\")\n else:\n # Check if there's improvement\n if self.mode == 'min':\n improved = current_score < (self.best_score - self.min_delta)\n else: # mode == 'max'\n improved = current_score > (self.best_score + self.min_delta)\n \n if improved:\n self.best_score = current_score\n self.save_checkpoint(model)\n self.counter = 0\n if self.verbose:\n print(f\"Validation improved! New best score: {self.best_score:.4f}\")\n else:\n self.counter += 1\n if self.verbose:\n print(f\"No improvement. Patience counter: {self.counter}/{self.patience}\")\n \n if self.counter >= self.patience:\n self.early_stop = True\n if self.verbose:\n print(f\"Early stopping triggered! Best score: {self.best_score:.4f}\")\n \n return self.early_stop\n \n def save_checkpoint(self, model):\n \"\"\"Save model state dict\"\"\"\n import copy\n self.best_model_state = copy.deepcopy(model.state_dict())\n \n def load_best_model(self, model):\n \"\"\"Load the best model state into the model\"\"\"\n if self.best_model_state is not None:\n model.load_state_dict(self.best_model_state)\n if self.verbose:\n print(f\"Loaded best model with score: {self.best_score:.4f}\")\n return model","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def main(args):\n torch.manual_seed(args.seed)\n torch.cuda.empty_cache()\n device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n\n\n file_map = {\n \"gab\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_GAB_dataset(85-15).csv',\n \"twitter\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_Twitter_dataset(85-15).csv',\n \"reddit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_REDDIT_dataset(85-15).csv',\n \"youtube\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_YOUTUBE_dataset(85-15).csv',\n \"implicit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_IMPLICIT_dataset(85-15).csv'\n }\n\n file_path = file_map[args.dataset]\n df = pd.read_csv(file_path)\n train_df = df[df['exp_split'] == 'train']\n test_df = df[df['exp_split'] == 'test']\n\n print(\"Train df: \", len(train_df))\n print(\"Test_df: \", len(test_df))\n\n import gc\n # del variables\n gc.collect()\n\n \n tokenizer = args.hate_tokenizer ## need this for tokenizing the input text in data loader\n tokenizer_bert = args.bert_tokenizer\n #Splitting training and validation testing split to test accuracy\n train_idx, val_idx = train_test_split(\n train_df.index,\n test_size=0.2,\n stratify=train_df[\"label\"],\n random_state=args.seed\n )\n \n if args.dataset == \"implicit\":\n train_text = train_df.loc[train_idx, \"post\"].tolist()\n val_texts = train_df.loc[val_idx, \"post\"].tolist()\n else:\n train_text = train_df.loc[train_idx, \"text\"].tolist()\n val_texts = train_df.loc[val_idx, \"text\"].tolist()\n \n add_train_text = train_df.loc[train_idx, \"Mistral_Rationales\"].tolist()\n add_val_texts = train_df.loc[val_idx, \"Mistral_Rationales\"].tolist()\n \n train_labels = train_df.loc[train_idx, \"label\"].tolist()\n val_labels = train_df.loc[val_idx, \"label\"].tolist()\n\n train_dataset = AdditionalCustomDataset(\n train_text,\n train_labels,\n add_train_text,\n tokenizer,\n tokenizer_bert,\n max_length=512\n )\n \n val_dataset = AdditionalCustomDataset(\n val_texts,\n val_labels,\n add_val_texts,\n tokenizer,\n tokenizer_bert,\n max_length=512\n )\n\n #Creating dataloader object to train the model\n # num_workers=2 for parallel data loading, pin_memory=True for faster GPU transfer\n train_dataloader = DataLoader(\n train_dataset, \n batch_size=args.batch_size, \n shuffle=True,\n num_workers=2,\n pin_memory=True if torch.cuda.is_available() else False\n )\n val_dataloader = DataLoader(\n val_dataset, \n batch_size=args.batch_size, \n shuffle=False,\n num_workers=2,\n pin_memory=True if torch.cuda.is_available() else False\n )\n hatebert_model = args.hatebert_model\n additional_model = args.additional_model\n \n \n temporal_cnn = TemporalCNN(\n hidden_size=768,\n num_filters=args.temp_filters, # 64 - 128\n kernel_sizes=(2, 3, 4),\n dropout=args.temp_dropout,\n dilation_base=args.temp_dilate # 0 - 4\n ).to(device)\n\n msa_cnn = MultiScaleAttentionCNN(\n hidden_size=768,\n num_filters=args.msa_filters, # 64 - 128\n kernel_sizes=(2, 3, 4),\n dropout=args.msa_dropout\n ).to(device)\n \n selector = GumbelTokenSelector(\n hidden_size=768,\n tau=1.0\n ).to(device)\n \n projection_mlp = ProjectionMLP(\n input_size=temporal_cnn.out_dim + msa_cnn.out_dim + 768 * 2,\n output_size=512\n ).to(device)\n\n\n\n concat_model = ConcatModel(\n hatebert_model=hatebert_model,\n additional_model=additional_model,\n temporal_cnn=temporal_cnn,\n msa_cnn=msa_cnn,\n selector=selector,\n projection_mlp=projection_mlp,\n unfreeze_n_layers_hate=args.hate_layers, #hatebert unfreeze all 12 layers id 12 by default\n unfreeze_n_layers_add=args.add_layers, # additional bert freeze all layers if 0 by default\n hate_pooler=args.hate_pooler, #bool to controle if pooler is frozen or not true=not frozen\n add_pooler=args.add_pooler #bool to controle if pooler is frozen or not true=not froze\n ).to(device)\n\n optimizer = AdamW(\n concat_model.parameters(),\n lr=args.lr, # 2e-5\n weight_decay=args.wd\n )\n criterion = nn.CrossEntropyLoss().to(device)\n\n # criterion = criterion.to(device)\n\n os.makedirs(\"/kaggle/working/models\", exist_ok=True)\n\n history = {\n \"train_loss\": [],\n \"val_loss\": [], #loss\n \"train_acc\": [],\n \"train_precision\": [],\n \"train_recall\": [],\n \"train_f1\": [], #train binary\n \"val_acc\": [],\n \"val_precision\": [],\n \"val_recall\": [],\n \"val_f1\": [], #validation binary\n \"train_acc_weighted\": [],\n \"train_precision_weighted\": [],\n \"train_recall_weighted\": [],\n \"train_f1_weighted\": [], # train weighted\n \"val_acc_weighted\": [],\n \"val_precision_weighted\": [],\n \"val_recall_weighted\": [],\n \"val_f1_weighted\": [], # validation weighted\n \"train_acc_macro\": [],\n \"train_precision_macro\": [],\n \"train_recall_macro\": [],\n \"train_f1_macro\": [], #train macro\n \"val_acc_macro\": [],\n \"val_precision_macro\": [],\n \"val_recall_macro\": [],\n \"val_f1_macro\": [],# validation macro\n \"epoch_time\": [],\n \"train_throughput\": [],\n \"val_confidence_mean\": [],\n \"val_confidence_std\": [],\n \"gpu_memory_mb\": [],\n }\n\n early_stopping = EarlyStopping(patience=args.patience, min_delta=0.001, mode='min', verbose=True) # early stop on loss\n\n for epoch in range(args.num_epochs):\n epoch_val_confidences = []\n epoch_start_time = time.time()\n samples_seen = 0\n \n concat_model.train()\n\n train_losses = []\n train_preds = []\n train_labels_epoch = []\n train_accuracy = 0\n train_epoch_size = 0\n\n with tqdm(train_dataloader, desc=f'Epoch {epoch + 1}', dynamic_ncols=True) as loop:\n for batch in loop:\n input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n \n samples_seen += labels.size(0)\n \n if torch.cuda.is_available():\n input_ids = input_ids.to(device)\n attention_mask = attention_mask.to(device)\n additional_input_ids = additional_input_ids.to(device)\n additional_attention_mask = additional_attention_mask.to(device)\n labels = labels.to(device)\n\n # Forward pass through the ConcatModel\n optimizer.zero_grad()\n outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n loss = criterion(outputs, labels)\n\n # Backward pass and optimization\n loss.backward()\n torch.nn.utils.clip_grad_norm_(concat_model.parameters(), max_norm=args.max_grad_norm)\n optimizer.step()\n \n probs = torch.softmax(outputs, dim=1)\n confidences, predictions = torch.max(probs, dim=1)\n train_preds.extend(predictions.cpu().numpy())\n train_labels_epoch.extend(labels.cpu().numpy())\n\n train_losses.append(loss.item())\n\n # Update accuracy and epoch size\n predictions = torch.argmax(outputs, dim=1)\n train_accuracy += (predictions == labels).sum().item()\n train_epoch_size += len(labels)\n \n epoch_train_time = time.time() - epoch_start_time\n train_throughput = samples_seen / epoch_train_time \n \n # Calculate train metrics (binary, weighted, macro)\n train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(\n train_labels_epoch, train_preds, average='binary'\n )\n train_precision_weighted, train_recall_weighted, train_f1_weighted, _ = precision_recall_fscore_support(\n train_labels_epoch, train_preds, average='weighted'\n )\n train_precision_macro, train_recall_macro, train_f1_macro, _ = precision_recall_fscore_support(\n train_labels_epoch, train_preds, average='macro'\n )\n train_acc = accuracy_score(train_labels_epoch, train_preds)\n\n # Evaluation on the validation set\n concat_model.eval()\n\n val_predictions = []\n val_labels_epoch = []\n val_loss = 0\n num_batches = 0\n\n with torch.no_grad(), tqdm(val_dataloader, desc='Validation', dynamic_ncols=True) as loop:\n for batch in loop:\n input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n\n if torch.cuda.is_available():\n input_ids = input_ids.to(device)\n attention_mask = attention_mask.to(device)\n additional_input_ids = additional_input_ids.to(device)\n additional_attention_mask = additional_attention_mask.to(device)\n labels = labels.to(device)\n\n # Forward pass through the ConcatModel\n outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n loss = criterion(outputs, labels)\n val_loss += loss.item()\n num_batches += 1\n probs = torch.softmax(outputs, dim=1)\n confidences, predictions = torch.max(probs, dim=1)\n\n epoch_val_confidences.extend(confidences.cpu().numpy())\n \n val_predictions.extend(predictions.cpu().numpy())\n val_labels_epoch.extend(labels.cpu().numpy())\n \n val_loss /= num_batches\n \n # Calculate validation metrics (binary)\n val_accuracy = accuracy_score(val_labels_epoch, val_predictions)\n val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(\n val_labels_epoch, val_predictions, average='binary'\n )\n # Calculate validation metrics (weighted)\n val_precision_weighted, val_recall_weighted, val_f1_weighted, _ = precision_recall_fscore_support(\n val_labels_epoch, val_predictions, average='weighted'\n )\n # Calculate validation metrics (macro)\n val_precision_macro, val_recall_macro, val_f1_macro, _ = precision_recall_fscore_support(\n val_labels_epoch, val_predictions, average='macro'\n )\n \n print(f\"Epoch {epoch}:\")\n print(f\" Train Accuracy: {train_acc:.4f}\")\n print(f\" Validation Accuracy: {val_accuracy:.4f}\")\n print(f\" Train Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}\")\n print(f\" Val Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}\")\n print(f\" Avg. Train Loss: {sum(train_losses) / len(train_losses):.4f}\")\n print(f\" Validation Loss: {val_loss:.4f}\")\n epoch_time = time.time() - epoch_start_time\n conf_mean = np.mean(epoch_val_confidences)\n conf_std = np.std(epoch_val_confidences)\n\n # Append all train metrics to history\n history[\"train_loss\"].append(np.mean(train_losses))\n history[\"train_acc\"].append(train_acc)\n history[\"train_precision\"].append(train_precision)\n history[\"train_recall\"].append(train_recall)\n history[\"train_f1\"].append(train_f1)\n history[\"train_acc_weighted\"].append(train_acc) # accuracy is same for all averaging\n history[\"train_precision_weighted\"].append(train_precision_weighted)\n history[\"train_recall_weighted\"].append(train_recall_weighted)\n history[\"train_f1_weighted\"].append(train_f1_weighted)\n history[\"train_acc_macro\"].append(train_acc) # accuracy is same for all averaging\n history[\"train_precision_macro\"].append(train_precision_macro)\n history[\"train_recall_macro\"].append(train_recall_macro)\n history[\"train_f1_macro\"].append(train_f1_macro)\n \n # Append all validation metrics to history\n history[\"val_loss\"].append(val_loss)\n history[\"val_acc\"].append(val_accuracy)\n history[\"val_precision\"].append(val_precision)\n history[\"val_recall\"].append(val_recall)\n history[\"val_f1\"].append(val_f1)\n history[\"val_acc_weighted\"].append(val_accuracy) # accuracy is same for all averaging\n history[\"val_precision_weighted\"].append(val_precision_weighted)\n history[\"val_recall_weighted\"].append(val_recall_weighted)\n history[\"val_f1_weighted\"].append(val_f1_weighted)\n history[\"val_acc_macro\"].append(val_accuracy) # accuracy is same for all averaging\n history[\"val_precision_macro\"].append(val_precision_macro)\n history[\"val_recall_macro\"].append(val_recall_macro)\n history[\"val_f1_macro\"].append(val_f1_macro)\n \n # Append efficiency metrics\n history[\"epoch_time\"].append(epoch_time)\n history[\"train_throughput\"].append(train_throughput)\n history[\"val_confidence_mean\"].append(conf_mean)\n history[\"val_confidence_std\"].append(conf_std)\n \n if torch.cuda.is_available():\n history[\"gpu_memory_mb\"].append(\n torch.cuda.max_memory_allocated() / 1024**2\n )\n torch.cuda.reset_peak_memory_stats()\n else:\n history[\"gpu_memory_mb\"].append(0)\n \n print(f\" Epoch Time (s): {epoch_time:.2f}\")\n print(f\" Throughput (samples/sec): {train_throughput:.2f}\")\n print(f\" Val Confidence Mean: {conf_mean:.4f} Β± {conf_std:.4f}\")\n \n current_metric = val_loss\n if early_stopping(current_metric, concat_model):\n print(f\"\\n{'='*50}\")\n print(f\"Early stopping at epoch {epoch+1}\")\n print(f\"{'='*50}\\n\")\n break\n model = early_stopping.load_best_model(concat_model)\n torch.save(model.state_dict(), f\"/kaggle/working/models/{args.dataset}_concat_model.pt\")\n print(f\"Best model saved to /kaggle/working/models/{args.dataset}_concat_model.pt\")\n\n checkpoint = {\n \"model_state_dict\": model.state_dict(),\n \"history\": history,\n }\n torch.save(checkpoint, f\"/kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n print(f\"Checkpoint with history saved to /kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n\n if args.dataset == \"implicit\":\n test_texts = test_df[\"post\"].tolist()\n else:\n test_texts = test_df[\"text\"].tolist()\n \n add_test_texts = test_df[\"Mistral_Rationales\"].tolist()\n test_labels = test_df[\"label\"].tolist()\n\n test_dataset = AdditionalCustomDataset(test_texts, test_labels, add_test_texts, tokenizer, tokenizer_bert, max_length=512)\n test_dataloader = DataLoader(\n test_dataset, \n batch_size=args.batch_size, # Use same batch size as training for faster inference\n shuffle=False,\n num_workers=2,\n pin_memory=True if torch.cuda.is_available() else False\n )\n\n # ================= TEST EVALUATION WITH EFFICIENCY =================\n model.eval()\n test_predictions = []\n test_true_labels = []\n test_confidences = []\n\n samples_seen = 0\n test_start_time = time.time()\n \n if torch.cuda.is_available():\n torch.cuda.reset_peak_memory_stats()\n \n with torch.no_grad(), tqdm(test_dataloader, desc='Testing', dynamic_ncols=True) as loop:\n for batch in loop:\n input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n \n batch_size = labels.size(0)\n samples_seen += batch_size\n \n input_ids = input_ids.to(device)\n attention_mask = attention_mask.to(device)\n additional_input_ids = additional_input_ids.to(device)\n additional_attention_mask = additional_attention_mask.to(device)\n labels = labels.to(device)\n \n outputs = model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n additional_input_ids=additional_input_ids,\n additional_attention_mask=additional_attention_mask\n )\n \n probs = torch.softmax(outputs, dim=1)\n confidences, preds = torch.max(probs, dim=1)\n \n test_confidences.extend(confidences.cpu().numpy())\n test_predictions.extend(preds.cpu().numpy())\n test_true_labels.extend(labels.cpu().numpy())\n\n \n # ================= TEST METRICS =================\n test_time = time.time() - test_start_time\n test_throughput = samples_seen / test_time\n \n accuracy = accuracy_score(test_true_labels, test_predictions)\n precision, recall, f1, _ = precision_recall_fscore_support(\n test_true_labels, test_predictions, average='weighted'\n )\n conf_mean = np.mean(test_confidences)\n conf_std = np.std(test_confidences)\n cm = confusion_matrix(test_true_labels, test_predictions)\n \n gpu_memory_mb = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0\n\n print(\"\\n================= FINAL TEST RESULTS =================\")\n print(f\"Dataset: {args.dataset}, Seed: {args.seed}, Epochs: {args.num_epochs}\")\n print(f\"Test Accuracy : {accuracy:.4f}\")\n print(f\"Test Precision: {precision:.4f}\")\n print(f\"Test Recall : {recall:.4f}\")\n print(f\"Test F1-score : {f1:.4f}\")\n print(f\"Test Confidence Mean Β± Std: {conf_mean:.4f} Β± {conf_std:.4f}\")\n print(f\"Test Time (s) : {test_time:.2f}\")\n print(f\"Throughput (samples/sec) : {test_throughput:.2f}\")\n print(f\"Peak GPU Memory (MB) : {gpu_memory_mb:.2f}\")\n print(\"\\nClassification Report:\")\n print(classification_report(test_true_labels, test_predictions))\n print(\"\\nConfusion Matrix:\")\n # Plot confusion matrix with seaborn\n plt.figure(figsize=(8, 6))\n sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n xticklabels=['Non-Hate', 'Hate'],\n yticklabels=['Non-Hate', 'Hate'],\n cbar_kws={'label': 'Count'})\n plt.title('Confusion Matrix', fontsize=14, pad=20)\n plt.ylabel('True label', fontsize=12)\n plt.xlabel('Predicted label', fontsize=12)\n plt.tight_layout()\n plt.show()\n print(\"======================================================\")\n\n return model, history","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from argparse import Namespace\nfrom transformers import AutoTokenizer, AutoModel\n\n# Load tokenizers and models ONCE (outside objective to save time)\ndevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\nhate_tokenizer = AutoTokenizer.from_pretrained(\"GroNLP/hateBERT\")\nbert_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n\ntorch.cuda.empty_cache()\n\n# Load fresh models for each trial (to reset weights)\nhatebert_model = AutoModel.from_pretrained(\"GroNLP/hateBERT\").to(device)\nadditional_model = AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n\n# Sample hyperparameters\nargs = Namespace(\n # Fixed\n seed=42,\n dataset=\"reddit\", # Change as needed: \"gab\", \"twitter\", \"reddit\", \"youtube\", \"implicit\"\n \n # Tokenizers & Models\n hate_tokenizer=hate_tokenizer,\n bert_tokenizer=bert_tokenizer,\n hatebert_model=hatebert_model,\n additional_model=additional_model,\n \n # Training hyperparameters\n batch_size=32,\n num_epochs=20,\n lr=1e-5,\n wd=0.01,\n patience=2,\n max_grad_norm=2.0,\n \n # TemporalCNN hyperparameters\n temp_filters=256,\n temp_dropout=0.1,\n temp_dilate=3,\n \n # MultiScaleAttentionCNN hyperparameters\n msa_filters=64,\n msa_dropout=0.23,\n \n # Layer unfreezing hyperparameters\n hate_layers=10,\n add_layers=0,\n hate_pooler=True,\n add_pooler=True,\n)\n \nmodel, history = main(args)\n\n\n ","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"fig, axes = plt.subplots(3, 4, figsize=(20, 15))\n\n# Row 1: Train vs Val comparisons\n# Accuracy\naxes[0, 0].plot(history['train_acc'], label='Train', marker='o')\naxes[0, 0].plot(history['val_acc'], label='Validation', marker='s')\naxes[0, 0].set_title('Accuracy')\naxes[0, 0].set_xlabel('Epoch')\naxes[0, 0].set_ylabel('Accuracy')\naxes[0, 0].legend()\naxes[0, 0].grid(True)\n\n# Loss\naxes[0, 1].plot(history['train_loss'], label='Train', marker='o')\naxes[0, 1].plot(history['val_loss'], label='Validation', marker='s')\naxes[0, 1].set_title('Loss')\naxes[0, 1].set_xlabel('Epoch')\naxes[0, 1].set_ylabel('Loss')\naxes[0, 1].legend()\naxes[0, 1].grid(True)\n\n# F1 Score\naxes[0, 2].plot(history['train_f1'], label='Train', marker='o')\naxes[0, 2].plot(history['val_f1'], label='Validation', marker='s')\naxes[0, 2].set_title('F1 Score (Binary)')\naxes[0, 2].set_xlabel('Epoch')\naxes[0, 2].set_ylabel('F1')\naxes[0, 2].legend()\naxes[0, 2].grid(True)\n\n# Precision\naxes[0, 3].plot(history['train_precision'], label='Train', marker='o')\naxes[0, 3].plot(history['val_precision'], label='Validation', marker='s')\naxes[0, 3].set_title('Precision (Binary)')\naxes[0, 3].set_xlabel('Epoch')\naxes[0, 3].set_ylabel('Precision')\naxes[0, 3].legend()\naxes[0, 3].grid(True)\n\n# Row 2: More Train vs Val + Individual metrics\n# Recall\naxes[1, 0].plot(history['train_recall'], label='Train', marker='o')\naxes[1, 0].plot(history['val_recall'], label='Validation', marker='s')\naxes[1, 0].set_title('Recall (Binary)')\naxes[1, 0].set_xlabel('Epoch')\naxes[1, 0].set_ylabel('Recall')\naxes[1, 0].legend()\naxes[1, 0].grid(True)\n\n# Epoch Time\naxes[1, 1].plot(history['epoch_time'], marker='o', color='green')\naxes[1, 1].set_title('Epoch Time')\naxes[1, 1].set_xlabel('Epoch')\naxes[1, 1].set_ylabel('Time (s)')\naxes[1, 1].grid(True)\n\n# Train Throughput\naxes[1, 2].plot(history['train_throughput'], marker='o', color='purple')\naxes[1, 2].set_title('Train Throughput')\naxes[1, 2].set_xlabel('Epoch')\naxes[1, 2].set_ylabel('Samples/sec')\naxes[1, 2].grid(True)\n\n# Validation Confidence\naxes[1, 3].errorbar(range(len(history['val_confidence_mean'])), \n history['val_confidence_mean'], \n yerr=history['val_confidence_std'], \n marker='o', color='orange', capsize=3)\naxes[1, 3].set_title('Validation Confidence')\naxes[1, 3].set_xlabel('Epoch')\naxes[1, 3].set_ylabel('Confidence')\naxes[1, 3].grid(True)\n\n# Row 3: GPU Memory and Weighted/Macro metrics\n# GPU Memory\naxes[2, 0].plot(history['gpu_memory_mb'], marker='o', color='red')\naxes[2, 0].set_title('GPU Memory Usage')\naxes[2, 0].set_xlabel('Epoch')\naxes[2, 0].set_ylabel('Memory (MB)')\naxes[2, 0].grid(True)\n\n# Weighted F1 comparison\naxes[2, 1].plot(history['train_f1_weighted'], label='Train', marker='o')\naxes[2, 1].plot(history['val_f1_weighted'], label='Validation', marker='s')\naxes[2, 1].set_title('F1 Score (Weighted)')\naxes[2, 1].set_xlabel('Epoch')\naxes[2, 1].set_ylabel('F1')\naxes[2, 1].legend()\naxes[2, 1].grid(True)\n\n# Macro F1 comparison\naxes[2, 2].plot(history['train_f1_macro'], label='Train', marker='o')\naxes[2, 2].plot(history['val_f1_macro'], label='Validation', marker='s')\naxes[2, 2].set_title('F1 Score (Macro)')\naxes[2, 2].set_xlabel('Epoch')\naxes[2, 2].set_ylabel('F1')\naxes[2, 2].legend()\naxes[2, 2].grid(True)\n\n# Hide last subplot\naxes[2, 3].axis('off')\n\nplt.suptitle('Training History', fontsize=16, y=1.02)\nplt.tight_layout()\nplt.show()","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
|
|
|
|
|
|
notebooks/thesis-proposed-model-vo4.ipynb
ADDED
|
@@ -0,0 +1,1424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
|
| 8 |
+
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
|
| 9 |
+
"trusted": true
|
| 10 |
+
},
|
| 11 |
+
"outputs": [],
|
| 12 |
+
"source": [
|
| 13 |
+
"import torch\n",
|
| 14 |
+
"import torch.nn as nn\n",
|
| 15 |
+
"import torch.nn.functional as F\n",
|
| 16 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 17 |
+
"from torch.optim import AdamW\n",
|
| 18 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 19 |
+
"from transformers import AutoModel, AutoTokenizer\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"import numpy as np\n",
|
| 22 |
+
"import pandas as pd\n",
|
| 23 |
+
"import time\n",
|
| 24 |
+
"import os\n",
|
| 25 |
+
"from tqdm.auto import tqdm\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"from sklearn.metrics import (\n",
|
| 28 |
+
" accuracy_score,\n",
|
| 29 |
+
" precision_recall_fscore_support,\n",
|
| 30 |
+
" classification_report,\n",
|
| 31 |
+
" confusion_matrix\n",
|
| 32 |
+
")\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"import matplotlib.pyplot as plt\n",
|
| 35 |
+
"import seaborn as sns"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"class AdditionalCustomDataset(Dataset):\n",
|
| 45 |
+
" \"\"\"\n",
|
| 46 |
+
" Pre-tokenizes all data during initialization for much faster training.\n",
|
| 47 |
+
" \"\"\"\n",
|
| 48 |
+
" def __init__(self, texts, labels, additional_texts, tokenizer, bert_tokenizer, max_length):\n",
|
| 49 |
+
" # Convert labels to tensor IMMEDIATELY to avoid type issues\n",
|
| 50 |
+
" self.labels = torch.tensor(labels, dtype=torch.long)\n",
|
| 51 |
+
" self.max_length = max_length\n",
|
| 52 |
+
" \n",
|
| 53 |
+
" # Pre-tokenize ALL data once during initialization\n",
|
| 54 |
+
" print(\"Pre-tokenizing primary texts...\")\n",
|
| 55 |
+
" primary_encodings = tokenizer(\n",
|
| 56 |
+
" texts,\n",
|
| 57 |
+
" max_length=max_length,\n",
|
| 58 |
+
" truncation=True,\n",
|
| 59 |
+
" padding='max_length',\n",
|
| 60 |
+
" return_tensors='pt'\n",
|
| 61 |
+
" )\n",
|
| 62 |
+
" self.primary_input_ids = primary_encodings['input_ids']\n",
|
| 63 |
+
" self.primary_attention_mask = primary_encodings['attention_mask']\n",
|
| 64 |
+
" \n",
|
| 65 |
+
" print(\"Pre-tokenizing additional texts...\")\n",
|
| 66 |
+
" additional_encodings = bert_tokenizer(\n",
|
| 67 |
+
" additional_texts,\n",
|
| 68 |
+
" max_length=max_length,\n",
|
| 69 |
+
" truncation=True,\n",
|
| 70 |
+
" padding='max_length',\n",
|
| 71 |
+
" return_tensors='pt'\n",
|
| 72 |
+
" )\n",
|
| 73 |
+
" self.additional_input_ids = additional_encodings['input_ids']\n",
|
| 74 |
+
" self.additional_attention_mask = additional_encodings['attention_mask']\n",
|
| 75 |
+
" \n",
|
| 76 |
+
" print(f\"Pre-tokenization complete. Dataset size: {len(self.labels)}\")\n",
|
| 77 |
+
" \n",
|
| 78 |
+
" # Optional: Move to pinned memory for faster GPU transfer\n",
|
| 79 |
+
" if torch.cuda.is_available():\n",
|
| 80 |
+
" self.primary_input_ids = self.primary_input_ids.pin_memory()\n",
|
| 81 |
+
" self.primary_attention_mask = self.primary_attention_mask.pin_memory()\n",
|
| 82 |
+
" self.additional_input_ids = self.additional_input_ids.pin_memory()\n",
|
| 83 |
+
" self.additional_attention_mask = self.additional_attention_mask.pin_memory()\n",
|
| 84 |
+
" self.labels = self.labels.pin_memory()\n",
|
| 85 |
+
"\n",
|
| 86 |
+
" def __len__(self):\n",
|
| 87 |
+
" return len(self.labels)\n",
|
| 88 |
+
"\n",
|
| 89 |
+
" def __getitem__(self, idx):\n",
|
| 90 |
+
" return (\n",
|
| 91 |
+
" self.primary_input_ids[idx],\n",
|
| 92 |
+
" self.primary_attention_mask[idx],\n",
|
| 93 |
+
" self.additional_input_ids[idx],\n",
|
| 94 |
+
" self.additional_attention_mask[idx],\n",
|
| 95 |
+
" self.labels[idx] # Now returns torch.long tensor\n",
|
| 96 |
+
" )"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"execution_count": null,
|
| 102 |
+
"metadata": {
|
| 103 |
+
"trusted": true
|
| 104 |
+
},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"class ProjectionMLP(nn.Module):\n",
|
| 108 |
+
" def __init__(self, input_size, output_size):\n",
|
| 109 |
+
" super(ProjectionMLP, self).__init__()\n",
|
| 110 |
+
" self.layers = nn.Sequential(\n",
|
| 111 |
+
" nn.Linear(input_size, output_size),\n",
|
| 112 |
+
" nn.ReLU(),\n",
|
| 113 |
+
" nn.Linear(output_size, 2)\n",
|
| 114 |
+
" )\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" def forward(self, x):\n",
|
| 117 |
+
" return self.layers(x)"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": null,
|
| 123 |
+
"metadata": {
|
| 124 |
+
"trusted": true
|
| 125 |
+
},
|
| 126 |
+
"outputs": [],
|
| 127 |
+
"source": [
|
| 128 |
+
"class GumbelTokenSelector(nn.Module):\n",
|
| 129 |
+
" def __init__(self, hidden_size, tau=1.0):\n",
|
| 130 |
+
" super().__init__()\n",
|
| 131 |
+
" self.tau = tau\n",
|
| 132 |
+
" self.proj = nn.Linear(hidden_size * 2, 1)\n",
|
| 133 |
+
" \n",
|
| 134 |
+
" def forward(self, token_embeddings, cls_embedding, training=True):\n",
|
| 135 |
+
" \"\"\"\n",
|
| 136 |
+
" token_embeddings: (B, L, H)\n",
|
| 137 |
+
" cls_embedding: (B, H)\n",
|
| 138 |
+
" \"\"\"\n",
|
| 139 |
+
" B, L, H = token_embeddings.size()\n",
|
| 140 |
+
" \n",
|
| 141 |
+
" cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)\n",
|
| 142 |
+
" x = torch.cat([token_embeddings, cls_exp], dim=-1)\n",
|
| 143 |
+
" \n",
|
| 144 |
+
" logits = self.proj(x).squeeze(-1) # (B, L)\n",
|
| 145 |
+
" \n",
|
| 146 |
+
" if training:\n",
|
| 147 |
+
" probs = F.gumbel_softmax(\n",
|
| 148 |
+
" torch.stack([logits, torch.zeros_like(logits)], dim=-1),\n",
|
| 149 |
+
" tau=self.tau,\n",
|
| 150 |
+
" hard=False\n",
|
| 151 |
+
" )[..., 0]\n",
|
| 152 |
+
" else:\n",
|
| 153 |
+
" probs = torch.sigmoid(logits)\n",
|
| 154 |
+
" \n",
|
| 155 |
+
" return probs, logits"
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"cell_type": "code",
|
| 160 |
+
"execution_count": null,
|
| 161 |
+
"metadata": {},
|
| 162 |
+
"outputs": [],
|
| 163 |
+
"source": [
|
| 164 |
+
"class AnnealedGumbelSelector(nn.Module):\n",
|
| 165 |
+
" def __init__(self, hidden_size, initial_tau=1.0, min_tau=0.1, decay_rate=0.9, hard=True):\n",
|
| 166 |
+
" \"\"\"\n",
|
| 167 |
+
" Args:\n",
|
| 168 |
+
" initial_tau: Starting temperature (high = random/exploration).\n",
|
| 169 |
+
" min_tau: Lowest temperature allowed (low = deterministic/exploitation).\n",
|
| 170 |
+
" decay_rate: How much to multiply tau by after every epoch.\n",
|
| 171 |
+
" hard: If True, uses Straight-Through Estimator (returns 0 or 1, but gradients flow).\n",
|
| 172 |
+
" \"\"\"\n",
|
| 173 |
+
" super().__init__()\n",
|
| 174 |
+
" self.initial_tau = initial_tau\n",
|
| 175 |
+
" self.min_tau = min_tau\n",
|
| 176 |
+
" self.decay_rate = decay_rate\n",
|
| 177 |
+
" self.hard = hard\n",
|
| 178 |
+
" \n",
|
| 179 |
+
" # Register tau as a buffer so it saves with model.state_dict() but isn't a trained parameter\n",
|
| 180 |
+
" self.register_buffer('tau', torch.tensor(initial_tau))\n",
|
| 181 |
+
" \n",
|
| 182 |
+
" # Project [Token_Emb; CLS_Emb] -> 1 scalar score per token\n",
|
| 183 |
+
" self.proj = nn.Linear(hidden_size * 2, 1)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
" def step_tau(self):\n",
|
| 186 |
+
" \"\"\"Call this at the end of every epoch to cool it down\"\"\"\n",
|
| 187 |
+
" new_tau = max(self.min_tau, self.tau * self.decay_rate)\n",
|
| 188 |
+
" self.tau.fill_(new_tau)\n",
|
| 189 |
+
" return self.tau.item()\n",
|
| 190 |
+
"\n",
|
| 191 |
+
" def forward(self, token_embeddings, cls_embedding, training=True):\n",
|
| 192 |
+
" \"\"\"\n",
|
| 193 |
+
" Returns:\n",
|
| 194 |
+
" probs: (B, L) - selected masks (0.0 to 1.0)\n",
|
| 195 |
+
" logits: (B, L) - raw scores\n",
|
| 196 |
+
" \"\"\"\n",
|
| 197 |
+
" B, L, H = token_embeddings.size()\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" # Create context aware embeddings: (B, L, 2*H)\n",
|
| 200 |
+
" cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)\n",
|
| 201 |
+
" x = torch.cat([token_embeddings, cls_exp], dim=-1)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" # Predict logits: (B, L)\n",
|
| 204 |
+
" logits = self.proj(x).squeeze(-1)\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" if training:\n",
|
| 207 |
+
" # 1. Generate Gumbel Noise\n",
|
| 208 |
+
" # noise = -log(-log(uniform(0,1)))\n",
|
| 209 |
+
" uniform_noise = torch.rand_like(logits)\n",
|
| 210 |
+
" gumbel_noise = -torch.log(-torch.log(uniform_noise + 1e-9) + 1e-9)\n",
|
| 211 |
+
" \n",
|
| 212 |
+
" # 2. Add noise and scale by temperature\n",
|
| 213 |
+
" y_soft = torch.sigmoid((logits + gumbel_noise) / self.tau)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
" if self.hard:\n",
|
| 216 |
+
" # 3. Straight-Through Estimator\n",
|
| 217 |
+
" # Forward pass is binary (0 or 1), Backward pass uses y_soft gradients\n",
|
| 218 |
+
" y_hard = (y_soft > 0.5).float()\n",
|
| 219 |
+
" # (y_hard - y_soft).detach() is 0 during backward pass\n",
|
| 220 |
+
" probs = y_hard - y_soft.detach() + y_soft\n",
|
| 221 |
+
" else:\n",
|
| 222 |
+
" probs = y_soft\n",
|
| 223 |
+
" else:\n",
|
| 224 |
+
" # During inference, just threshold the logits (deterministic)\n",
|
| 225 |
+
" if self.hard:\n",
|
| 226 |
+
" probs = (torch.sigmoid(logits) > 0.5).float()\n",
|
| 227 |
+
" else:\n",
|
| 228 |
+
" probs = torch.sigmoid(logits)\n",
|
| 229 |
+
"\n",
|
| 230 |
+
" return probs, logits"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"cell_type": "code",
|
| 235 |
+
"execution_count": null,
|
| 236 |
+
"metadata": {
|
| 237 |
+
"trusted": true
|
| 238 |
+
},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"class MultiScaleAttentionCNN(nn.Module):\n",
|
| 242 |
+
" def __init__(\n",
|
| 243 |
+
" self,\n",
|
| 244 |
+
" hidden_size=768,\n",
|
| 245 |
+
" num_filters=128,\n",
|
| 246 |
+
" kernel_sizes=(2, 3, 4),\n",
|
| 247 |
+
" dropout=0.3,\n",
|
| 248 |
+
" ):\n",
|
| 249 |
+
" super().__init__()\n",
|
| 250 |
+
" \n",
|
| 251 |
+
" self.convs = nn.ModuleList([\n",
|
| 252 |
+
" nn.Conv1d(hidden_size, num_filters, k)\n",
|
| 253 |
+
" for k in kernel_sizes\n",
|
| 254 |
+
" ])\n",
|
| 255 |
+
" \n",
|
| 256 |
+
" self.attention_fc = nn.Linear(num_filters, 1)\n",
|
| 257 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
| 258 |
+
" self.out_dim = num_filters * len(kernel_sizes)\n",
|
| 259 |
+
" \n",
|
| 260 |
+
" def forward(self, x, mask):\n",
|
| 261 |
+
" \"\"\"\n",
|
| 262 |
+
" x: (B, L, H)\n",
|
| 263 |
+
" mask: (B, L)\n",
|
| 264 |
+
" \"\"\"\n",
|
| 265 |
+
" x = x.transpose(1, 2) # (B, H, L)\n",
|
| 266 |
+
" feats = []\n",
|
| 267 |
+
" \n",
|
| 268 |
+
" for conv in self.convs:\n",
|
| 269 |
+
" h = F.relu(conv(x)) # (B, C, L')\n",
|
| 270 |
+
" h = h.transpose(1, 2) # (B, L', C)\n",
|
| 271 |
+
" \n",
|
| 272 |
+
" attn = self.attention_fc(h).squeeze(-1)\n",
|
| 273 |
+
" attn = attn.masked_fill(mask[:, :attn.size(1)] == 0, -1e9)\n",
|
| 274 |
+
" alpha = F.softmax(attn, dim=1)\n",
|
| 275 |
+
" \n",
|
| 276 |
+
" pooled = torch.sum(h * alpha.unsqueeze(-1), dim=1)\n",
|
| 277 |
+
" feats.append(pooled)\n",
|
| 278 |
+
" \n",
|
| 279 |
+
" out = torch.cat(feats, dim=1)\n",
|
| 280 |
+
" return self.dropout(out)"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "code",
|
| 285 |
+
"execution_count": null,
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"outputs": [],
|
| 288 |
+
"source": [
|
| 289 |
+
"class TemporalCNN(nn.Module):\n",
|
| 290 |
+
" \"\"\"\n",
|
| 291 |
+
" Standard TextCNN implementation (better for classification than Causal TCN).\n",
|
| 292 |
+
" Input: sequence_embeddings (B, L, H), attention_mask (B, L)\n",
|
| 293 |
+
" Output: pooled vector (B, output_dim) \n",
|
| 294 |
+
" \"\"\"\n",
|
| 295 |
+
" def __init__(self, input_dim=768, num_filters=256, kernel_sizes=(2,3,4), dropout=0.3):\n",
|
| 296 |
+
" super().__init__()\n",
|
| 297 |
+
" self.input_dim = input_dim\n",
|
| 298 |
+
" \n",
|
| 299 |
+
" # Output dim = filters * kernels * 2 (because of max + mean pooling)\n",
|
| 300 |
+
" self.out_dim = num_filters * len(kernel_sizes) * 2\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" # Convs expect (B, C_in, L) where C_in = input_dim\n",
|
| 303 |
+
" self.convs = nn.ModuleList([\n",
|
| 304 |
+
" nn.Conv1d(\n",
|
| 305 |
+
" in_channels=input_dim, \n",
|
| 306 |
+
" out_channels=num_filters, \n",
|
| 307 |
+
" kernel_size=k, \n",
|
| 308 |
+
" padding=k//2 # 'Same' style padding to look at neighbors on both sides\n",
|
| 309 |
+
" )\n",
|
| 310 |
+
" for k in kernel_sizes\n",
|
| 311 |
+
" ])\n",
|
| 312 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" def forward(self, sequence_embeddings, attention_mask=None):\n",
|
| 315 |
+
" \"\"\"\n",
|
| 316 |
+
" sequence_embeddings: (B, L, H)\n",
|
| 317 |
+
" attention_mask: (B, L)\n",
|
| 318 |
+
" \"\"\"\n",
|
| 319 |
+
" # transpose to (B, H, L) for Conv1d\n",
|
| 320 |
+
" x = sequence_embeddings.transpose(1, 2).contiguous()\n",
|
| 321 |
+
" \n",
|
| 322 |
+
" pooled_outputs = []\n",
|
| 323 |
+
" for conv in self.convs:\n",
|
| 324 |
+
" # Apply Convolution\n",
|
| 325 |
+
" conv_out = F.relu(conv(x)) # (B, num_filters, L_out)\n",
|
| 326 |
+
" \n",
|
| 327 |
+
" # Handle Masking for Pooling\n",
|
| 328 |
+
" if attention_mask is not None:\n",
|
| 329 |
+
" # Resize mask to match conv output length (in case of slight size change)\n",
|
| 330 |
+
" L_out = conv_out.size(2)\n",
|
| 331 |
+
" mask = attention_mask.float()\n",
|
| 332 |
+
" if mask.size(1) != L_out:\n",
|
| 333 |
+
" mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)\n",
|
| 334 |
+
" \n",
|
| 335 |
+
" # (B, 1, L_out) to broadcast over filters\n",
|
| 336 |
+
" mask_expanded = mask.unsqueeze(1).to(conv_out.device)\n",
|
| 337 |
+
" \n",
|
| 338 |
+
" # 1. Max Pooling with Mask (fill padded areas with -inf)\n",
|
| 339 |
+
" # We use a large negative number instead of true -inf to avoid NaNs\n",
|
| 340 |
+
" neg_val = -1e9\n",
|
| 341 |
+
" masked_for_max = torch.where(mask_expanded.bool(), conv_out, torch.tensor(neg_val, device=conv_out.device))\n",
|
| 342 |
+
" max_pooled = torch.max(masked_for_max, dim=2)[0]\n",
|
| 343 |
+
" \n",
|
| 344 |
+
" # 2. Mean Pooling with Mask (zero out padding, divide by valid length)\n",
|
| 345 |
+
" sum_masked = (conv_out * mask_expanded).sum(dim=2)\n",
|
| 346 |
+
" lengths = mask_expanded.sum(dim=2).clamp_min(1e-9)\n",
|
| 347 |
+
" mean_pooled = sum_masked / lengths\n",
|
| 348 |
+
" else:\n",
|
| 349 |
+
" max_pooled = torch.max(conv_out, dim=2)[0]\n",
|
| 350 |
+
" mean_pooled = conv_out.mean(dim=2)\n",
|
| 351 |
+
" \n",
|
| 352 |
+
" pooled_outputs.append(max_pooled)\n",
|
| 353 |
+
" pooled_outputs.append(mean_pooled)\n",
|
| 354 |
+
" \n",
|
| 355 |
+
" # Concatenate: (B, num_filters * len(kernels) * 2)\n",
|
| 356 |
+
" out = torch.cat(pooled_outputs, dim=1)\n",
|
| 357 |
+
" return self.dropout(out)"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"cell_type": "code",
|
| 362 |
+
"execution_count": null,
|
| 363 |
+
"metadata": {
|
| 364 |
+
"trusted": true
|
| 365 |
+
},
|
| 366 |
+
"outputs": [],
|
| 367 |
+
"source": [
|
| 368 |
+
"class ConcatModel(nn.Module):\n",
|
| 369 |
+
" def __init__(\n",
|
| 370 |
+
" self,\n",
|
| 371 |
+
" hatebert_model,\n",
|
| 372 |
+
" additional_model,\n",
|
| 373 |
+
" temporal_cnn,\n",
|
| 374 |
+
" msa_cnn,\n",
|
| 375 |
+
" selector,\n",
|
| 376 |
+
" projection_mlp,\n",
|
| 377 |
+
" unfreeze_n_layers_hate=12, #hatebert unfreeze all 12 layers\n",
|
| 378 |
+
" hate_pooler=True,\n",
|
| 379 |
+
" \n",
|
| 380 |
+
" ):\n",
|
| 381 |
+
" super().__init__()\n",
|
| 382 |
+
" \n",
|
| 383 |
+
" self.hatebert_model = hatebert_model\n",
|
| 384 |
+
" self.additional_model = additional_model\n",
|
| 385 |
+
" \n",
|
| 386 |
+
" self.temporal_cnn = temporal_cnn\n",
|
| 387 |
+
" self.msa_cnn = msa_cnn\n",
|
| 388 |
+
" self.selector = selector\n",
|
| 389 |
+
" self.projection_mlp = projection_mlp\n",
|
| 390 |
+
" \n",
|
| 391 |
+
" # freeze everything on additional bert\n",
|
| 392 |
+
" for p in self.additional_model.parameters():\n",
|
| 393 |
+
" p.requires_grad = False\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" # freeze everything on hatebert\n",
|
| 396 |
+
" for p in self.hatebert_model.parameters():\n",
|
| 397 |
+
" p.requires_grad = False\n",
|
| 398 |
+
" \n",
|
| 399 |
+
" # FIX: Use hatebert_model to get the correct number of layers\n",
|
| 400 |
+
" hate_num_layers = len(self.hatebert_model.encoder.layer)\n",
|
| 401 |
+
" for i in range(hate_num_layers - unfreeze_n_layers_hate, hate_num_layers):\n",
|
| 402 |
+
" for param in self.hatebert_model.encoder.layer[i].parameters():\n",
|
| 403 |
+
" param.requires_grad = True\n",
|
| 404 |
+
" \n",
|
| 405 |
+
" if self.hatebert_model.pooler is not None and hate_pooler:\n",
|
| 406 |
+
" for p in self.hatebert_model.pooler.parameters():\n",
|
| 407 |
+
" p.requires_grad = True\n",
|
| 408 |
+
"\n",
|
| 409 |
+
" \n",
|
| 410 |
+
" def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):\n",
|
| 411 |
+
" # ================= HateBERT =================\n",
|
| 412 |
+
" hate_outputs = self.hatebert_model(\n",
|
| 413 |
+
" input_ids=input_ids,\n",
|
| 414 |
+
" attention_mask=attention_mask,\n",
|
| 415 |
+
" )\n",
|
| 416 |
+
" seq_emb = hate_outputs.last_hidden_state # (B, L, H)\n",
|
| 417 |
+
" cls_emb = seq_emb[:, 0, :] # (B, H)\n",
|
| 418 |
+
" \n",
|
| 419 |
+
" # ---- Token Selector ----\n",
|
| 420 |
+
" token_probs, token_logits = self.selector(seq_emb, cls_emb, self.training)\n",
|
| 421 |
+
" \n",
|
| 422 |
+
" # ---- Temporal CNN on FULL embeddings (NOT masked) ----\n",
|
| 423 |
+
" temporal_feat = self.temporal_cnn(seq_emb, attention_mask)\n",
|
| 424 |
+
" \n",
|
| 425 |
+
" # ---- Rationale-Weighted Summary Vector H_r ----\n",
|
| 426 |
+
" weights = token_probs.unsqueeze(-1) # (B, L, 1)\n",
|
| 427 |
+
" H_r = (seq_emb * weights).sum(dim=1) / (weights.sum(dim=1) + 1e-6)\n",
|
| 428 |
+
" \n",
|
| 429 |
+
" # ================= Frozen Rationale BERT =================\n",
|
| 430 |
+
" add_outputs = self.additional_model(\n",
|
| 431 |
+
" input_ids=additional_input_ids,\n",
|
| 432 |
+
" attention_mask=additional_attention_mask,\n",
|
| 433 |
+
" )\n",
|
| 434 |
+
" add_seq = add_outputs.last_hidden_state\n",
|
| 435 |
+
" \n",
|
| 436 |
+
" # ---- Multi-Scale Attention CNN ----\n",
|
| 437 |
+
" msa_feat = self.msa_cnn(add_seq, additional_attention_mask)\n",
|
| 438 |
+
" \n",
|
| 439 |
+
" # ================= CONCAT (4 components) =================\n",
|
| 440 |
+
" concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)\n",
|
| 441 |
+
" \n",
|
| 442 |
+
" logits = self.projection_mlp(concat)\n",
|
| 443 |
+
" return logits"
|
| 444 |
+
]
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"cell_type": "code",
|
| 448 |
+
"execution_count": null,
|
| 449 |
+
"metadata": {
|
| 450 |
+
"trusted": true
|
| 451 |
+
},
|
| 452 |
+
"outputs": [],
|
| 453 |
+
"source": [
|
| 454 |
+
"class EarlyStopping:\n",
|
| 455 |
+
" \"\"\"\n",
|
| 456 |
+
" Early stopping to stop training when validation loss doesn't improve.\n",
|
| 457 |
+
" \"\"\"\n",
|
| 458 |
+
" def __init__(self, patience=10, min_delta=1.0, mode='min', verbose=True):\n",
|
| 459 |
+
" \"\"\"\n",
|
| 460 |
+
" Args:\n",
|
| 461 |
+
" patience (int): How many epochs to wait after last improvement.\n",
|
| 462 |
+
" min_delta (float): Minimum change to qualify as an improvement.\n",
|
| 463 |
+
" mode (str): 'min' for loss, 'max' for accuracy/f1.\n",
|
| 464 |
+
" verbose (bool): Print messages when improvement occurs.\n",
|
| 465 |
+
" \"\"\"\n",
|
| 466 |
+
" self.patience = patience\n",
|
| 467 |
+
" self.min_delta = min_delta\n",
|
| 468 |
+
" self.mode = mode\n",
|
| 469 |
+
" self.verbose = verbose\n",
|
| 470 |
+
" self.counter = 0\n",
|
| 471 |
+
" self.best_score = None\n",
|
| 472 |
+
" self.early_stop = False\n",
|
| 473 |
+
" self.best_model_state = None\n",
|
| 474 |
+
" \n",
|
| 475 |
+
" def __call__(self, current_score, model):\n",
|
| 476 |
+
" \"\"\"\n",
|
| 477 |
+
" Call this after each epoch with the validation metric and model.\n",
|
| 478 |
+
" \n",
|
| 479 |
+
" Args:\n",
|
| 480 |
+
" current_score: Current epoch's validation metric (loss, accuracy, f1, etc.)\n",
|
| 481 |
+
" model: The model to save if there's improvement\n",
|
| 482 |
+
" \n",
|
| 483 |
+
" Returns:\n",
|
| 484 |
+
" bool: True if training should stop, False otherwise\n",
|
| 485 |
+
" \"\"\"\n",
|
| 486 |
+
" if self.best_score is None:\n",
|
| 487 |
+
" # First epoch\n",
|
| 488 |
+
" self.best_score = current_score\n",
|
| 489 |
+
" self.save_checkpoint(model)\n",
|
| 490 |
+
" if self.verbose:\n",
|
| 491 |
+
" print(f\"Initial best score: {self.best_score:.4f}\")\n",
|
| 492 |
+
" else:\n",
|
| 493 |
+
" # Check if there's improvement\n",
|
| 494 |
+
" if self.mode == 'min':\n",
|
| 495 |
+
" improved = current_score < (self.best_score - self.min_delta)\n",
|
| 496 |
+
" else: # mode == 'max'\n",
|
| 497 |
+
" improved = current_score > (self.best_score + self.min_delta)\n",
|
| 498 |
+
" \n",
|
| 499 |
+
" if improved:\n",
|
| 500 |
+
" self.best_score = current_score\n",
|
| 501 |
+
" self.save_checkpoint(model)\n",
|
| 502 |
+
" self.counter = 0\n",
|
| 503 |
+
" if self.verbose:\n",
|
| 504 |
+
" print(f\"Validation improved! New best score: {self.best_score:.4f}\")\n",
|
| 505 |
+
" else:\n",
|
| 506 |
+
" self.counter += 1\n",
|
| 507 |
+
" if self.verbose:\n",
|
| 508 |
+
" print(f\"No improvement. Patience counter: {self.counter}/{self.patience}\")\n",
|
| 509 |
+
" \n",
|
| 510 |
+
" if self.counter >= self.patience:\n",
|
| 511 |
+
" self.early_stop = True\n",
|
| 512 |
+
" if self.verbose:\n",
|
| 513 |
+
" print(f\"Early stopping triggered! Best score: {self.best_score:.4f}\")\n",
|
| 514 |
+
" \n",
|
| 515 |
+
" return self.early_stop\n",
|
| 516 |
+
" \n",
|
| 517 |
+
" def save_checkpoint(self, model):\n",
|
| 518 |
+
" \"\"\"Save model state dict\"\"\"\n",
|
| 519 |
+
" import copy\n",
|
| 520 |
+
" self.best_model_state = copy.deepcopy(model.state_dict())\n",
|
| 521 |
+
" \n",
|
| 522 |
+
" def load_best_model(self, model):\n",
|
| 523 |
+
" \"\"\"Load the best model state into the model\"\"\"\n",
|
| 524 |
+
" if self.best_model_state is not None:\n",
|
| 525 |
+
" model.load_state_dict(self.best_model_state)\n",
|
| 526 |
+
" if self.verbose:\n",
|
| 527 |
+
" print(f\"Loaded best model with score: {self.best_score:.4f}\")\n",
|
| 528 |
+
" return model"
|
| 529 |
+
]
|
| 530 |
+
},
|
| 531 |
+
{
|
| 532 |
+
"cell_type": "code",
|
| 533 |
+
"execution_count": null,
|
| 534 |
+
"metadata": {
|
| 535 |
+
"trusted": true
|
| 536 |
+
},
|
| 537 |
+
"outputs": [],
|
| 538 |
+
"source": [
|
| 539 |
+
"def main(args):\n",
|
| 540 |
+
" torch.manual_seed(args.seed)\n",
|
| 541 |
+
" torch.cuda.empty_cache()\n",
|
| 542 |
+
" device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" file_map = {\n",
|
| 546 |
+
" \"gab\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_GAB_dataset(85-15).csv',\n",
|
| 547 |
+
" \"twitter\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_Twitter_dataset(85-15).csv',\n",
|
| 548 |
+
" \"reddit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_REDDIT_dataset(85-15).csv',\n",
|
| 549 |
+
" \"youtube\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_YOUTUBE_dataset(85-15).csv',\n",
|
| 550 |
+
" \"implicit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_IMPLICIT_dataset(85-15).csv'\n",
|
| 551 |
+
" }\n",
|
| 552 |
+
"\n",
|
| 553 |
+
" file_path = file_map[args.dataset]\n",
|
| 554 |
+
" df = pd.read_csv(file_path)\n",
|
| 555 |
+
" train_df = df[df['exp_split'] == 'train']\n",
|
| 556 |
+
" test_df = df[df['exp_split'] == 'test']\n",
|
| 557 |
+
"\n",
|
| 558 |
+
" print(\"Train df: \", len(train_df))\n",
|
| 559 |
+
" print(\"Test_df: \", len(test_df))\n",
|
| 560 |
+
"\n",
|
| 561 |
+
" import gc\n",
|
| 562 |
+
" # del variables\n",
|
| 563 |
+
" gc.collect()\n",
|
| 564 |
+
"\n",
|
| 565 |
+
" \n",
|
| 566 |
+
" tokenizer = args.hate_tokenizer ## need this for tokenizing the input text in data loader\n",
|
| 567 |
+
" tokenizer_bert = args.bert_tokenizer\n",
|
| 568 |
+
" #Splitting training and validation testing split to test accuracy\n",
|
| 569 |
+
" train_idx, val_idx = train_test_split(\n",
|
| 570 |
+
" train_df.index,\n",
|
| 571 |
+
" test_size=0.2,\n",
|
| 572 |
+
" stratify=train_df[\"label\"],\n",
|
| 573 |
+
" random_state=args.seed\n",
|
| 574 |
+
" )\n",
|
| 575 |
+
" \n",
|
| 576 |
+
" if args.dataset == \"implicit\":\n",
|
| 577 |
+
" train_text = train_df.loc[train_idx, \"post\"].tolist()\n",
|
| 578 |
+
" val_texts = train_df.loc[val_idx, \"post\"].tolist()\n",
|
| 579 |
+
" else:\n",
|
| 580 |
+
" train_text = train_df.loc[train_idx, \"text\"].tolist()\n",
|
| 581 |
+
" val_texts = train_df.loc[val_idx, \"text\"].tolist()\n",
|
| 582 |
+
" \n",
|
| 583 |
+
" add_train_text = train_df.loc[train_idx, \"Mistral_Rationales\"].tolist()\n",
|
| 584 |
+
" add_val_texts = train_df.loc[val_idx, \"Mistral_Rationales\"].tolist()\n",
|
| 585 |
+
" \n",
|
| 586 |
+
" train_labels = train_df.loc[train_idx, \"label\"].tolist()\n",
|
| 587 |
+
" val_labels = train_df.loc[val_idx, \"label\"].tolist()\n",
|
| 588 |
+
"\n",
|
| 589 |
+
" train_dataset = AdditionalCustomDataset(\n",
|
| 590 |
+
" train_text,\n",
|
| 591 |
+
" train_labels,\n",
|
| 592 |
+
" add_train_text,\n",
|
| 593 |
+
" tokenizer,\n",
|
| 594 |
+
" tokenizer_bert,\n",
|
| 595 |
+
" max_length=512\n",
|
| 596 |
+
" )\n",
|
| 597 |
+
" \n",
|
| 598 |
+
" val_dataset = AdditionalCustomDataset(\n",
|
| 599 |
+
" val_texts,\n",
|
| 600 |
+
" val_labels,\n",
|
| 601 |
+
" add_val_texts,\n",
|
| 602 |
+
" tokenizer,\n",
|
| 603 |
+
" tokenizer_bert,\n",
|
| 604 |
+
" max_length=512\n",
|
| 605 |
+
" )\n",
|
| 606 |
+
"\n",
|
| 607 |
+
" #Creating dataloader object to train the model\n",
|
| 608 |
+
" # num_workers=2 for parallel data loading, pin_memory=True for faster GPU transfer\n",
|
| 609 |
+
" train_dataloader = DataLoader(\n",
|
| 610 |
+
" train_dataset, \n",
|
| 611 |
+
" batch_size=args.batch_size, \n",
|
| 612 |
+
" shuffle=True,\n",
|
| 613 |
+
" num_workers=2,\n",
|
| 614 |
+
" pin_memory=True if torch.cuda.is_available() else False\n",
|
| 615 |
+
" )\n",
|
| 616 |
+
" val_dataloader = DataLoader(\n",
|
| 617 |
+
" val_dataset, \n",
|
| 618 |
+
" batch_size=args.batch_size, \n",
|
| 619 |
+
" shuffle=False,\n",
|
| 620 |
+
" num_workers=2,\n",
|
| 621 |
+
" pin_memory=True if torch.cuda.is_available() else False\n",
|
| 622 |
+
" )\n",
|
| 623 |
+
" hatebert_model = args.hatebert_model\n",
|
| 624 |
+
" additional_model = args.additional_model\n",
|
| 625 |
+
" \n",
|
| 626 |
+
" \n",
|
| 627 |
+
" temporal_cnn = TemporalCNN(\n",
|
| 628 |
+
" input_dim=768,\n",
|
| 629 |
+
" num_filters=args.temp_filters, # 64 - 128\n",
|
| 630 |
+
" kernel_sizes=(2, 3, 4),\n",
|
| 631 |
+
" dropout=args.temp_dropout,\n",
|
| 632 |
+
" ).to(device)\n",
|
| 633 |
+
"\n",
|
| 634 |
+
" msa_cnn = MultiScaleAttentionCNN(\n",
|
| 635 |
+
" hidden_size=768,\n",
|
| 636 |
+
" num_filters=args.msa_filters, # 64 - 128\n",
|
| 637 |
+
" kernel_sizes=(2, 3, 4),\n",
|
| 638 |
+
" dropout=args.msa_dropout\n",
|
| 639 |
+
" ).to(device)\n",
|
| 640 |
+
" \n",
|
| 641 |
+
" selector = AnnealedGumbelSelector(\n",
|
| 642 |
+
" hidden_size=768,\n",
|
| 643 |
+
" initial_tau=1.0, # Start hot (random exploration)\n",
|
| 644 |
+
" min_tau=0.1, # End cold (deterministic selection)\n",
|
| 645 |
+
" decay_rate=0.9, # Decrease by 10% every epoch\n",
|
| 646 |
+
" hard=True # Force binary 0/1 selections\n",
|
| 647 |
+
" ).to(device)\n",
|
| 648 |
+
" \n",
|
| 649 |
+
" projection_mlp = ProjectionMLP(\n",
|
| 650 |
+
" input_size=temporal_cnn.out_dim + msa_cnn.out_dim + 768 * 2,\n",
|
| 651 |
+
" output_size=512\n",
|
| 652 |
+
" ).to(device)\n",
|
| 653 |
+
"\n",
|
| 654 |
+
"\n",
|
| 655 |
+
"\n",
|
| 656 |
+
" concat_model = ConcatModel(\n",
|
| 657 |
+
" hatebert_model=hatebert_model,\n",
|
| 658 |
+
" additional_model=additional_model,\n",
|
| 659 |
+
" temporal_cnn=temporal_cnn,\n",
|
| 660 |
+
" msa_cnn=msa_cnn,\n",
|
| 661 |
+
" selector=selector,\n",
|
| 662 |
+
" projection_mlp=projection_mlp,\n",
|
| 663 |
+
" unfreeze_n_layers_hate=args.hate_layers, #hatebert unfreeze all 12 layers id 12 by default\n",
|
| 664 |
+
" hate_pooler=args.hate_pooler, #bool to controle if pooler is frozen or not true=not frozen\n",
|
| 665 |
+
" ).to(device)\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" optimizer = AdamW(\n",
|
| 668 |
+
" [\n",
|
| 669 |
+
" {'params': concat_model.hatebert_model.parameters(), 'lr': args.hatebert_lr},\n",
|
| 670 |
+
" {'params': concat_model.additional_model.parameters(), 'lr': args.additional_lr},\n",
|
| 671 |
+
" {'params': concat_model.temporal_cnn.parameters(), 'lr': args.temp_lr},\n",
|
| 672 |
+
" {'params': concat_model.msa_cnn.parameters(), 'lr': args.msa_lr},\n",
|
| 673 |
+
" {'params': concat_model.projection_mlp.parameters(), 'lr': args.proj_lr}\n",
|
| 674 |
+
" ],\n",
|
| 675 |
+
" weight_decay=args.wd\n",
|
| 676 |
+
" )\n",
|
| 677 |
+
" criterion = nn.CrossEntropyLoss().to(device)\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" # criterion = criterion.to(device)\n",
|
| 680 |
+
"\n",
|
| 681 |
+
" os.makedirs(\"/kaggle/working/models\", exist_ok=True)\n",
|
| 682 |
+
"\n",
|
| 683 |
+
" history = {\n",
|
| 684 |
+
" \"train_loss\": [],\n",
|
| 685 |
+
" \"val_loss\": [], #loss\n",
|
| 686 |
+
" \"train_acc\": [],\n",
|
| 687 |
+
" \"train_precision\": [],\n",
|
| 688 |
+
" \"train_recall\": [],\n",
|
| 689 |
+
" \"train_f1\": [], #train binary\n",
|
| 690 |
+
" \"val_acc\": [],\n",
|
| 691 |
+
" \"val_precision\": [],\n",
|
| 692 |
+
" \"val_recall\": [],\n",
|
| 693 |
+
" \"val_f1\": [], #validation binary\n",
|
| 694 |
+
" \"train_acc_weighted\": [],\n",
|
| 695 |
+
" \"train_precision_weighted\": [],\n",
|
| 696 |
+
" \"train_recall_weighted\": [],\n",
|
| 697 |
+
" \"train_f1_weighted\": [], # train weighted\n",
|
| 698 |
+
" \"val_acc_weighted\": [],\n",
|
| 699 |
+
" \"val_precision_weighted\": [],\n",
|
| 700 |
+
" \"val_recall_weighted\": [],\n",
|
| 701 |
+
" \"val_f1_weighted\": [], # validation weighted\n",
|
| 702 |
+
" \"train_acc_macro\": [],\n",
|
| 703 |
+
" \"train_precision_macro\": [],\n",
|
| 704 |
+
" \"train_recall_macro\": [],\n",
|
| 705 |
+
" \"train_f1_macro\": [], #train macro\n",
|
| 706 |
+
" \"val_acc_macro\": [],\n",
|
| 707 |
+
" \"val_precision_macro\": [],\n",
|
| 708 |
+
" \"val_recall_macro\": [],\n",
|
| 709 |
+
" \"val_f1_macro\": [],# validation macro\n",
|
| 710 |
+
" \"epoch_time\": [],\n",
|
| 711 |
+
" \"train_throughput\": [],\n",
|
| 712 |
+
" \"val_confidence_mean\": [],\n",
|
| 713 |
+
" \"val_confidence_std\": [],\n",
|
| 714 |
+
" \"gpu_memory_mb\": [],\n",
|
| 715 |
+
" }\n",
|
| 716 |
+
"\n",
|
| 717 |
+
" early_stopping = EarlyStopping(patience=args.patience, min_delta=0.001, mode='min', verbose=True) # early stop on loss\n",
|
| 718 |
+
"\n",
|
| 719 |
+
" for epoch in range(args.num_epochs):\n",
|
| 720 |
+
" epoch_val_confidences = []\n",
|
| 721 |
+
" epoch_start_time = time.time()\n",
|
| 722 |
+
" samples_seen = 0\n",
|
| 723 |
+
" \n",
|
| 724 |
+
" concat_model.train()\n",
|
| 725 |
+
"\n",
|
| 726 |
+
" train_losses = []\n",
|
| 727 |
+
" train_preds = []\n",
|
| 728 |
+
" train_labels_epoch = []\n",
|
| 729 |
+
" train_accuracy = 0\n",
|
| 730 |
+
" train_epoch_size = 0\n",
|
| 731 |
+
"\n",
|
| 732 |
+
" with tqdm(train_dataloader, desc=f'Epoch {epoch + 1}', dynamic_ncols=True) as loop:\n",
|
| 733 |
+
" for batch in loop:\n",
|
| 734 |
+
" input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n",
|
| 735 |
+
" \n",
|
| 736 |
+
" samples_seen += labels.size(0)\n",
|
| 737 |
+
" \n",
|
| 738 |
+
" if torch.cuda.is_available():\n",
|
| 739 |
+
" input_ids = input_ids.to(device)\n",
|
| 740 |
+
" attention_mask = attention_mask.to(device)\n",
|
| 741 |
+
" additional_input_ids = additional_input_ids.to(device)\n",
|
| 742 |
+
" additional_attention_mask = additional_attention_mask.to(device)\n",
|
| 743 |
+
" labels = labels.to(device)\n",
|
| 744 |
+
"\n",
|
| 745 |
+
" # Forward pass through the ConcatModel\n",
|
| 746 |
+
" optimizer.zero_grad()\n",
|
| 747 |
+
" outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n",
|
| 748 |
+
" loss = criterion(outputs, labels)\n",
|
| 749 |
+
"\n",
|
| 750 |
+
" # Backward pass and optimization\n",
|
| 751 |
+
" loss.backward()\n",
|
| 752 |
+
" torch.nn.utils.clip_grad_norm_(concat_model.parameters(), max_norm=args.max_grad_norm)\n",
|
| 753 |
+
" optimizer.step()\n",
|
| 754 |
+
" \n",
|
| 755 |
+
" probs = torch.softmax(outputs, dim=1)\n",
|
| 756 |
+
" confidences, predictions = torch.max(probs, dim=1)\n",
|
| 757 |
+
" train_preds.extend(predictions.cpu().numpy())\n",
|
| 758 |
+
" train_labels_epoch.extend(labels.cpu().numpy())\n",
|
| 759 |
+
"\n",
|
| 760 |
+
" train_losses.append(loss.item())\n",
|
| 761 |
+
"\n",
|
| 762 |
+
" # Update accuracy and epoch size\n",
|
| 763 |
+
" predictions = torch.argmax(outputs, dim=1)\n",
|
| 764 |
+
" train_accuracy += (predictions == labels).sum().item()\n",
|
| 765 |
+
" train_epoch_size += len(labels)\n",
|
| 766 |
+
" \n",
|
| 767 |
+
" epoch_train_time = time.time() - epoch_start_time\n",
|
| 768 |
+
" train_throughput = samples_seen / epoch_train_time \n",
|
| 769 |
+
" \n",
|
| 770 |
+
" # Calculate train metrics (binary, weighted, macro)\n",
|
| 771 |
+
" train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(\n",
|
| 772 |
+
" train_labels_epoch, train_preds, average='binary'\n",
|
| 773 |
+
" )\n",
|
| 774 |
+
" train_precision_weighted, train_recall_weighted, train_f1_weighted, _ = precision_recall_fscore_support(\n",
|
| 775 |
+
" train_labels_epoch, train_preds, average='weighted'\n",
|
| 776 |
+
" )\n",
|
| 777 |
+
" train_precision_macro, train_recall_macro, train_f1_macro, _ = precision_recall_fscore_support(\n",
|
| 778 |
+
" train_labels_epoch, train_preds, average='macro'\n",
|
| 779 |
+
" )\n",
|
| 780 |
+
" train_acc = accuracy_score(train_labels_epoch, train_preds)\n",
|
| 781 |
+
"\n",
|
| 782 |
+
" # Evaluation on the validation set\n",
|
| 783 |
+
" concat_model.eval()\n",
|
| 784 |
+
"\n",
|
| 785 |
+
" val_predictions = []\n",
|
| 786 |
+
" val_labels_epoch = []\n",
|
| 787 |
+
" val_loss = 0\n",
|
| 788 |
+
" num_batches = 0\n",
|
| 789 |
+
"\n",
|
| 790 |
+
" with torch.no_grad(), tqdm(val_dataloader, desc='Validation', dynamic_ncols=True) as loop:\n",
|
| 791 |
+
" for batch in loop:\n",
|
| 792 |
+
" input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n",
|
| 793 |
+
"\n",
|
| 794 |
+
" if torch.cuda.is_available():\n",
|
| 795 |
+
" input_ids = input_ids.to(device)\n",
|
| 796 |
+
" attention_mask = attention_mask.to(device)\n",
|
| 797 |
+
" additional_input_ids = additional_input_ids.to(device)\n",
|
| 798 |
+
" additional_attention_mask = additional_attention_mask.to(device)\n",
|
| 799 |
+
" labels = labels.to(device)\n",
|
| 800 |
+
"\n",
|
| 801 |
+
" # Forward pass through the ConcatModel\n",
|
| 802 |
+
" outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n",
|
| 803 |
+
" loss = criterion(outputs, labels)\n",
|
| 804 |
+
" val_loss += loss.item()\n",
|
| 805 |
+
" num_batches += 1\n",
|
| 806 |
+
" probs = torch.softmax(outputs, dim=1)\n",
|
| 807 |
+
" confidences, predictions = torch.max(probs, dim=1)\n",
|
| 808 |
+
"\n",
|
| 809 |
+
" epoch_val_confidences.extend(confidences.cpu().numpy())\n",
|
| 810 |
+
" \n",
|
| 811 |
+
" val_predictions.extend(predictions.cpu().numpy())\n",
|
| 812 |
+
" val_labels_epoch.extend(labels.cpu().numpy())\n",
|
| 813 |
+
" \n",
|
| 814 |
+
" val_loss /= num_batches\n",
|
| 815 |
+
" \n",
|
| 816 |
+
" # Calculate validation metrics (binary)\n",
|
| 817 |
+
" val_accuracy = accuracy_score(val_labels_epoch, val_predictions)\n",
|
| 818 |
+
" val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(\n",
|
| 819 |
+
" val_labels_epoch, val_predictions, average='binary'\n",
|
| 820 |
+
" )\n",
|
| 821 |
+
" # Calculate validation metrics (weighted)\n",
|
| 822 |
+
" val_precision_weighted, val_recall_weighted, val_f1_weighted, _ = precision_recall_fscore_support(\n",
|
| 823 |
+
" val_labels_epoch, val_predictions, average='weighted'\n",
|
| 824 |
+
" )\n",
|
| 825 |
+
" # Calculate validation metrics (macro)\n",
|
| 826 |
+
" val_precision_macro, val_recall_macro, val_f1_macro, _ = precision_recall_fscore_support(\n",
|
| 827 |
+
" val_labels_epoch, val_predictions, average='macro'\n",
|
| 828 |
+
" )\n",
|
| 829 |
+
" \n",
|
| 830 |
+
" print(f\"Epoch {epoch}:\")\n",
|
| 831 |
+
" print(f\" Train Accuracy: {train_acc:.4f}\")\n",
|
| 832 |
+
" print(f\" Validation Accuracy: {val_accuracy:.4f}\")\n",
|
| 833 |
+
" print(f\" Train Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}\")\n",
|
| 834 |
+
" print(f\" Val Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}\")\n",
|
| 835 |
+
" print(f\" Avg. Train Loss: {sum(train_losses) / len(train_losses):.4f}\")\n",
|
| 836 |
+
" print(f\" Validation Loss: {val_loss:.4f}\")\n",
|
| 837 |
+
" epoch_time = time.time() - epoch_start_time\n",
|
| 838 |
+
" conf_mean = np.mean(epoch_val_confidences)\n",
|
| 839 |
+
" conf_std = np.std(epoch_val_confidences)\n",
|
| 840 |
+
"\n",
|
| 841 |
+
" # Append all train metrics to history\n",
|
| 842 |
+
" history[\"train_loss\"].append(np.mean(train_losses))\n",
|
| 843 |
+
" history[\"train_acc\"].append(train_acc)\n",
|
| 844 |
+
" history[\"train_precision\"].append(train_precision)\n",
|
| 845 |
+
" history[\"train_recall\"].append(train_recall)\n",
|
| 846 |
+
" history[\"train_f1\"].append(train_f1)\n",
|
| 847 |
+
" history[\"train_acc_weighted\"].append(train_acc) # accuracy is same for all averaging\n",
|
| 848 |
+
" history[\"train_precision_weighted\"].append(train_precision_weighted)\n",
|
| 849 |
+
" history[\"train_recall_weighted\"].append(train_recall_weighted)\n",
|
| 850 |
+
" history[\"train_f1_weighted\"].append(train_f1_weighted)\n",
|
| 851 |
+
" history[\"train_acc_macro\"].append(train_acc) # accuracy is same for all averaging\n",
|
| 852 |
+
" history[\"train_precision_macro\"].append(train_precision_macro)\n",
|
| 853 |
+
" history[\"train_recall_macro\"].append(train_recall_macro)\n",
|
| 854 |
+
" history[\"train_f1_macro\"].append(train_f1_macro)\n",
|
| 855 |
+
" \n",
|
| 856 |
+
" # Append all validation metrics to history\n",
|
| 857 |
+
" history[\"val_loss\"].append(val_loss)\n",
|
| 858 |
+
" history[\"val_acc\"].append(val_accuracy)\n",
|
| 859 |
+
" history[\"val_precision\"].append(val_precision)\n",
|
| 860 |
+
" history[\"val_recall\"].append(val_recall)\n",
|
| 861 |
+
" history[\"val_f1\"].append(val_f1)\n",
|
| 862 |
+
" history[\"val_acc_weighted\"].append(val_accuracy) # accuracy is same for all averaging\n",
|
| 863 |
+
" history[\"val_precision_weighted\"].append(val_precision_weighted)\n",
|
| 864 |
+
" history[\"val_recall_weighted\"].append(val_recall_weighted)\n",
|
| 865 |
+
" history[\"val_f1_weighted\"].append(val_f1_weighted)\n",
|
| 866 |
+
" history[\"val_acc_macro\"].append(val_accuracy) # accuracy is same for all averaging\n",
|
| 867 |
+
" history[\"val_precision_macro\"].append(val_precision_macro)\n",
|
| 868 |
+
" history[\"val_recall_macro\"].append(val_recall_macro)\n",
|
| 869 |
+
" history[\"val_f1_macro\"].append(val_f1_macro)\n",
|
| 870 |
+
" \n",
|
| 871 |
+
" # Append efficiency metrics\n",
|
| 872 |
+
" history[\"epoch_time\"].append(epoch_time)\n",
|
| 873 |
+
" history[\"train_throughput\"].append(train_throughput)\n",
|
| 874 |
+
" history[\"val_confidence_mean\"].append(conf_mean)\n",
|
| 875 |
+
" history[\"val_confidence_std\"].append(conf_std)\n",
|
| 876 |
+
" \n",
|
| 877 |
+
" if torch.cuda.is_available():\n",
|
| 878 |
+
" history[\"gpu_memory_mb\"].append(\n",
|
| 879 |
+
" torch.cuda.max_memory_allocated() / 1024**2\n",
|
| 880 |
+
" )\n",
|
| 881 |
+
" torch.cuda.reset_peak_memory_stats()\n",
|
| 882 |
+
" else:\n",
|
| 883 |
+
" history[\"gpu_memory_mb\"].append(0)\n",
|
| 884 |
+
" \n",
|
| 885 |
+
" print(f\" Epoch Time (s): {epoch_time:.2f}\")\n",
|
| 886 |
+
" print(f\" Throughput (samples/sec): {train_throughput:.2f}\")\n",
|
| 887 |
+
" print(f\" Val Confidence Mean: {conf_mean:.4f} Β± {conf_std:.4f}\")\n",
|
| 888 |
+
"\n",
|
| 889 |
+
" new_tau = concat_model.selector.step_tau()\n",
|
| 890 |
+
" print(f\" Updated Gumbel Tau: {new_tau:.4f}\")\n",
|
| 891 |
+
" \n",
|
| 892 |
+
" current_metric = val_loss\n",
|
| 893 |
+
" if early_stopping(current_metric, concat_model):\n",
|
| 894 |
+
" print(f\"\\n{'='*50}\")\n",
|
| 895 |
+
" print(f\"Early stopping at epoch {epoch+1}\")\n",
|
| 896 |
+
" print(f\"{'='*50}\\n\")\n",
|
| 897 |
+
" break\n",
|
| 898 |
+
" model = early_stopping.load_best_model(concat_model)\n",
|
| 899 |
+
" torch.save(model.state_dict(), f\"/kaggle/working/models/{args.dataset}_concat_model.pt\")\n",
|
| 900 |
+
" print(f\"Best model saved to /kaggle/working/models/{args.dataset}_concat_model.pt\")\n",
|
| 901 |
+
"\n",
|
| 902 |
+
" checkpoint = {\n",
|
| 903 |
+
" \"history\": history,\n",
|
| 904 |
+
" }\n",
|
| 905 |
+
" torch.save(checkpoint, f\"/kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n",
|
| 906 |
+
" print(f\"Checkpoint with history saved to /kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n",
|
| 907 |
+
"\n",
|
| 908 |
+
" if args.dataset == \"implicit\":\n",
|
| 909 |
+
" test_texts = test_df[\"post\"].tolist()\n",
|
| 910 |
+
" else:\n",
|
| 911 |
+
" test_texts = test_df[\"text\"].tolist()\n",
|
| 912 |
+
" \n",
|
| 913 |
+
" add_test_texts = test_df[\"Mistral_Rationales\"].tolist()\n",
|
| 914 |
+
" test_labels = test_df[\"label\"].tolist()\n",
|
| 915 |
+
"\n",
|
| 916 |
+
" test_dataset = AdditionalCustomDataset(test_texts, test_labels, add_test_texts, tokenizer, tokenizer_bert, max_length=512)\n",
|
| 917 |
+
" test_dataloader = DataLoader(\n",
|
| 918 |
+
" test_dataset, \n",
|
| 919 |
+
" batch_size=args.batch_size, # Use same batch size as training for faster inference\n",
|
| 920 |
+
" shuffle=False,\n",
|
| 921 |
+
" num_workers=2,\n",
|
| 922 |
+
" pin_memory=True if torch.cuda.is_available() else False\n",
|
| 923 |
+
" )\n",
|
| 924 |
+
"\n",
|
| 925 |
+
" # ================= TEST EVALUATION WITH EFFICIENCY =================\n",
|
| 926 |
+
" model.eval()\n",
|
| 927 |
+
" test_predictions = []\n",
|
| 928 |
+
" test_true_labels = []\n",
|
| 929 |
+
" test_confidences = []\n",
|
| 930 |
+
"\n",
|
| 931 |
+
" samples_seen = 0\n",
|
| 932 |
+
" test_start_time = time.time()\n",
|
| 933 |
+
" \n",
|
| 934 |
+
" if torch.cuda.is_available():\n",
|
| 935 |
+
" torch.cuda.reset_peak_memory_stats()\n",
|
| 936 |
+
" \n",
|
| 937 |
+
" with torch.no_grad(), tqdm(test_dataloader, desc='Testing', dynamic_ncols=True) as loop:\n",
|
| 938 |
+
" for batch in loop:\n",
|
| 939 |
+
" input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n",
|
| 940 |
+
" \n",
|
| 941 |
+
" batch_size = labels.size(0)\n",
|
| 942 |
+
" samples_seen += batch_size\n",
|
| 943 |
+
" \n",
|
| 944 |
+
" input_ids = input_ids.to(device)\n",
|
| 945 |
+
" attention_mask = attention_mask.to(device)\n",
|
| 946 |
+
" additional_input_ids = additional_input_ids.to(device)\n",
|
| 947 |
+
" additional_attention_mask = additional_attention_mask.to(device)\n",
|
| 948 |
+
" labels = labels.to(device)\n",
|
| 949 |
+
" \n",
|
| 950 |
+
" outputs = model(\n",
|
| 951 |
+
" input_ids=input_ids,\n",
|
| 952 |
+
" attention_mask=attention_mask,\n",
|
| 953 |
+
" additional_input_ids=additional_input_ids,\n",
|
| 954 |
+
" additional_attention_mask=additional_attention_mask\n",
|
| 955 |
+
" )\n",
|
| 956 |
+
" \n",
|
| 957 |
+
" probs = torch.softmax(outputs, dim=1)\n",
|
| 958 |
+
" confidences, preds = torch.max(probs, dim=1)\n",
|
| 959 |
+
" \n",
|
| 960 |
+
" test_confidences.extend(confidences.cpu().numpy())\n",
|
| 961 |
+
" test_predictions.extend(preds.cpu().numpy())\n",
|
| 962 |
+
" test_true_labels.extend(labels.cpu().numpy())\n",
|
| 963 |
+
"\n",
|
| 964 |
+
" \n",
|
| 965 |
+
" # ================= TEST METRICS =================\n",
|
| 966 |
+
" test_time = time.time() - test_start_time\n",
|
| 967 |
+
" test_throughput = samples_seen / test_time\n",
|
| 968 |
+
" \n",
|
| 969 |
+
" accuracy = accuracy_score(test_true_labels, test_predictions)\n",
|
| 970 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(\n",
|
| 971 |
+
" test_true_labels, test_predictions, average='weighted'\n",
|
| 972 |
+
" )\n",
|
| 973 |
+
" conf_mean = np.mean(test_confidences)\n",
|
| 974 |
+
" conf_std = np.std(test_confidences)\n",
|
| 975 |
+
" cm = confusion_matrix(test_true_labels, test_predictions)\n",
|
| 976 |
+
" \n",
|
| 977 |
+
" gpu_memory_mb = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0\n",
|
| 978 |
+
"\n",
|
| 979 |
+
" print(\"\\n================= FINAL TEST RESULTS =================\")\n",
|
| 980 |
+
" print(f\"Dataset: {args.dataset}, Seed: {args.seed}, Epochs: {args.num_epochs}\")\n",
|
| 981 |
+
" print(f\"Test Accuracy : {accuracy:.4f}\")\n",
|
| 982 |
+
" print(f\"Test Precision: {precision:.4f}\")\n",
|
| 983 |
+
" print(f\"Test Recall : {recall:.4f}\")\n",
|
| 984 |
+
" print(f\"Test F1-score : {f1:.4f}\")\n",
|
| 985 |
+
" print(f\"Test Confidence Mean Β± Std: {conf_mean:.4f} Β± {conf_std:.4f}\")\n",
|
| 986 |
+
" print(f\"Test Time (s) : {test_time:.2f}\")\n",
|
| 987 |
+
" print(f\"Throughput (samples/sec) : {test_throughput:.2f}\")\n",
|
| 988 |
+
" print(f\"Peak GPU Memory (MB) : {gpu_memory_mb:.2f}\")\n",
|
| 989 |
+
" print(\"\\nClassification Report:\")\n",
|
| 990 |
+
" print(classification_report(test_true_labels, test_predictions))\n",
|
| 991 |
+
" print(\"\\nConfusion Matrix:\")\n",
|
| 992 |
+
" # Plot confusion matrix with seaborn\n",
|
| 993 |
+
" plt.figure(figsize=(8, 6))\n",
|
| 994 |
+
" sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n",
|
| 995 |
+
" xticklabels=['Non-Hate', 'Hate'],\n",
|
| 996 |
+
" yticklabels=['Non-Hate', 'Hate'],\n",
|
| 997 |
+
" cbar_kws={'label': 'Count'})\n",
|
| 998 |
+
" plt.title('Confusion Matrix', fontsize=14, pad=20)\n",
|
| 999 |
+
" plt.ylabel('True label', fontsize=12)\n",
|
| 1000 |
+
" plt.xlabel('Predicted label', fontsize=12)\n",
|
| 1001 |
+
" plt.tight_layout()\n",
|
| 1002 |
+
" plt.show()\n",
|
| 1003 |
+
" print(\"======================================================\")\n",
|
| 1004 |
+
"\n",
|
| 1005 |
+
" return model, history"
|
| 1006 |
+
]
|
| 1007 |
+
},
|
| 1008 |
+
{
|
| 1009 |
+
"cell_type": "code",
|
| 1010 |
+
"execution_count": null,
|
| 1011 |
+
"metadata": {
|
| 1012 |
+
"trusted": true
|
| 1013 |
+
},
|
| 1014 |
+
"outputs": [],
|
| 1015 |
+
"source": [
|
| 1016 |
+
"from argparse import Namespace\n",
|
| 1017 |
+
"from transformers import AutoTokenizer, AutoModel\n",
|
| 1018 |
+
"\n",
|
| 1019 |
+
"# Load tokenizers and models ONCE (outside objective to save time)\n",
|
| 1020 |
+
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
| 1021 |
+
"hate_tokenizer = AutoTokenizer.from_pretrained(\"GroNLP/hateBERT\")\n",
|
| 1022 |
+
"bert_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
|
| 1023 |
+
"\n",
|
| 1024 |
+
"torch.cuda.empty_cache()\n",
|
| 1025 |
+
"\n",
|
| 1026 |
+
"# Load fresh models for each trial (to reset weights)\n",
|
| 1027 |
+
"hatebert_model = AutoModel.from_pretrained(\"GroNLP/hateBERT\").to(device)\n",
|
| 1028 |
+
"additional_model = AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
|
| 1029 |
+
"\n",
|
| 1030 |
+
"# Sample hyperparameters\n",
|
| 1031 |
+
"args = Namespace(\n",
|
| 1032 |
+
" # Fixed\n",
|
| 1033 |
+
" seed=42,\n",
|
| 1034 |
+
" dataset=\"reddit\", # Change as needed: \"gab\", \"twitter\", \"reddit\", \"youtube\", \"implicit\"\n",
|
| 1035 |
+
"\n",
|
| 1036 |
+
" # Tokenizers & Models\n",
|
| 1037 |
+
" hate_tokenizer=hate_tokenizer,\n",
|
| 1038 |
+
" bert_tokenizer=bert_tokenizer,\n",
|
| 1039 |
+
" hatebert_model=hatebert_model,\n",
|
| 1040 |
+
" additional_model=additional_model,\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" # Training hyperparameters\n",
|
| 1043 |
+
" batch_size=32,\n",
|
| 1044 |
+
" num_epochs=20,\n",
|
| 1045 |
+
" wd=0.01, # Weight decay\n",
|
| 1046 |
+
" patience=5, # Early stopping patience\n",
|
| 1047 |
+
" max_grad_norm=2.0, # Gradient clipping\n",
|
| 1048 |
+
"\n",
|
| 1049 |
+
" # Learning Rates (Missing but required by optimizer in main)\n",
|
| 1050 |
+
" hatebert_lr=2e-5, # Lower LR for pre-trained models\n",
|
| 1051 |
+
" additional_lr=2e-5, # Usually kept low or 0 if frozen\n",
|
| 1052 |
+
" temp_lr=1e-3, # Higher LR for initialized layers (CNN)\n",
|
| 1053 |
+
" msa_lr=1e-3, # Higher LR for initialized layers (CNN)\n",
|
| 1054 |
+
" proj_lr=1e-3, # Higher LR for Head (MLP)\n",
|
| 1055 |
+
"\n",
|
| 1056 |
+
" # TemporalCNN hyperparameters\n",
|
| 1057 |
+
" temp_filters=256,\n",
|
| 1058 |
+
" temp_dropout=0.1,\n",
|
| 1059 |
+
" temp_dilate=3, # Note: This is in args, but TemporalCNN class doesn't use it. \n",
|
| 1060 |
+
" # Code fix required in main() or TemporalCNN class.\n",
|
| 1061 |
+
"\n",
|
| 1062 |
+
" # MultiScaleAttentionCNN hyperparameters\n",
|
| 1063 |
+
" msa_filters=64,\n",
|
| 1064 |
+
" msa_dropout=0.23,\n",
|
| 1065 |
+
"\n",
|
| 1066 |
+
" # Layer unfreezing hyperparameters\n",
|
| 1067 |
+
" hate_layers=10, # How many layers of HateBERT to unfreeze\n",
|
| 1068 |
+
" add_layers=0, # How many layers of Add. BERT to unfreeze\n",
|
| 1069 |
+
" hate_pooler=True,\n",
|
| 1070 |
+
" add_pooler=True,\n",
|
| 1071 |
+
") \n",
|
| 1072 |
+
"model, history = main(args)\n",
|
| 1073 |
+
"\n",
|
| 1074 |
+
"\n",
|
| 1075 |
+
" "
|
| 1076 |
+
]
|
| 1077 |
+
},
|
| 1078 |
+
{
|
| 1079 |
+
"cell_type": "code",
|
| 1080 |
+
"execution_count": null,
|
| 1081 |
+
"metadata": {},
|
| 1082 |
+
"outputs": [],
|
| 1083 |
+
"source": [
|
| 1084 |
+
"import optuna\n",
|
| 1085 |
+
"from optuna.pruners import MedianPruner\n",
|
| 1086 |
+
"\n",
|
| 1087 |
+
"# Load tokenizers and models ONCE (outside objective to save time)\n",
|
| 1088 |
+
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
| 1089 |
+
"hate_tokenizer = AutoTokenizer.from_pretrained(\"GroNLP/hateBERT\")\n",
|
| 1090 |
+
"bert_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
|
| 1091 |
+
"\n",
|
| 1092 |
+
"def objective(trial):\n",
|
| 1093 |
+
" \"\"\"\n",
|
| 1094 |
+
" Optuna objective function that samples hyperparameters and trains the model.\n",
|
| 1095 |
+
" Returns: tuple of (validation_loss, validation_accuracy)\n",
|
| 1096 |
+
" - Minimize validation loss\n",
|
| 1097 |
+
" - Maximize validation accuracy\n",
|
| 1098 |
+
" \"\"\"\n",
|
| 1099 |
+
" # Clear GPU cache before each trial\n",
|
| 1100 |
+
" torch.cuda.empty_cache()\n",
|
| 1101 |
+
" \n",
|
| 1102 |
+
" # Load fresh models for each trial (to reset weights)\n",
|
| 1103 |
+
" hatebert_model = AutoModel.from_pretrained(\"GroNLP/hateBERT\").to(device)\n",
|
| 1104 |
+
" additional_model = AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
|
| 1105 |
+
" \n",
|
| 1106 |
+
" # ========== SAMPLE HYPERPARAMETERS USING OPTUNA ==========\n",
|
| 1107 |
+
" # Training hyperparameters\n",
|
| 1108 |
+
" batch_size = trial.suggest_int(\"batch_size\", 16, 64, step=16) # 16, 32, 48, 64\n",
|
| 1109 |
+
" num_epochs = trial.suggest_int(\"num_epochs\", 15, 30, step=5) # 15, 20, 25, 30\n",
|
| 1110 |
+
" weight_decay = trial.suggest_float(\"weight_decay\", 1e-4, 1e-1, log=True)\n",
|
| 1111 |
+
" max_grad_norm = trial.suggest_float(\"max_grad_norm\", 0.5, 3.0)\n",
|
| 1112 |
+
" patience = trial.suggest_int(\"patience\", 3, 8)\n",
|
| 1113 |
+
" \n",
|
| 1114 |
+
" # Learning Rates (use log scale for better sampling across orders of magnitude)\n",
|
| 1115 |
+
" hatebert_lr = trial.suggest_float(\"hatebert_lr\", 1e-6, 5e-5, log=True)\n",
|
| 1116 |
+
" additional_lr = trial.suggest_float(\"additional_lr\", 1e-6, 5e-5, log=True)\n",
|
| 1117 |
+
" temp_lr = trial.suggest_float(\"temp_lr\", 1e-5, 1e-2, log=True)\n",
|
| 1118 |
+
" msa_lr = trial.suggest_float(\"msa_lr\", 1e-5, 1e-2, log=True)\n",
|
| 1119 |
+
" proj_lr = trial.suggest_float(\"proj_lr\", 1e-5, 1e-2, log=True)\n",
|
| 1120 |
+
" \n",
|
| 1121 |
+
" # TemporalCNN hyperparameters\n",
|
| 1122 |
+
" temp_filters = trial.suggest_int(\"temp_filters\", 128, 512, step=64) # 128, 192, 256, ...\n",
|
| 1123 |
+
" temp_dropout = trial.suggest_float(\"temp_dropout\", 0.0, 0.5, step=0.05) # 0.0, 0.05, 0.1, ...\n",
|
| 1124 |
+
" kernel_size_2_weight = trial.suggest_float(\"kernel_size_weight\", 0.3, 1.0) # Weight for inclusion\n",
|
| 1125 |
+
" \n",
|
| 1126 |
+
" # MultiScaleAttentionCNN hyperparameters\n",
|
| 1127 |
+
" msa_filters = trial.suggest_int(\"msa_filters\", 32, 256, step=32) # 32, 64, 96, 128, ...\n",
|
| 1128 |
+
" msa_dropout = trial.suggest_float(\"msa_dropout\", 0.0, 0.5, step=0.05)\n",
|
| 1129 |
+
" \n",
|
| 1130 |
+
" # Layer unfreezing hyperparameters\n",
|
| 1131 |
+
" hate_layers = trial.suggest_int(\"hate_layers\", 4, 12) # How many layers to unfreeze (0-12)\n",
|
| 1132 |
+
" hate_pooler = trial.suggest_categorical(\"hate_pooler\", [True, False])\n",
|
| 1133 |
+
" \n",
|
| 1134 |
+
" # Create args Namespace with sampled hyperparameters\n",
|
| 1135 |
+
" args = Namespace(\n",
|
| 1136 |
+
" # Fixed\n",
|
| 1137 |
+
" seed=42,\n",
|
| 1138 |
+
" dataset=\"reddit\", # Change as needed\n",
|
| 1139 |
+
" \n",
|
| 1140 |
+
" # Tokenizers & Models\n",
|
| 1141 |
+
" hate_tokenizer=hate_tokenizer,\n",
|
| 1142 |
+
" bert_tokenizer=bert_tokenizer,\n",
|
| 1143 |
+
" hatebert_model=hatebert_model,\n",
|
| 1144 |
+
" additional_model=additional_model,\n",
|
| 1145 |
+
" \n",
|
| 1146 |
+
" # Sampled training hyperparameters\n",
|
| 1147 |
+
" batch_size=batch_size,\n",
|
| 1148 |
+
" num_epochs=num_epochs,\n",
|
| 1149 |
+
" wd=weight_decay,\n",
|
| 1150 |
+
" patience=patience,\n",
|
| 1151 |
+
" max_grad_norm=max_grad_norm,\n",
|
| 1152 |
+
" \n",
|
| 1153 |
+
" # Sampled learning rates\n",
|
| 1154 |
+
" hatebert_lr=hatebert_lr,\n",
|
| 1155 |
+
" additional_lr=additional_lr,\n",
|
| 1156 |
+
" temp_lr=temp_lr,\n",
|
| 1157 |
+
" msa_lr=msa_lr,\n",
|
| 1158 |
+
" proj_lr=proj_lr,\n",
|
| 1159 |
+
" \n",
|
| 1160 |
+
" # Sampled TemporalCNN hyperparameters\n",
|
| 1161 |
+
" temp_filters=temp_filters,\n",
|
| 1162 |
+
" temp_dropout=temp_dropout,\n",
|
| 1163 |
+
" temp_dilate=3, # Fixed (not optimized)\n",
|
| 1164 |
+
" \n",
|
| 1165 |
+
" # Sampled MultiScaleAttentionCNN hyperparameters\n",
|
| 1166 |
+
" msa_filters=msa_filters,\n",
|
| 1167 |
+
" msa_dropout=msa_dropout,\n",
|
| 1168 |
+
" \n",
|
| 1169 |
+
" # Sampled layer unfreezing hyperparameters\n",
|
| 1170 |
+
" hate_layers=hate_layers,\n",
|
| 1171 |
+
" add_layers=0, # Fixed (frozen)\n",
|
| 1172 |
+
" hate_pooler=hate_pooler,\n",
|
| 1173 |
+
" add_pooler=False, # Fixed\n",
|
| 1174 |
+
" )\n",
|
| 1175 |
+
" \n",
|
| 1176 |
+
" try:\n",
|
| 1177 |
+
" print(f\"\\n{'='*60}\")\n",
|
| 1178 |
+
" print(f\"Trial {trial.number} started\")\n",
|
| 1179 |
+
" print(f\"Batch Size: {batch_size}, Epochs: {num_epochs}, Patience: {patience}\")\n",
|
| 1180 |
+
" print(f\"HateBERT LR: {hatebert_lr:.2e}, Temp LR: {temp_lr:.2e}, MSA LR: {msa_lr:.2e}\")\n",
|
| 1181 |
+
" print(f\"Temp Filters: {temp_filters}, MSA Filters: {msa_filters}\")\n",
|
| 1182 |
+
" print(f\"{'='*60}\\n\")\n",
|
| 1183 |
+
" \n",
|
| 1184 |
+
" model, history = main(args)\n",
|
| 1185 |
+
" \n",
|
| 1186 |
+
" # Get best validation loss and corresponding accuracy\n",
|
| 1187 |
+
" val_loss = min(history[\"val_loss\"])\n",
|
| 1188 |
+
" val_loss_index = history[\"val_loss\"].index(val_loss)\n",
|
| 1189 |
+
" val_acc = history['val_acc'][val_loss_index]\n",
|
| 1190 |
+
" val_f1_macro = history['val_f1_macro'][val_loss_index]\n",
|
| 1191 |
+
" \n",
|
| 1192 |
+
" print(f\"\\nTrial {trial.number} completed:\")\n",
|
| 1193 |
+
" print(f\" Best Val Loss: {val_loss:.4f}\")\n",
|
| 1194 |
+
" print(f\" Val Acc at Best Loss: {val_acc:.4f}\")\n",
|
| 1195 |
+
" print(f\" Val F1-Macro at Best Loss: {val_f1_macro:.4f}\")\n",
|
| 1196 |
+
" \n",
|
| 1197 |
+
" # Return tuple for multi-objective optimization\n",
|
| 1198 |
+
" # Optuna will minimize val_loss and maximize val_acc\n",
|
| 1199 |
+
" return val_loss, val_acc\n",
|
| 1200 |
+
" \n",
|
| 1201 |
+
" except Exception as e:\n",
|
| 1202 |
+
" print(f\"\\nβ Trial {trial.number} FAILED with error: {e}\")\n",
|
| 1203 |
+
" import traceback\n",
|
| 1204 |
+
" traceback.print_exc()\n",
|
| 1205 |
+
" return float(\"inf\"), 0.0 # Return worst case for failed trials\n",
|
| 1206 |
+
"\n",
|
| 1207 |
+
"\n",
|
| 1208 |
+
"# ========== CREATE AND RUN OPTUNA STUDY ==========\n",
|
| 1209 |
+
"# Using in-memory storage for simplicity (change to PostgreSQL if you need persistence)\n",
|
| 1210 |
+
"study = optuna.create_study(\n",
|
| 1211 |
+
" storage=None, # Use in-memory storage\n",
|
| 1212 |
+
" study_name=\"proposed-modelv4-optim\",\n",
|
| 1213 |
+
" directions=[\"minimize\", \"maximize\"], # Minimize val_loss, Maximize val_acc\n",
|
| 1214 |
+
" pruner=MedianPruner(\n",
|
| 1215 |
+
" n_startup_trials=5, # Don't prune the first 5 trials\n",
|
| 1216 |
+
" n_warmup_steps=3, # Warmup steps before pruning kicks in\n",
|
| 1217 |
+
" interval_steps=1 # Check for pruning after each epoch\n",
|
| 1218 |
+
" ),\n",
|
| 1219 |
+
" sampler=optuna.samplers.TPESampler(seed=42), # Tree-structured Parzen Estimator\n",
|
| 1220 |
+
" load_if_exists=False,\n",
|
| 1221 |
+
")\n",
|
| 1222 |
+
"\n",
|
| 1223 |
+
"print(\"\\n\" + \"=\"*80)\n",
|
| 1224 |
+
"print(\"STARTING OPTUNA HYPERPARAMETER OPTIMIZATION\")\n",
|
| 1225 |
+
"print(\"Multi-Objective: Minimize Val Loss | Maximize Val Accuracy\")\n",
|
| 1226 |
+
"print(\"=\"*80 + \"\\n\")\n",
|
| 1227 |
+
"\n",
|
| 1228 |
+
"study.optimize(\n",
|
| 1229 |
+
" objective,\n",
|
| 1230 |
+
" n_trials=50, # Reduced for testing; increase for production\n",
|
| 1231 |
+
" timeout=None, # Optional: timeout in seconds\n",
|
| 1232 |
+
" show_progress_bar=True,\n",
|
| 1233 |
+
" callbacks=[\n",
|
| 1234 |
+
" optuna.study.MaxTrialsCallback(max_trials=50),\n",
|
| 1235 |
+
" ]\n",
|
| 1236 |
+
")\n",
|
| 1237 |
+
"\n",
|
| 1238 |
+
"# ========== PRINT RESULTS ==========\n",
|
| 1239 |
+
"print(\"\\n\" + \"=\"*80)\n",
|
| 1240 |
+
"print(\"OPTIMIZATION COMPLETE\")\n",
|
| 1241 |
+
"print(\"=\"*80)\n",
|
| 1242 |
+
"\n",
|
| 1243 |
+
"# For multi-objective optimization, there are multiple Pareto-optimal trials\n",
|
| 1244 |
+
"print(f\"\\nNumber of Pareto-optimal trials: {len(study.best_trials)}\")\n",
|
| 1245 |
+
"print(\"\\nPareto-optimal trials (ranked by loss):\")\n",
|
| 1246 |
+
"print(\"-\" * 80)\n",
|
| 1247 |
+
"\n",
|
| 1248 |
+
"# Sort by first objective (val_loss) for display\n",
|
| 1249 |
+
"best_trials_sorted = sorted(study.best_trials, key=lambda t: t.values[0])\n",
|
| 1250 |
+
"\n",
|
| 1251 |
+
"for i, trial in enumerate(best_trials_sorted[:5]): # Show top 5\n",
|
| 1252 |
+
" print(f\"\\nTrial #{trial.number}:\")\n",
|
| 1253 |
+
" print(f\" Validation Loss: {trial.values[0]:.4f}\")\n",
|
| 1254 |
+
" print(f\" Validation Accuracy: {trial.values[1]:.4f}\")\n",
|
| 1255 |
+
" print(f\" Key Hyperparameters:\")\n",
|
| 1256 |
+
" print(f\" - Batch Size: {trial.params.get('batch_size')}\")\n",
|
| 1257 |
+
" print(f\" - Temp Filters: {trial.params.get('temp_filters')}\")\n",
|
| 1258 |
+
" print(f\" - MSA Filters: {trial.params.get('msa_filters')}\")\n",
|
| 1259 |
+
" print(f\" - HateBERT LR: {trial.params.get('hatebert_lr'):.2e}\")\n",
|
| 1260 |
+
" print(f\" - Temp LR: {trial.params.get('temp_lr'):.2e}\")\n",
|
| 1261 |
+
" print(f\" - Hate Layers: {trial.params.get('hate_layers')}\")\n",
|
| 1262 |
+
"\n",
|
| 1263 |
+
"# Print top 5 trials as DataFrame\n",
|
| 1264 |
+
"print(f\"\\n\" + \"=\"*80)\n",
|
| 1265 |
+
"print(\"TOP 5 TRIALS (by validation loss)\")\n",
|
| 1266 |
+
"print(\"=\"*80)\n",
|
| 1267 |
+
"trials_df = study.trials_dataframe().sort_values('values_0').head(5)\n",
|
| 1268 |
+
"print(trials_df[['number', 'values_0', 'values_1', 'state', 'params_batch_size', 'params_temp_filters', 'params_msa_filters']].to_string())\n"
|
| 1269 |
+
]
|
| 1270 |
+
},
|
| 1271 |
+
{
|
| 1272 |
+
"cell_type": "code",
|
| 1273 |
+
"execution_count": null,
|
| 1274 |
+
"metadata": {
|
| 1275 |
+
"trusted": true
|
| 1276 |
+
},
|
| 1277 |
+
"outputs": [],
|
| 1278 |
+
"source": [
|
| 1279 |
+
"fig, axes = plt.subplots(3, 4, figsize=(20, 15))\n",
|
| 1280 |
+
"\n",
|
| 1281 |
+
"# Row 1: Train vs Val comparisons\n",
|
| 1282 |
+
"# Accuracy\n",
|
| 1283 |
+
"axes[0, 0].plot(history['train_acc'], label='Train', marker='o')\n",
|
| 1284 |
+
"axes[0, 0].plot(history['val_acc'], label='Validation', marker='s')\n",
|
| 1285 |
+
"axes[0, 0].set_title('Accuracy')\n",
|
| 1286 |
+
"axes[0, 0].set_xlabel('Epoch')\n",
|
| 1287 |
+
"axes[0, 0].set_ylabel('Accuracy')\n",
|
| 1288 |
+
"axes[0, 0].legend()\n",
|
| 1289 |
+
"axes[0, 0].grid(True)\n",
|
| 1290 |
+
"\n",
|
| 1291 |
+
"# Loss\n",
|
| 1292 |
+
"axes[0, 1].plot(history['train_loss'], label='Train', marker='o')\n",
|
| 1293 |
+
"axes[0, 1].plot(history['val_loss'], label='Validation', marker='s')\n",
|
| 1294 |
+
"axes[0, 1].set_title('Loss')\n",
|
| 1295 |
+
"axes[0, 1].set_xlabel('Epoch')\n",
|
| 1296 |
+
"axes[0, 1].set_ylabel('Loss')\n",
|
| 1297 |
+
"axes[0, 1].legend()\n",
|
| 1298 |
+
"axes[0, 1].grid(True)\n",
|
| 1299 |
+
"\n",
|
| 1300 |
+
"# F1 Score\n",
|
| 1301 |
+
"axes[0, 2].plot(history['train_f1'], label='Train', marker='o')\n",
|
| 1302 |
+
"axes[0, 2].plot(history['val_f1'], label='Validation', marker='s')\n",
|
| 1303 |
+
"axes[0, 2].set_title('F1 Score (Binary)')\n",
|
| 1304 |
+
"axes[0, 2].set_xlabel('Epoch')\n",
|
| 1305 |
+
"axes[0, 2].set_ylabel('F1')\n",
|
| 1306 |
+
"axes[0, 2].legend()\n",
|
| 1307 |
+
"axes[0, 2].grid(True)\n",
|
| 1308 |
+
"\n",
|
| 1309 |
+
"# Precision\n",
|
| 1310 |
+
"axes[0, 3].plot(history['train_precision'], label='Train', marker='o')\n",
|
| 1311 |
+
"axes[0, 3].plot(history['val_precision'], label='Validation', marker='s')\n",
|
| 1312 |
+
"axes[0, 3].set_title('Precision (Binary)')\n",
|
| 1313 |
+
"axes[0, 3].set_xlabel('Epoch')\n",
|
| 1314 |
+
"axes[0, 3].set_ylabel('Precision')\n",
|
| 1315 |
+
"axes[0, 3].legend()\n",
|
| 1316 |
+
"axes[0, 3].grid(True)\n",
|
| 1317 |
+
"\n",
|
| 1318 |
+
"# Row 2: More Train vs Val + Individual metrics\n",
|
| 1319 |
+
"# Recall\n",
|
| 1320 |
+
"axes[1, 0].plot(history['train_recall'], label='Train', marker='o')\n",
|
| 1321 |
+
"axes[1, 0].plot(history['val_recall'], label='Validation', marker='s')\n",
|
| 1322 |
+
"axes[1, 0].set_title('Recall (Binary)')\n",
|
| 1323 |
+
"axes[1, 0].set_xlabel('Epoch')\n",
|
| 1324 |
+
"axes[1, 0].set_ylabel('Recall')\n",
|
| 1325 |
+
"axes[1, 0].legend()\n",
|
| 1326 |
+
"axes[1, 0].grid(True)\n",
|
| 1327 |
+
"\n",
|
| 1328 |
+
"# Epoch Time\n",
|
| 1329 |
+
"axes[1, 1].plot(history['epoch_time'], marker='o', color='green')\n",
|
| 1330 |
+
"axes[1, 1].set_title('Epoch Time')\n",
|
| 1331 |
+
"axes[1, 1].set_xlabel('Epoch')\n",
|
| 1332 |
+
"axes[1, 1].set_ylabel('Time (s)')\n",
|
| 1333 |
+
"axes[1, 1].grid(True)\n",
|
| 1334 |
+
"\n",
|
| 1335 |
+
"# Train Throughput\n",
|
| 1336 |
+
"axes[1, 2].plot(history['train_throughput'], marker='o', color='purple')\n",
|
| 1337 |
+
"axes[1, 2].set_title('Train Throughput')\n",
|
| 1338 |
+
"axes[1, 2].set_xlabel('Epoch')\n",
|
| 1339 |
+
"axes[1, 2].set_ylabel('Samples/sec')\n",
|
| 1340 |
+
"axes[1, 2].grid(True)\n",
|
| 1341 |
+
"\n",
|
| 1342 |
+
"# Validation Confidence\n",
|
| 1343 |
+
"axes[1, 3].errorbar(range(len(history['val_confidence_mean'])), \n",
|
| 1344 |
+
" history['val_confidence_mean'], \n",
|
| 1345 |
+
" yerr=history['val_confidence_std'], \n",
|
| 1346 |
+
" marker='o', color='orange', capsize=3)\n",
|
| 1347 |
+
"axes[1, 3].set_title('Validation Confidence')\n",
|
| 1348 |
+
"axes[1, 3].set_xlabel('Epoch')\n",
|
| 1349 |
+
"axes[1, 3].set_ylabel('Confidence')\n",
|
| 1350 |
+
"axes[1, 3].grid(True)\n",
|
| 1351 |
+
"\n",
|
| 1352 |
+
"# Row 3: GPU Memory and Weighted/Macro metrics\n",
|
| 1353 |
+
"# GPU Memory\n",
|
| 1354 |
+
"axes[2, 0].plot(history['gpu_memory_mb'], marker='o', color='red')\n",
|
| 1355 |
+
"axes[2, 0].set_title('GPU Memory Usage')\n",
|
| 1356 |
+
"axes[2, 0].set_xlabel('Epoch')\n",
|
| 1357 |
+
"axes[2, 0].set_ylabel('Memory (MB)')\n",
|
| 1358 |
+
"axes[2, 0].grid(True)\n",
|
| 1359 |
+
"\n",
|
| 1360 |
+
"# Weighted F1 comparison\n",
|
| 1361 |
+
"axes[2, 1].plot(history['train_f1_weighted'], label='Train', marker='o')\n",
|
| 1362 |
+
"axes[2, 1].plot(history['val_f1_weighted'], label='Validation', marker='s')\n",
|
| 1363 |
+
"axes[2, 1].set_title('F1 Score (Weighted)')\n",
|
| 1364 |
+
"axes[2, 1].set_xlabel('Epoch')\n",
|
| 1365 |
+
"axes[2, 1].set_ylabel('F1')\n",
|
| 1366 |
+
"axes[2, 1].legend()\n",
|
| 1367 |
+
"axes[2, 1].grid(True)\n",
|
| 1368 |
+
"\n",
|
| 1369 |
+
"# Macro F1 comparison\n",
|
| 1370 |
+
"axes[2, 2].plot(history['train_f1_macro'], label='Train', marker='o')\n",
|
| 1371 |
+
"axes[2, 2].plot(history['val_f1_macro'], label='Validation', marker='s')\n",
|
| 1372 |
+
"axes[2, 2].set_title('F1 Score (Macro)')\n",
|
| 1373 |
+
"axes[2, 2].set_xlabel('Epoch')\n",
|
| 1374 |
+
"axes[2, 2].set_ylabel('F1')\n",
|
| 1375 |
+
"axes[2, 2].legend()\n",
|
| 1376 |
+
"axes[2, 2].grid(True)\n",
|
| 1377 |
+
"\n",
|
| 1378 |
+
"# Hide last subplot\n",
|
| 1379 |
+
"axes[2, 3].axis('off')\n",
|
| 1380 |
+
"\n",
|
| 1381 |
+
"plt.suptitle('Training History', fontsize=16, y=1.02)\n",
|
| 1382 |
+
"plt.tight_layout()\n",
|
| 1383 |
+
"plt.show()"
|
| 1384 |
+
]
|
| 1385 |
+
}
|
| 1386 |
+
],
|
| 1387 |
+
"metadata": {
|
| 1388 |
+
"kaggle": {
|
| 1389 |
+
"accelerator": "gpu",
|
| 1390 |
+
"dataSources": [
|
| 1391 |
+
{
|
| 1392 |
+
"databundleVersionId": 15785969,
|
| 1393 |
+
"datasetId": 9546304,
|
| 1394 |
+
"sourceId": 14919503,
|
| 1395 |
+
"sourceType": "datasetVersion"
|
| 1396 |
+
}
|
| 1397 |
+
],
|
| 1398 |
+
"dockerImageVersionId": 31286,
|
| 1399 |
+
"isGpuEnabled": true,
|
| 1400 |
+
"isInternetEnabled": true,
|
| 1401 |
+
"language": "python",
|
| 1402 |
+
"sourceType": "notebook"
|
| 1403 |
+
},
|
| 1404 |
+
"kernelspec": {
|
| 1405 |
+
"display_name": "Python 3",
|
| 1406 |
+
"language": "python",
|
| 1407 |
+
"name": "python3"
|
| 1408 |
+
},
|
| 1409 |
+
"language_info": {
|
| 1410 |
+
"codemirror_mode": {
|
| 1411 |
+
"name": "ipython",
|
| 1412 |
+
"version": 3
|
| 1413 |
+
},
|
| 1414 |
+
"file_extension": ".py",
|
| 1415 |
+
"mimetype": "text/x-python",
|
| 1416 |
+
"name": "python",
|
| 1417 |
+
"nbconvert_exporter": "python",
|
| 1418 |
+
"pygments_lexer": "ipython3",
|
| 1419 |
+
"version": "3.12.12"
|
| 1420 |
+
}
|
| 1421 |
+
},
|
| 1422 |
+
"nbformat": 4,
|
| 1423 |
+
"nbformat_minor": 4
|
| 1424 |
+
}
|
src/app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import gc
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
|
|
@@ -24,6 +25,16 @@ def load_cached_model(model_type="altered"):
|
|
| 24 |
"""Load and cache the model"""
|
| 25 |
return load_model_from_hf(model_type=model_type)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Custom CSS
|
| 28 |
st.markdown("""
|
| 29 |
<style>
|
|
@@ -156,21 +167,24 @@ classify_button = st.button("π Analyze Text", type="primary", use_container_w
|
|
| 156 |
|
| 157 |
if classify_button:
|
| 158 |
if user_input and user_input.strip():
|
|
|
|
|
|
|
|
|
|
| 159 |
with st.spinner('π Generating rationale from Mistral AI...'):
|
| 160 |
# --- Step 1: Get rationale from Mistral ---
|
| 161 |
try:
|
| 162 |
-
raw_rationale = get_rationale_from_mistral(
|
| 163 |
cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
|
| 164 |
print(f"Raw rationale from Mistral: {raw_rationale}")
|
| 165 |
except Exception as e:
|
| 166 |
st.error(f"β Error generating/processing rationale: {str(e)}")
|
| 167 |
-
cleaned_rationale =
|
| 168 |
|
| 169 |
with st.spinner('π Analyzing text with models...'):
|
| 170 |
# Run enhanced model
|
| 171 |
enhanced_start = time.time()
|
| 172 |
enhanced_model_result = predict_hatespeech(
|
| 173 |
-
text=
|
| 174 |
rationale=cleaned_rationale, # use cleaned rationale
|
| 175 |
model=enhanced_model,
|
| 176 |
tokenizer_hatebert=enhanced_tokenizer_hatebert,
|
|
@@ -184,7 +198,7 @@ if classify_button:
|
|
| 184 |
# Run base model
|
| 185 |
base_start = time.time()
|
| 186 |
base_model_result = predict_hatespeech(
|
| 187 |
-
text=
|
| 188 |
rationale=cleaned_rationale, # use cleaned rationale
|
| 189 |
model=base_model,
|
| 190 |
tokenizer_hatebert=base_tokenizer_hatebert,
|
|
|
|
| 1 |
import gc
|
| 2 |
+
import re
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
|
|
|
|
| 25 |
"""Load and cache the model"""
|
| 26 |
return load_model_from_hf(model_type=model_type)
|
| 27 |
|
| 28 |
+
def clean_user_input(text):
|
| 29 |
+
"""Remove URLs and special characters (except exclamation points) from text"""
|
| 30 |
+
# Remove URLs
|
| 31 |
+
text = re.sub(r'https?://\S+|www\.\S+', '', text)
|
| 32 |
+
# Remove special characters except exclamation points
|
| 33 |
+
text = re.sub(r'[^a-zA-Z0-9\s!]', '', text)
|
| 34 |
+
# Remove extra whitespace
|
| 35 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 36 |
+
return text
|
| 37 |
+
|
| 38 |
# Custom CSS
|
| 39 |
st.markdown("""
|
| 40 |
<style>
|
|
|
|
| 167 |
|
| 168 |
if classify_button:
|
| 169 |
if user_input and user_input.strip():
|
| 170 |
+
# Clean the input text
|
| 171 |
+
cleaned_input = clean_user_input(user_input)
|
| 172 |
+
|
| 173 |
with st.spinner('π Generating rationale from Mistral AI...'):
|
| 174 |
# --- Step 1: Get rationale from Mistral ---
|
| 175 |
try:
|
| 176 |
+
raw_rationale = get_rationale_from_mistral(cleaned_input)
|
| 177 |
cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
|
| 178 |
print(f"Raw rationale from Mistral: {raw_rationale}")
|
| 179 |
except Exception as e:
|
| 180 |
st.error(f"β Error generating/processing rationale: {str(e)}")
|
| 181 |
+
cleaned_rationale = cleaned_input # fallback to cleaned input
|
| 182 |
|
| 183 |
with st.spinner('π Analyzing text with models...'):
|
| 184 |
# Run enhanced model
|
| 185 |
enhanced_start = time.time()
|
| 186 |
enhanced_model_result = predict_hatespeech(
|
| 187 |
+
text=cleaned_input,
|
| 188 |
rationale=cleaned_rationale, # use cleaned rationale
|
| 189 |
model=enhanced_model,
|
| 190 |
tokenizer_hatebert=enhanced_tokenizer_hatebert,
|
|
|
|
| 198 |
# Run base model
|
| 199 |
base_start = time.time()
|
| 200 |
base_model_result = predict_hatespeech(
|
| 201 |
+
text=cleaned_input,
|
| 202 |
rationale=cleaned_rationale, # use cleaned rationale
|
| 203 |
model=base_model,
|
| 204 |
tokenizer_hatebert=base_tokenizer_hatebert,
|