jl commited on
Commit
3e8208f
Β·
1 Parent(s): 8092c6e

fix: regex filter for text

Browse files
notebooks/thesis-proposed-model-vo3.ipynb DELETED
@@ -1 +0,0 @@
1
- {"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceType":"datasetVersion","sourceId":14919503,"datasetId":9546304,"databundleVersionId":15785969}],"dockerImageVersionId":31286,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.optim import AdamW\nfrom sklearn.model_selection import train_test_split\nfrom transformers import AutoModel, AutoTokenizer\n\nimport numpy as np\nimport pandas as pd\nimport time\nimport os\nfrom tqdm.auto import tqdm\n\nfrom sklearn.metrics import (\n accuracy_score,\n precision_recall_fscore_support,\n classification_report,\n confusion_matrix\n)\n\nimport matplotlib.pyplot as plt\nimport seaborn as sns","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class AdditionalCustomDataset(Dataset):\n \"\"\"\n Pre-tokenizes all data during initialization for much faster training.\n \"\"\"\n def __init__(self, texts, labels, additional_texts, tokenizer, bert_tokenizer, max_length):\n self.labels = labels\n self.max_length = max_length\n \n # Pre-tokenize ALL data once during initialization (MUCH faster than tokenizing in __getitem__)\n print(\"Pre-tokenizing primary texts...\")\n primary_encodings = tokenizer(\n texts,\n max_length=max_length,\n truncation=True,\n padding='max_length',\n return_tensors='pt'\n )\n self.primary_input_ids = primary_encodings['input_ids']\n self.primary_attention_mask = primary_encodings['attention_mask']\n \n print(\"Pre-tokenizing additional texts...\")\n additional_encodings = bert_tokenizer(\n additional_texts,\n max_length=max_length,\n truncation=True,\n padding='max_length',\n return_tensors='pt'\n )\n self.additional_input_ids = additional_encodings['input_ids']\n self.additional_attention_mask = additional_encodings['attention_mask']\n \n print(f\"Pre-tokenization complete. Dataset size: {len(self.labels)}\")\n\n def __len__(self):\n return len(self.labels)\n\n def __getitem__(self, idx):\n return (\n self.primary_input_ids[idx],\n self.primary_attention_mask[idx],\n self.additional_input_ids[idx],\n self.additional_attention_mask[idx],\n self.labels[idx]\n )","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class ProjectionMLP(nn.Module):\n def __init__(self, input_size, output_size):\n super(ProjectionMLP, self).__init__()\n self.layers = nn.Sequential(\n nn.Linear(input_size, output_size),\n nn.ReLU(),\n nn.Linear(output_size, 2)\n )\n\n def forward(self, x):\n return self.layers(x)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class GumbelTokenSelector(nn.Module):\n def __init__(self, hidden_size, tau=1.0):\n super().__init__()\n self.tau = tau\n self.proj = nn.Linear(hidden_size * 2, 1)\n \n def forward(self, token_embeddings, cls_embedding, training=True):\n \"\"\"\n token_embeddings: (B, L, H)\n cls_embedding: (B, H)\n \"\"\"\n B, L, H = token_embeddings.size()\n \n cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)\n x = torch.cat([token_embeddings, cls_exp], dim=-1)\n \n logits = self.proj(x).squeeze(-1) # (B, L)\n \n if training:\n probs = F.gumbel_softmax(\n torch.stack([logits, torch.zeros_like(logits)], dim=-1),\n tau=self.tau,\n hard=False\n )[..., 0]\n else:\n probs = torch.sigmoid(logits)\n \n return probs, logits","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class MultiScaleAttentionCNN(nn.Module):\n def __init__(\n self,\n hidden_size=768,\n num_filters=128,\n kernel_sizes=(2, 3, 4),\n dropout=0.3,\n ):\n super().__init__()\n \n self.convs = nn.ModuleList([\n nn.Conv1d(hidden_size, num_filters, k)\n for k in kernel_sizes\n ])\n \n self.attention_fc = nn.Linear(num_filters, 1)\n self.dropout = nn.Dropout(dropout)\n self.out_dim = num_filters * len(kernel_sizes)\n \n def forward(self, x, mask):\n \"\"\"\n x: (B, L, H)\n mask: (B, L)\n \"\"\"\n x = x.transpose(1, 2) # (B, H, L)\n feats = []\n \n for conv in self.convs:\n h = F.relu(conv(x)) # (B, C, L')\n h = h.transpose(1, 2) # (B, L', C)\n \n attn = self.attention_fc(h).squeeze(-1)\n attn = attn.masked_fill(mask[:, :attn.size(1)] == 0, -1e9)\n alpha = F.softmax(attn, dim=1)\n \n pooled = torch.sum(h * alpha.unsqueeze(-1), dim=1)\n feats.append(pooled)\n \n out = torch.cat(feats, dim=1)\n return self.dropout(out)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class TemporalCNN(nn.Module):\n def __init__(\n self,\n hidden_size=768,\n num_filters=128,\n kernel_sizes=(2, 3, 4),\n dropout=0.1,\n dilation_base=2,\n ):\n super().__init__()\n \n self.kernel_sizes = kernel_sizes\n self.dilation_base = dilation_base\n \n # Dilated convolutions with exponentially increasing dilation rates\n self.convs = nn.ModuleList([\n nn.Conv1d(\n hidden_size, \n num_filters, \n k,\n dilation=dilation_base ** i, # dilation = 2^i\n padding=0 # we'll handle causal padding manually\n )\n for i, k in enumerate(kernel_sizes)\n ])\n \n self.dropout = nn.Dropout(dropout)\n self.out_dim = num_filters * len(kernel_sizes)\n \n def _causal_padding(self, x, kernel_size, dilation):\n \"\"\"\n Apply left padding only (causal) to ensure output at time t\n only depends on inputs from time 0 to t.\n \"\"\"\n # Calculate required padding: (kernel_size - 1) * dilation\n padding = (kernel_size - 1) * dilation\n # Pad only on the left (temporal dimension)\n return F.pad(x, (padding, 0))\n \n def forward(self, x, attention_mask):\n \"\"\"\n x: (B, L, H)\n attention_mask: (B, L)\n \"\"\"\n # zero-out padding tokens\n mask = attention_mask.unsqueeze(-1)\n x = x * mask\n \n # (B, H, L) for Conv1d\n x = x.transpose(1, 2)\n \n feats = []\n for i, conv in enumerate(self.convs):\n kernel_size = self.kernel_sizes[i]\n dilation = self.dilation_base ** i\n \n # Apply causal padding (left-only)\n x_padded = self._causal_padding(x, kernel_size, dilation)\n \n # Apply dilated convolution\n c = F.relu(conv(x_padded)) # (B, C, L')\n \n # Global max pooling over the temporal dimension\n p = F.max_pool1d(c, kernel_size=c.size(2)).squeeze(2) # (B, C)\n \n feats.append(p)\n \n out = torch.cat(feats, dim=1) # (B, num_filters * len(kernel_sizes))\n return self.dropout(out)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class ConcatModel(nn.Module):\n def __init__(\n self,\n hatebert_model,\n additional_model,\n temporal_cnn,\n msa_cnn,\n selector,\n projection_mlp,\n unfreeze_n_layers_hate=12, #hatebert unfreeze all 12 layers\n unfreeze_n_layers_add=0, # additional bert freeze all layers\n hate_pooler=True,\n add_pooler=False\n \n ):\n super().__init__()\n \n self.hatebert_model = hatebert_model\n self.additional_model = additional_model\n \n self.temporal_cnn = temporal_cnn\n self.msa_cnn = msa_cnn\n self.selector = selector\n self.projection_mlp = projection_mlp\n \n # freeze everything on additional bert\n for p in self.additional_model.parameters():\n p.requires_grad = False\n\n # will unfreeze last n layers of additional model\n add_num_layers = len(self.additional_model.encoder.layer)\n for i in range(add_num_layers - unfreeze_n_layers_add, add_num_layers):\n for param in self.additional_model.encoder.layer[i].parameters():\n param.requires_grad = True\n\n if self.additional_model.pooler is not None and add_pooler:\n for p in self.additional_model.pooler.parameters():\n p.requires_grad = True\n\n # freeze everything on hatebert\n for p in self.hatebert_model.parameters():\n p.requires_grad = False\n \n # FIX: Use hatebert_model to get the correct number of layers\n hate_num_layers = len(self.hatebert_model.encoder.layer)\n for i in range(hate_num_layers - unfreeze_n_layers_hate, hate_num_layers):\n for param in self.hatebert_model.encoder.layer[i].parameters():\n param.requires_grad = True\n \n if self.hatebert_model.pooler is not None and hate_pooler:\n for p in self.hatebert_model.pooler.parameters():\n p.requires_grad = True\n\n \n def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):\n # ================= HateBERT =================\n hate_outputs = self.hatebert_model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n )\n seq_emb = hate_outputs.last_hidden_state # (B, L, H)\n cls_emb = seq_emb[:, 0, :] # (B, H)\n \n # ---- Token Selector ----\n token_probs, token_logits = self.selector(seq_emb, cls_emb, self.training)\n \n # ---- Temporal CNN on FULL embeddings (NOT masked) ----\n temporal_feat = self.temporal_cnn(seq_emb, attention_mask)\n \n # ---- Rationale-Weighted Summary Vector H_r ----\n weights = token_probs.unsqueeze(-1) # (B, L, 1)\n H_r = (seq_emb * weights).sum(dim=1) / (weights.sum(dim=1) + 1e-6)\n \n # ================= Frozen Rationale BERT =================\n add_outputs = self.additional_model(\n input_ids=additional_input_ids,\n attention_mask=additional_attention_mask,\n )\n add_seq = add_outputs.last_hidden_state\n \n # ---- Multi-Scale Attention CNN ----\n msa_feat = self.msa_cnn(add_seq, additional_attention_mask)\n \n # ================= CONCAT (4 components) =================\n concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)\n \n logits = self.projection_mlp(concat)\n return logits","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class EarlyStopping:\n \"\"\"\n Early stopping to stop training when validation loss doesn't improve.\n \"\"\"\n def __init__(self, patience=10, min_delta=1.0, mode='min', verbose=True):\n \"\"\"\n Args:\n patience (int): How many epochs to wait after last improvement.\n min_delta (float): Minimum change to qualify as an improvement.\n mode (str): 'min' for loss, 'max' for accuracy/f1.\n verbose (bool): Print messages when improvement occurs.\n \"\"\"\n self.patience = patience\n self.min_delta = min_delta\n self.mode = mode\n self.verbose = verbose\n self.counter = 0\n self.best_score = None\n self.early_stop = False\n self.best_model_state = None\n \n def __call__(self, current_score, model):\n \"\"\"\n Call this after each epoch with the validation metric and model.\n \n Args:\n current_score: Current epoch's validation metric (loss, accuracy, f1, etc.)\n model: The model to save if there's improvement\n \n Returns:\n bool: True if training should stop, False otherwise\n \"\"\"\n if self.best_score is None:\n # First epoch\n self.best_score = current_score\n self.save_checkpoint(model)\n if self.verbose:\n print(f\"Initial best score: {self.best_score:.4f}\")\n else:\n # Check if there's improvement\n if self.mode == 'min':\n improved = current_score < (self.best_score - self.min_delta)\n else: # mode == 'max'\n improved = current_score > (self.best_score + self.min_delta)\n \n if improved:\n self.best_score = current_score\n self.save_checkpoint(model)\n self.counter = 0\n if self.verbose:\n print(f\"Validation improved! New best score: {self.best_score:.4f}\")\n else:\n self.counter += 1\n if self.verbose:\n print(f\"No improvement. Patience counter: {self.counter}/{self.patience}\")\n \n if self.counter >= self.patience:\n self.early_stop = True\n if self.verbose:\n print(f\"Early stopping triggered! Best score: {self.best_score:.4f}\")\n \n return self.early_stop\n \n def save_checkpoint(self, model):\n \"\"\"Save model state dict\"\"\"\n import copy\n self.best_model_state = copy.deepcopy(model.state_dict())\n \n def load_best_model(self, model):\n \"\"\"Load the best model state into the model\"\"\"\n if self.best_model_state is not None:\n model.load_state_dict(self.best_model_state)\n if self.verbose:\n print(f\"Loaded best model with score: {self.best_score:.4f}\")\n return model","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def main(args):\n torch.manual_seed(args.seed)\n torch.cuda.empty_cache()\n device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n\n\n file_map = {\n \"gab\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_GAB_dataset(85-15).csv',\n \"twitter\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_Twitter_dataset(85-15).csv',\n \"reddit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_REDDIT_dataset(85-15).csv',\n \"youtube\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_YOUTUBE_dataset(85-15).csv',\n \"implicit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_IMPLICIT_dataset(85-15).csv'\n }\n\n file_path = file_map[args.dataset]\n df = pd.read_csv(file_path)\n train_df = df[df['exp_split'] == 'train']\n test_df = df[df['exp_split'] == 'test']\n\n print(\"Train df: \", len(train_df))\n print(\"Test_df: \", len(test_df))\n\n import gc\n # del variables\n gc.collect()\n\n \n tokenizer = args.hate_tokenizer ## need this for tokenizing the input text in data loader\n tokenizer_bert = args.bert_tokenizer\n #Splitting training and validation testing split to test accuracy\n train_idx, val_idx = train_test_split(\n train_df.index,\n test_size=0.2,\n stratify=train_df[\"label\"],\n random_state=args.seed\n )\n \n if args.dataset == \"implicit\":\n train_text = train_df.loc[train_idx, \"post\"].tolist()\n val_texts = train_df.loc[val_idx, \"post\"].tolist()\n else:\n train_text = train_df.loc[train_idx, \"text\"].tolist()\n val_texts = train_df.loc[val_idx, \"text\"].tolist()\n \n add_train_text = train_df.loc[train_idx, \"Mistral_Rationales\"].tolist()\n add_val_texts = train_df.loc[val_idx, \"Mistral_Rationales\"].tolist()\n \n train_labels = train_df.loc[train_idx, \"label\"].tolist()\n val_labels = train_df.loc[val_idx, \"label\"].tolist()\n\n train_dataset = AdditionalCustomDataset(\n train_text,\n train_labels,\n add_train_text,\n tokenizer,\n tokenizer_bert,\n max_length=512\n )\n \n val_dataset = AdditionalCustomDataset(\n val_texts,\n val_labels,\n add_val_texts,\n tokenizer,\n tokenizer_bert,\n max_length=512\n )\n\n #Creating dataloader object to train the model\n # num_workers=2 for parallel data loading, pin_memory=True for faster GPU transfer\n train_dataloader = DataLoader(\n train_dataset, \n batch_size=args.batch_size, \n shuffle=True,\n num_workers=2,\n pin_memory=True if torch.cuda.is_available() else False\n )\n val_dataloader = DataLoader(\n val_dataset, \n batch_size=args.batch_size, \n shuffle=False,\n num_workers=2,\n pin_memory=True if torch.cuda.is_available() else False\n )\n hatebert_model = args.hatebert_model\n additional_model = args.additional_model\n \n \n temporal_cnn = TemporalCNN(\n hidden_size=768,\n num_filters=args.temp_filters, # 64 - 128\n kernel_sizes=(2, 3, 4),\n dropout=args.temp_dropout,\n dilation_base=args.temp_dilate # 0 - 4\n ).to(device)\n\n msa_cnn = MultiScaleAttentionCNN(\n hidden_size=768,\n num_filters=args.msa_filters, # 64 - 128\n kernel_sizes=(2, 3, 4),\n dropout=args.msa_dropout\n ).to(device)\n \n selector = GumbelTokenSelector(\n hidden_size=768,\n tau=1.0\n ).to(device)\n \n projection_mlp = ProjectionMLP(\n input_size=temporal_cnn.out_dim + msa_cnn.out_dim + 768 * 2,\n output_size=512\n ).to(device)\n\n\n\n concat_model = ConcatModel(\n hatebert_model=hatebert_model,\n additional_model=additional_model,\n temporal_cnn=temporal_cnn,\n msa_cnn=msa_cnn,\n selector=selector,\n projection_mlp=projection_mlp,\n unfreeze_n_layers_hate=args.hate_layers, #hatebert unfreeze all 12 layers id 12 by default\n unfreeze_n_layers_add=args.add_layers, # additional bert freeze all layers if 0 by default\n hate_pooler=args.hate_pooler, #bool to controle if pooler is frozen or not true=not frozen\n add_pooler=args.add_pooler #bool to controle if pooler is frozen or not true=not froze\n ).to(device)\n\n optimizer = AdamW(\n concat_model.parameters(),\n lr=args.lr, # 2e-5\n weight_decay=args.wd\n )\n criterion = nn.CrossEntropyLoss().to(device)\n\n # criterion = criterion.to(device)\n\n os.makedirs(\"/kaggle/working/models\", exist_ok=True)\n\n history = {\n \"train_loss\": [],\n \"val_loss\": [], #loss\n \"train_acc\": [],\n \"train_precision\": [],\n \"train_recall\": [],\n \"train_f1\": [], #train binary\n \"val_acc\": [],\n \"val_precision\": [],\n \"val_recall\": [],\n \"val_f1\": [], #validation binary\n \"train_acc_weighted\": [],\n \"train_precision_weighted\": [],\n \"train_recall_weighted\": [],\n \"train_f1_weighted\": [], # train weighted\n \"val_acc_weighted\": [],\n \"val_precision_weighted\": [],\n \"val_recall_weighted\": [],\n \"val_f1_weighted\": [], # validation weighted\n \"train_acc_macro\": [],\n \"train_precision_macro\": [],\n \"train_recall_macro\": [],\n \"train_f1_macro\": [], #train macro\n \"val_acc_macro\": [],\n \"val_precision_macro\": [],\n \"val_recall_macro\": [],\n \"val_f1_macro\": [],# validation macro\n \"epoch_time\": [],\n \"train_throughput\": [],\n \"val_confidence_mean\": [],\n \"val_confidence_std\": [],\n \"gpu_memory_mb\": [],\n }\n\n early_stopping = EarlyStopping(patience=args.patience, min_delta=0.001, mode='min', verbose=True) # early stop on loss\n\n for epoch in range(args.num_epochs):\n epoch_val_confidences = []\n epoch_start_time = time.time()\n samples_seen = 0\n \n concat_model.train()\n\n train_losses = []\n train_preds = []\n train_labels_epoch = []\n train_accuracy = 0\n train_epoch_size = 0\n\n with tqdm(train_dataloader, desc=f'Epoch {epoch + 1}', dynamic_ncols=True) as loop:\n for batch in loop:\n input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n \n samples_seen += labels.size(0)\n \n if torch.cuda.is_available():\n input_ids = input_ids.to(device)\n attention_mask = attention_mask.to(device)\n additional_input_ids = additional_input_ids.to(device)\n additional_attention_mask = additional_attention_mask.to(device)\n labels = labels.to(device)\n\n # Forward pass through the ConcatModel\n optimizer.zero_grad()\n outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n loss = criterion(outputs, labels)\n\n # Backward pass and optimization\n loss.backward()\n torch.nn.utils.clip_grad_norm_(concat_model.parameters(), max_norm=args.max_grad_norm)\n optimizer.step()\n \n probs = torch.softmax(outputs, dim=1)\n confidences, predictions = torch.max(probs, dim=1)\n train_preds.extend(predictions.cpu().numpy())\n train_labels_epoch.extend(labels.cpu().numpy())\n\n train_losses.append(loss.item())\n\n # Update accuracy and epoch size\n predictions = torch.argmax(outputs, dim=1)\n train_accuracy += (predictions == labels).sum().item()\n train_epoch_size += len(labels)\n \n epoch_train_time = time.time() - epoch_start_time\n train_throughput = samples_seen / epoch_train_time \n \n # Calculate train metrics (binary, weighted, macro)\n train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(\n train_labels_epoch, train_preds, average='binary'\n )\n train_precision_weighted, train_recall_weighted, train_f1_weighted, _ = precision_recall_fscore_support(\n train_labels_epoch, train_preds, average='weighted'\n )\n train_precision_macro, train_recall_macro, train_f1_macro, _ = precision_recall_fscore_support(\n train_labels_epoch, train_preds, average='macro'\n )\n train_acc = accuracy_score(train_labels_epoch, train_preds)\n\n # Evaluation on the validation set\n concat_model.eval()\n\n val_predictions = []\n val_labels_epoch = []\n val_loss = 0\n num_batches = 0\n\n with torch.no_grad(), tqdm(val_dataloader, desc='Validation', dynamic_ncols=True) as loop:\n for batch in loop:\n input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n\n if torch.cuda.is_available():\n input_ids = input_ids.to(device)\n attention_mask = attention_mask.to(device)\n additional_input_ids = additional_input_ids.to(device)\n additional_attention_mask = additional_attention_mask.to(device)\n labels = labels.to(device)\n\n # Forward pass through the ConcatModel\n outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n loss = criterion(outputs, labels)\n val_loss += loss.item()\n num_batches += 1\n probs = torch.softmax(outputs, dim=1)\n confidences, predictions = torch.max(probs, dim=1)\n\n epoch_val_confidences.extend(confidences.cpu().numpy())\n \n val_predictions.extend(predictions.cpu().numpy())\n val_labels_epoch.extend(labels.cpu().numpy())\n \n val_loss /= num_batches\n \n # Calculate validation metrics (binary)\n val_accuracy = accuracy_score(val_labels_epoch, val_predictions)\n val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(\n val_labels_epoch, val_predictions, average='binary'\n )\n # Calculate validation metrics (weighted)\n val_precision_weighted, val_recall_weighted, val_f1_weighted, _ = precision_recall_fscore_support(\n val_labels_epoch, val_predictions, average='weighted'\n )\n # Calculate validation metrics (macro)\n val_precision_macro, val_recall_macro, val_f1_macro, _ = precision_recall_fscore_support(\n val_labels_epoch, val_predictions, average='macro'\n )\n \n print(f\"Epoch {epoch}:\")\n print(f\" Train Accuracy: {train_acc:.4f}\")\n print(f\" Validation Accuracy: {val_accuracy:.4f}\")\n print(f\" Train Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}\")\n print(f\" Val Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}\")\n print(f\" Avg. Train Loss: {sum(train_losses) / len(train_losses):.4f}\")\n print(f\" Validation Loss: {val_loss:.4f}\")\n epoch_time = time.time() - epoch_start_time\n conf_mean = np.mean(epoch_val_confidences)\n conf_std = np.std(epoch_val_confidences)\n\n # Append all train metrics to history\n history[\"train_loss\"].append(np.mean(train_losses))\n history[\"train_acc\"].append(train_acc)\n history[\"train_precision\"].append(train_precision)\n history[\"train_recall\"].append(train_recall)\n history[\"train_f1\"].append(train_f1)\n history[\"train_acc_weighted\"].append(train_acc) # accuracy is same for all averaging\n history[\"train_precision_weighted\"].append(train_precision_weighted)\n history[\"train_recall_weighted\"].append(train_recall_weighted)\n history[\"train_f1_weighted\"].append(train_f1_weighted)\n history[\"train_acc_macro\"].append(train_acc) # accuracy is same for all averaging\n history[\"train_precision_macro\"].append(train_precision_macro)\n history[\"train_recall_macro\"].append(train_recall_macro)\n history[\"train_f1_macro\"].append(train_f1_macro)\n \n # Append all validation metrics to history\n history[\"val_loss\"].append(val_loss)\n history[\"val_acc\"].append(val_accuracy)\n history[\"val_precision\"].append(val_precision)\n history[\"val_recall\"].append(val_recall)\n history[\"val_f1\"].append(val_f1)\n history[\"val_acc_weighted\"].append(val_accuracy) # accuracy is same for all averaging\n history[\"val_precision_weighted\"].append(val_precision_weighted)\n history[\"val_recall_weighted\"].append(val_recall_weighted)\n history[\"val_f1_weighted\"].append(val_f1_weighted)\n history[\"val_acc_macro\"].append(val_accuracy) # accuracy is same for all averaging\n history[\"val_precision_macro\"].append(val_precision_macro)\n history[\"val_recall_macro\"].append(val_recall_macro)\n history[\"val_f1_macro\"].append(val_f1_macro)\n \n # Append efficiency metrics\n history[\"epoch_time\"].append(epoch_time)\n history[\"train_throughput\"].append(train_throughput)\n history[\"val_confidence_mean\"].append(conf_mean)\n history[\"val_confidence_std\"].append(conf_std)\n \n if torch.cuda.is_available():\n history[\"gpu_memory_mb\"].append(\n torch.cuda.max_memory_allocated() / 1024**2\n )\n torch.cuda.reset_peak_memory_stats()\n else:\n history[\"gpu_memory_mb\"].append(0)\n \n print(f\" Epoch Time (s): {epoch_time:.2f}\")\n print(f\" Throughput (samples/sec): {train_throughput:.2f}\")\n print(f\" Val Confidence Mean: {conf_mean:.4f} Β± {conf_std:.4f}\")\n \n current_metric = val_loss\n if early_stopping(current_metric, concat_model):\n print(f\"\\n{'='*50}\")\n print(f\"Early stopping at epoch {epoch+1}\")\n print(f\"{'='*50}\\n\")\n break\n model = early_stopping.load_best_model(concat_model)\n torch.save(model.state_dict(), f\"/kaggle/working/models/{args.dataset}_concat_model.pt\")\n print(f\"Best model saved to /kaggle/working/models/{args.dataset}_concat_model.pt\")\n\n checkpoint = {\n \"model_state_dict\": model.state_dict(),\n \"history\": history,\n }\n torch.save(checkpoint, f\"/kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n print(f\"Checkpoint with history saved to /kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n\n if args.dataset == \"implicit\":\n test_texts = test_df[\"post\"].tolist()\n else:\n test_texts = test_df[\"text\"].tolist()\n \n add_test_texts = test_df[\"Mistral_Rationales\"].tolist()\n test_labels = test_df[\"label\"].tolist()\n\n test_dataset = AdditionalCustomDataset(test_texts, test_labels, add_test_texts, tokenizer, tokenizer_bert, max_length=512)\n test_dataloader = DataLoader(\n test_dataset, \n batch_size=args.batch_size, # Use same batch size as training for faster inference\n shuffle=False,\n num_workers=2,\n pin_memory=True if torch.cuda.is_available() else False\n )\n\n # ================= TEST EVALUATION WITH EFFICIENCY =================\n model.eval()\n test_predictions = []\n test_true_labels = []\n test_confidences = []\n\n samples_seen = 0\n test_start_time = time.time()\n \n if torch.cuda.is_available():\n torch.cuda.reset_peak_memory_stats()\n \n with torch.no_grad(), tqdm(test_dataloader, desc='Testing', dynamic_ncols=True) as loop:\n for batch in loop:\n input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n \n batch_size = labels.size(0)\n samples_seen += batch_size\n \n input_ids = input_ids.to(device)\n attention_mask = attention_mask.to(device)\n additional_input_ids = additional_input_ids.to(device)\n additional_attention_mask = additional_attention_mask.to(device)\n labels = labels.to(device)\n \n outputs = model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n additional_input_ids=additional_input_ids,\n additional_attention_mask=additional_attention_mask\n )\n \n probs = torch.softmax(outputs, dim=1)\n confidences, preds = torch.max(probs, dim=1)\n \n test_confidences.extend(confidences.cpu().numpy())\n test_predictions.extend(preds.cpu().numpy())\n test_true_labels.extend(labels.cpu().numpy())\n\n \n # ================= TEST METRICS =================\n test_time = time.time() - test_start_time\n test_throughput = samples_seen / test_time\n \n accuracy = accuracy_score(test_true_labels, test_predictions)\n precision, recall, f1, _ = precision_recall_fscore_support(\n test_true_labels, test_predictions, average='weighted'\n )\n conf_mean = np.mean(test_confidences)\n conf_std = np.std(test_confidences)\n cm = confusion_matrix(test_true_labels, test_predictions)\n \n gpu_memory_mb = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0\n\n print(\"\\n================= FINAL TEST RESULTS =================\")\n print(f\"Dataset: {args.dataset}, Seed: {args.seed}, Epochs: {args.num_epochs}\")\n print(f\"Test Accuracy : {accuracy:.4f}\")\n print(f\"Test Precision: {precision:.4f}\")\n print(f\"Test Recall : {recall:.4f}\")\n print(f\"Test F1-score : {f1:.4f}\")\n print(f\"Test Confidence Mean Β± Std: {conf_mean:.4f} Β± {conf_std:.4f}\")\n print(f\"Test Time (s) : {test_time:.2f}\")\n print(f\"Throughput (samples/sec) : {test_throughput:.2f}\")\n print(f\"Peak GPU Memory (MB) : {gpu_memory_mb:.2f}\")\n print(\"\\nClassification Report:\")\n print(classification_report(test_true_labels, test_predictions))\n print(\"\\nConfusion Matrix:\")\n # Plot confusion matrix with seaborn\n plt.figure(figsize=(8, 6))\n sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n xticklabels=['Non-Hate', 'Hate'],\n yticklabels=['Non-Hate', 'Hate'],\n cbar_kws={'label': 'Count'})\n plt.title('Confusion Matrix', fontsize=14, pad=20)\n plt.ylabel('True label', fontsize=12)\n plt.xlabel('Predicted label', fontsize=12)\n plt.tight_layout()\n plt.show()\n print(\"======================================================\")\n\n return model, history","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from argparse import Namespace\nfrom transformers import AutoTokenizer, AutoModel\n\n# Load tokenizers and models ONCE (outside objective to save time)\ndevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\nhate_tokenizer = AutoTokenizer.from_pretrained(\"GroNLP/hateBERT\")\nbert_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n\ntorch.cuda.empty_cache()\n\n# Load fresh models for each trial (to reset weights)\nhatebert_model = AutoModel.from_pretrained(\"GroNLP/hateBERT\").to(device)\nadditional_model = AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n\n# Sample hyperparameters\nargs = Namespace(\n # Fixed\n seed=42,\n dataset=\"reddit\", # Change as needed: \"gab\", \"twitter\", \"reddit\", \"youtube\", \"implicit\"\n \n # Tokenizers & Models\n hate_tokenizer=hate_tokenizer,\n bert_tokenizer=bert_tokenizer,\n hatebert_model=hatebert_model,\n additional_model=additional_model,\n \n # Training hyperparameters\n batch_size=32,\n num_epochs=20,\n lr=1e-5,\n wd=0.01,\n patience=2,\n max_grad_norm=2.0,\n \n # TemporalCNN hyperparameters\n temp_filters=256,\n temp_dropout=0.1,\n temp_dilate=3,\n \n # MultiScaleAttentionCNN hyperparameters\n msa_filters=64,\n msa_dropout=0.23,\n \n # Layer unfreezing hyperparameters\n hate_layers=10,\n add_layers=0,\n hate_pooler=True,\n add_pooler=True,\n)\n \nmodel, history = main(args)\n\n\n ","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"fig, axes = plt.subplots(3, 4, figsize=(20, 15))\n\n# Row 1: Train vs Val comparisons\n# Accuracy\naxes[0, 0].plot(history['train_acc'], label='Train', marker='o')\naxes[0, 0].plot(history['val_acc'], label='Validation', marker='s')\naxes[0, 0].set_title('Accuracy')\naxes[0, 0].set_xlabel('Epoch')\naxes[0, 0].set_ylabel('Accuracy')\naxes[0, 0].legend()\naxes[0, 0].grid(True)\n\n# Loss\naxes[0, 1].plot(history['train_loss'], label='Train', marker='o')\naxes[0, 1].plot(history['val_loss'], label='Validation', marker='s')\naxes[0, 1].set_title('Loss')\naxes[0, 1].set_xlabel('Epoch')\naxes[0, 1].set_ylabel('Loss')\naxes[0, 1].legend()\naxes[0, 1].grid(True)\n\n# F1 Score\naxes[0, 2].plot(history['train_f1'], label='Train', marker='o')\naxes[0, 2].plot(history['val_f1'], label='Validation', marker='s')\naxes[0, 2].set_title('F1 Score (Binary)')\naxes[0, 2].set_xlabel('Epoch')\naxes[0, 2].set_ylabel('F1')\naxes[0, 2].legend()\naxes[0, 2].grid(True)\n\n# Precision\naxes[0, 3].plot(history['train_precision'], label='Train', marker='o')\naxes[0, 3].plot(history['val_precision'], label='Validation', marker='s')\naxes[0, 3].set_title('Precision (Binary)')\naxes[0, 3].set_xlabel('Epoch')\naxes[0, 3].set_ylabel('Precision')\naxes[0, 3].legend()\naxes[0, 3].grid(True)\n\n# Row 2: More Train vs Val + Individual metrics\n# Recall\naxes[1, 0].plot(history['train_recall'], label='Train', marker='o')\naxes[1, 0].plot(history['val_recall'], label='Validation', marker='s')\naxes[1, 0].set_title('Recall (Binary)')\naxes[1, 0].set_xlabel('Epoch')\naxes[1, 0].set_ylabel('Recall')\naxes[1, 0].legend()\naxes[1, 0].grid(True)\n\n# Epoch Time\naxes[1, 1].plot(history['epoch_time'], marker='o', color='green')\naxes[1, 1].set_title('Epoch Time')\naxes[1, 1].set_xlabel('Epoch')\naxes[1, 1].set_ylabel('Time (s)')\naxes[1, 1].grid(True)\n\n# Train Throughput\naxes[1, 2].plot(history['train_throughput'], marker='o', color='purple')\naxes[1, 2].set_title('Train Throughput')\naxes[1, 2].set_xlabel('Epoch')\naxes[1, 2].set_ylabel('Samples/sec')\naxes[1, 2].grid(True)\n\n# Validation Confidence\naxes[1, 3].errorbar(range(len(history['val_confidence_mean'])), \n history['val_confidence_mean'], \n yerr=history['val_confidence_std'], \n marker='o', color='orange', capsize=3)\naxes[1, 3].set_title('Validation Confidence')\naxes[1, 3].set_xlabel('Epoch')\naxes[1, 3].set_ylabel('Confidence')\naxes[1, 3].grid(True)\n\n# Row 3: GPU Memory and Weighted/Macro metrics\n# GPU Memory\naxes[2, 0].plot(history['gpu_memory_mb'], marker='o', color='red')\naxes[2, 0].set_title('GPU Memory Usage')\naxes[2, 0].set_xlabel('Epoch')\naxes[2, 0].set_ylabel('Memory (MB)')\naxes[2, 0].grid(True)\n\n# Weighted F1 comparison\naxes[2, 1].plot(history['train_f1_weighted'], label='Train', marker='o')\naxes[2, 1].plot(history['val_f1_weighted'], label='Validation', marker='s')\naxes[2, 1].set_title('F1 Score (Weighted)')\naxes[2, 1].set_xlabel('Epoch')\naxes[2, 1].set_ylabel('F1')\naxes[2, 1].legend()\naxes[2, 1].grid(True)\n\n# Macro F1 comparison\naxes[2, 2].plot(history['train_f1_macro'], label='Train', marker='o')\naxes[2, 2].plot(history['val_f1_macro'], label='Validation', marker='s')\naxes[2, 2].set_title('F1 Score (Macro)')\naxes[2, 2].set_xlabel('Epoch')\naxes[2, 2].set_ylabel('F1')\naxes[2, 2].legend()\naxes[2, 2].grid(True)\n\n# Hide last subplot\naxes[2, 3].axis('off')\n\nplt.suptitle('Training History', fontsize=16, y=1.02)\nplt.tight_layout()\nplt.show()","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
 
 
notebooks/thesis-proposed-model-vo4.ipynb ADDED
@@ -0,0 +1,1424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
8
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
9
+ "trusted": true
10
+ },
11
+ "outputs": [],
12
+ "source": [
13
+ "import torch\n",
14
+ "import torch.nn as nn\n",
15
+ "import torch.nn.functional as F\n",
16
+ "from torch.utils.data import DataLoader, Dataset\n",
17
+ "from torch.optim import AdamW\n",
18
+ "from sklearn.model_selection import train_test_split\n",
19
+ "from transformers import AutoModel, AutoTokenizer\n",
20
+ "\n",
21
+ "import numpy as np\n",
22
+ "import pandas as pd\n",
23
+ "import time\n",
24
+ "import os\n",
25
+ "from tqdm.auto import tqdm\n",
26
+ "\n",
27
+ "from sklearn.metrics import (\n",
28
+ " accuracy_score,\n",
29
+ " precision_recall_fscore_support,\n",
30
+ " classification_report,\n",
31
+ " confusion_matrix\n",
32
+ ")\n",
33
+ "\n",
34
+ "import matplotlib.pyplot as plt\n",
35
+ "import seaborn as sns"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "class AdditionalCustomDataset(Dataset):\n",
45
+ " \"\"\"\n",
46
+ " Pre-tokenizes all data during initialization for much faster training.\n",
47
+ " \"\"\"\n",
48
+ " def __init__(self, texts, labels, additional_texts, tokenizer, bert_tokenizer, max_length):\n",
49
+ " # Convert labels to tensor IMMEDIATELY to avoid type issues\n",
50
+ " self.labels = torch.tensor(labels, dtype=torch.long)\n",
51
+ " self.max_length = max_length\n",
52
+ " \n",
53
+ " # Pre-tokenize ALL data once during initialization\n",
54
+ " print(\"Pre-tokenizing primary texts...\")\n",
55
+ " primary_encodings = tokenizer(\n",
56
+ " texts,\n",
57
+ " max_length=max_length,\n",
58
+ " truncation=True,\n",
59
+ " padding='max_length',\n",
60
+ " return_tensors='pt'\n",
61
+ " )\n",
62
+ " self.primary_input_ids = primary_encodings['input_ids']\n",
63
+ " self.primary_attention_mask = primary_encodings['attention_mask']\n",
64
+ " \n",
65
+ " print(\"Pre-tokenizing additional texts...\")\n",
66
+ " additional_encodings = bert_tokenizer(\n",
67
+ " additional_texts,\n",
68
+ " max_length=max_length,\n",
69
+ " truncation=True,\n",
70
+ " padding='max_length',\n",
71
+ " return_tensors='pt'\n",
72
+ " )\n",
73
+ " self.additional_input_ids = additional_encodings['input_ids']\n",
74
+ " self.additional_attention_mask = additional_encodings['attention_mask']\n",
75
+ " \n",
76
+ " print(f\"Pre-tokenization complete. Dataset size: {len(self.labels)}\")\n",
77
+ " \n",
78
+ " # Optional: Move to pinned memory for faster GPU transfer\n",
79
+ " if torch.cuda.is_available():\n",
80
+ " self.primary_input_ids = self.primary_input_ids.pin_memory()\n",
81
+ " self.primary_attention_mask = self.primary_attention_mask.pin_memory()\n",
82
+ " self.additional_input_ids = self.additional_input_ids.pin_memory()\n",
83
+ " self.additional_attention_mask = self.additional_attention_mask.pin_memory()\n",
84
+ " self.labels = self.labels.pin_memory()\n",
85
+ "\n",
86
+ " def __len__(self):\n",
87
+ " return len(self.labels)\n",
88
+ "\n",
89
+ " def __getitem__(self, idx):\n",
90
+ " return (\n",
91
+ " self.primary_input_ids[idx],\n",
92
+ " self.primary_attention_mask[idx],\n",
93
+ " self.additional_input_ids[idx],\n",
94
+ " self.additional_attention_mask[idx],\n",
95
+ " self.labels[idx] # Now returns torch.long tensor\n",
96
+ " )"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "trusted": true
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "class ProjectionMLP(nn.Module):\n",
108
+ " def __init__(self, input_size, output_size):\n",
109
+ " super(ProjectionMLP, self).__init__()\n",
110
+ " self.layers = nn.Sequential(\n",
111
+ " nn.Linear(input_size, output_size),\n",
112
+ " nn.ReLU(),\n",
113
+ " nn.Linear(output_size, 2)\n",
114
+ " )\n",
115
+ "\n",
116
+ " def forward(self, x):\n",
117
+ " return self.layers(x)"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {
124
+ "trusted": true
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "class GumbelTokenSelector(nn.Module):\n",
129
+ " def __init__(self, hidden_size, tau=1.0):\n",
130
+ " super().__init__()\n",
131
+ " self.tau = tau\n",
132
+ " self.proj = nn.Linear(hidden_size * 2, 1)\n",
133
+ " \n",
134
+ " def forward(self, token_embeddings, cls_embedding, training=True):\n",
135
+ " \"\"\"\n",
136
+ " token_embeddings: (B, L, H)\n",
137
+ " cls_embedding: (B, H)\n",
138
+ " \"\"\"\n",
139
+ " B, L, H = token_embeddings.size()\n",
140
+ " \n",
141
+ " cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)\n",
142
+ " x = torch.cat([token_embeddings, cls_exp], dim=-1)\n",
143
+ " \n",
144
+ " logits = self.proj(x).squeeze(-1) # (B, L)\n",
145
+ " \n",
146
+ " if training:\n",
147
+ " probs = F.gumbel_softmax(\n",
148
+ " torch.stack([logits, torch.zeros_like(logits)], dim=-1),\n",
149
+ " tau=self.tau,\n",
150
+ " hard=False\n",
151
+ " )[..., 0]\n",
152
+ " else:\n",
153
+ " probs = torch.sigmoid(logits)\n",
154
+ " \n",
155
+ " return probs, logits"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "class AnnealedGumbelSelector(nn.Module):\n",
165
+ " def __init__(self, hidden_size, initial_tau=1.0, min_tau=0.1, decay_rate=0.9, hard=True):\n",
166
+ " \"\"\"\n",
167
+ " Args:\n",
168
+ " initial_tau: Starting temperature (high = random/exploration).\n",
169
+ " min_tau: Lowest temperature allowed (low = deterministic/exploitation).\n",
170
+ " decay_rate: How much to multiply tau by after every epoch.\n",
171
+ " hard: If True, uses Straight-Through Estimator (returns 0 or 1, but gradients flow).\n",
172
+ " \"\"\"\n",
173
+ " super().__init__()\n",
174
+ " self.initial_tau = initial_tau\n",
175
+ " self.min_tau = min_tau\n",
176
+ " self.decay_rate = decay_rate\n",
177
+ " self.hard = hard\n",
178
+ " \n",
179
+ " # Register tau as a buffer so it saves with model.state_dict() but isn't a trained parameter\n",
180
+ " self.register_buffer('tau', torch.tensor(initial_tau))\n",
181
+ " \n",
182
+ " # Project [Token_Emb; CLS_Emb] -> 1 scalar score per token\n",
183
+ " self.proj = nn.Linear(hidden_size * 2, 1)\n",
184
+ "\n",
185
+ " def step_tau(self):\n",
186
+ " \"\"\"Call this at the end of every epoch to cool it down\"\"\"\n",
187
+ " new_tau = max(self.min_tau, self.tau * self.decay_rate)\n",
188
+ " self.tau.fill_(new_tau)\n",
189
+ " return self.tau.item()\n",
190
+ "\n",
191
+ " def forward(self, token_embeddings, cls_embedding, training=True):\n",
192
+ " \"\"\"\n",
193
+ " Returns:\n",
194
+ " probs: (B, L) - selected masks (0.0 to 1.0)\n",
195
+ " logits: (B, L) - raw scores\n",
196
+ " \"\"\"\n",
197
+ " B, L, H = token_embeddings.size()\n",
198
+ "\n",
199
+ " # Create context aware embeddings: (B, L, 2*H)\n",
200
+ " cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)\n",
201
+ " x = torch.cat([token_embeddings, cls_exp], dim=-1)\n",
202
+ "\n",
203
+ " # Predict logits: (B, L)\n",
204
+ " logits = self.proj(x).squeeze(-1)\n",
205
+ "\n",
206
+ " if training:\n",
207
+ " # 1. Generate Gumbel Noise\n",
208
+ " # noise = -log(-log(uniform(0,1)))\n",
209
+ " uniform_noise = torch.rand_like(logits)\n",
210
+ " gumbel_noise = -torch.log(-torch.log(uniform_noise + 1e-9) + 1e-9)\n",
211
+ " \n",
212
+ " # 2. Add noise and scale by temperature\n",
213
+ " y_soft = torch.sigmoid((logits + gumbel_noise) / self.tau)\n",
214
+ "\n",
215
+ " if self.hard:\n",
216
+ " # 3. Straight-Through Estimator\n",
217
+ " # Forward pass is binary (0 or 1), Backward pass uses y_soft gradients\n",
218
+ " y_hard = (y_soft > 0.5).float()\n",
219
+ " # (y_hard - y_soft).detach() is 0 during backward pass\n",
220
+ " probs = y_hard - y_soft.detach() + y_soft\n",
221
+ " else:\n",
222
+ " probs = y_soft\n",
223
+ " else:\n",
224
+ " # During inference, just threshold the logits (deterministic)\n",
225
+ " if self.hard:\n",
226
+ " probs = (torch.sigmoid(logits) > 0.5).float()\n",
227
+ " else:\n",
228
+ " probs = torch.sigmoid(logits)\n",
229
+ "\n",
230
+ " return probs, logits"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {
237
+ "trusted": true
238
+ },
239
+ "outputs": [],
240
+ "source": [
241
+ "class MultiScaleAttentionCNN(nn.Module):\n",
242
+ " def __init__(\n",
243
+ " self,\n",
244
+ " hidden_size=768,\n",
245
+ " num_filters=128,\n",
246
+ " kernel_sizes=(2, 3, 4),\n",
247
+ " dropout=0.3,\n",
248
+ " ):\n",
249
+ " super().__init__()\n",
250
+ " \n",
251
+ " self.convs = nn.ModuleList([\n",
252
+ " nn.Conv1d(hidden_size, num_filters, k)\n",
253
+ " for k in kernel_sizes\n",
254
+ " ])\n",
255
+ " \n",
256
+ " self.attention_fc = nn.Linear(num_filters, 1)\n",
257
+ " self.dropout = nn.Dropout(dropout)\n",
258
+ " self.out_dim = num_filters * len(kernel_sizes)\n",
259
+ " \n",
260
+ " def forward(self, x, mask):\n",
261
+ " \"\"\"\n",
262
+ " x: (B, L, H)\n",
263
+ " mask: (B, L)\n",
264
+ " \"\"\"\n",
265
+ " x = x.transpose(1, 2) # (B, H, L)\n",
266
+ " feats = []\n",
267
+ " \n",
268
+ " for conv in self.convs:\n",
269
+ " h = F.relu(conv(x)) # (B, C, L')\n",
270
+ " h = h.transpose(1, 2) # (B, L', C)\n",
271
+ " \n",
272
+ " attn = self.attention_fc(h).squeeze(-1)\n",
273
+ " attn = attn.masked_fill(mask[:, :attn.size(1)] == 0, -1e9)\n",
274
+ " alpha = F.softmax(attn, dim=1)\n",
275
+ " \n",
276
+ " pooled = torch.sum(h * alpha.unsqueeze(-1), dim=1)\n",
277
+ " feats.append(pooled)\n",
278
+ " \n",
279
+ " out = torch.cat(feats, dim=1)\n",
280
+ " return self.dropout(out)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "class TemporalCNN(nn.Module):\n",
290
+ " \"\"\"\n",
291
+ " Standard TextCNN implementation (better for classification than Causal TCN).\n",
292
+ " Input: sequence_embeddings (B, L, H), attention_mask (B, L)\n",
293
+ " Output: pooled vector (B, output_dim) \n",
294
+ " \"\"\"\n",
295
+ " def __init__(self, input_dim=768, num_filters=256, kernel_sizes=(2,3,4), dropout=0.3):\n",
296
+ " super().__init__()\n",
297
+ " self.input_dim = input_dim\n",
298
+ " \n",
299
+ " # Output dim = filters * kernels * 2 (because of max + mean pooling)\n",
300
+ " self.out_dim = num_filters * len(kernel_sizes) * 2\n",
301
+ "\n",
302
+ " # Convs expect (B, C_in, L) where C_in = input_dim\n",
303
+ " self.convs = nn.ModuleList([\n",
304
+ " nn.Conv1d(\n",
305
+ " in_channels=input_dim, \n",
306
+ " out_channels=num_filters, \n",
307
+ " kernel_size=k, \n",
308
+ " padding=k//2 # 'Same' style padding to look at neighbors on both sides\n",
309
+ " )\n",
310
+ " for k in kernel_sizes\n",
311
+ " ])\n",
312
+ " self.dropout = nn.Dropout(dropout)\n",
313
+ "\n",
314
+ " def forward(self, sequence_embeddings, attention_mask=None):\n",
315
+ " \"\"\"\n",
316
+ " sequence_embeddings: (B, L, H)\n",
317
+ " attention_mask: (B, L)\n",
318
+ " \"\"\"\n",
319
+ " # transpose to (B, H, L) for Conv1d\n",
320
+ " x = sequence_embeddings.transpose(1, 2).contiguous()\n",
321
+ " \n",
322
+ " pooled_outputs = []\n",
323
+ " for conv in self.convs:\n",
324
+ " # Apply Convolution\n",
325
+ " conv_out = F.relu(conv(x)) # (B, num_filters, L_out)\n",
326
+ " \n",
327
+ " # Handle Masking for Pooling\n",
328
+ " if attention_mask is not None:\n",
329
+ " # Resize mask to match conv output length (in case of slight size change)\n",
330
+ " L_out = conv_out.size(2)\n",
331
+ " mask = attention_mask.float()\n",
332
+ " if mask.size(1) != L_out:\n",
333
+ " mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)\n",
334
+ " \n",
335
+ " # (B, 1, L_out) to broadcast over filters\n",
336
+ " mask_expanded = mask.unsqueeze(1).to(conv_out.device)\n",
337
+ " \n",
338
+ " # 1. Max Pooling with Mask (fill padded areas with -inf)\n",
339
+ " # We use a large negative number instead of true -inf to avoid NaNs\n",
340
+ " neg_val = -1e9\n",
341
+ " masked_for_max = torch.where(mask_expanded.bool(), conv_out, torch.tensor(neg_val, device=conv_out.device))\n",
342
+ " max_pooled = torch.max(masked_for_max, dim=2)[0]\n",
343
+ " \n",
344
+ " # 2. Mean Pooling with Mask (zero out padding, divide by valid length)\n",
345
+ " sum_masked = (conv_out * mask_expanded).sum(dim=2)\n",
346
+ " lengths = mask_expanded.sum(dim=2).clamp_min(1e-9)\n",
347
+ " mean_pooled = sum_masked / lengths\n",
348
+ " else:\n",
349
+ " max_pooled = torch.max(conv_out, dim=2)[0]\n",
350
+ " mean_pooled = conv_out.mean(dim=2)\n",
351
+ " \n",
352
+ " pooled_outputs.append(max_pooled)\n",
353
+ " pooled_outputs.append(mean_pooled)\n",
354
+ " \n",
355
+ " # Concatenate: (B, num_filters * len(kernels) * 2)\n",
356
+ " out = torch.cat(pooled_outputs, dim=1)\n",
357
+ " return self.dropout(out)"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": null,
363
+ "metadata": {
364
+ "trusted": true
365
+ },
366
+ "outputs": [],
367
+ "source": [
368
+ "class ConcatModel(nn.Module):\n",
369
+ " def __init__(\n",
370
+ " self,\n",
371
+ " hatebert_model,\n",
372
+ " additional_model,\n",
373
+ " temporal_cnn,\n",
374
+ " msa_cnn,\n",
375
+ " selector,\n",
376
+ " projection_mlp,\n",
377
+ " unfreeze_n_layers_hate=12, #hatebert unfreeze all 12 layers\n",
378
+ " hate_pooler=True,\n",
379
+ " \n",
380
+ " ):\n",
381
+ " super().__init__()\n",
382
+ " \n",
383
+ " self.hatebert_model = hatebert_model\n",
384
+ " self.additional_model = additional_model\n",
385
+ " \n",
386
+ " self.temporal_cnn = temporal_cnn\n",
387
+ " self.msa_cnn = msa_cnn\n",
388
+ " self.selector = selector\n",
389
+ " self.projection_mlp = projection_mlp\n",
390
+ " \n",
391
+ " # freeze everything on additional bert\n",
392
+ " for p in self.additional_model.parameters():\n",
393
+ " p.requires_grad = False\n",
394
+ "\n",
395
+ " # freeze everything on hatebert\n",
396
+ " for p in self.hatebert_model.parameters():\n",
397
+ " p.requires_grad = False\n",
398
+ " \n",
399
+ " # FIX: Use hatebert_model to get the correct number of layers\n",
400
+ " hate_num_layers = len(self.hatebert_model.encoder.layer)\n",
401
+ " for i in range(hate_num_layers - unfreeze_n_layers_hate, hate_num_layers):\n",
402
+ " for param in self.hatebert_model.encoder.layer[i].parameters():\n",
403
+ " param.requires_grad = True\n",
404
+ " \n",
405
+ " if self.hatebert_model.pooler is not None and hate_pooler:\n",
406
+ " for p in self.hatebert_model.pooler.parameters():\n",
407
+ " p.requires_grad = True\n",
408
+ "\n",
409
+ " \n",
410
+ " def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):\n",
411
+ " # ================= HateBERT =================\n",
412
+ " hate_outputs = self.hatebert_model(\n",
413
+ " input_ids=input_ids,\n",
414
+ " attention_mask=attention_mask,\n",
415
+ " )\n",
416
+ " seq_emb = hate_outputs.last_hidden_state # (B, L, H)\n",
417
+ " cls_emb = seq_emb[:, 0, :] # (B, H)\n",
418
+ " \n",
419
+ " # ---- Token Selector ----\n",
420
+ " token_probs, token_logits = self.selector(seq_emb, cls_emb, self.training)\n",
421
+ " \n",
422
+ " # ---- Temporal CNN on FULL embeddings (NOT masked) ----\n",
423
+ " temporal_feat = self.temporal_cnn(seq_emb, attention_mask)\n",
424
+ " \n",
425
+ " # ---- Rationale-Weighted Summary Vector H_r ----\n",
426
+ " weights = token_probs.unsqueeze(-1) # (B, L, 1)\n",
427
+ " H_r = (seq_emb * weights).sum(dim=1) / (weights.sum(dim=1) + 1e-6)\n",
428
+ " \n",
429
+ " # ================= Frozen Rationale BERT =================\n",
430
+ " add_outputs = self.additional_model(\n",
431
+ " input_ids=additional_input_ids,\n",
432
+ " attention_mask=additional_attention_mask,\n",
433
+ " )\n",
434
+ " add_seq = add_outputs.last_hidden_state\n",
435
+ " \n",
436
+ " # ---- Multi-Scale Attention CNN ----\n",
437
+ " msa_feat = self.msa_cnn(add_seq, additional_attention_mask)\n",
438
+ " \n",
439
+ " # ================= CONCAT (4 components) =================\n",
440
+ " concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)\n",
441
+ " \n",
442
+ " logits = self.projection_mlp(concat)\n",
443
+ " return logits"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": null,
449
+ "metadata": {
450
+ "trusted": true
451
+ },
452
+ "outputs": [],
453
+ "source": [
454
+ "class EarlyStopping:\n",
455
+ " \"\"\"\n",
456
+ " Early stopping to stop training when validation loss doesn't improve.\n",
457
+ " \"\"\"\n",
458
+ " def __init__(self, patience=10, min_delta=1.0, mode='min', verbose=True):\n",
459
+ " \"\"\"\n",
460
+ " Args:\n",
461
+ " patience (int): How many epochs to wait after last improvement.\n",
462
+ " min_delta (float): Minimum change to qualify as an improvement.\n",
463
+ " mode (str): 'min' for loss, 'max' for accuracy/f1.\n",
464
+ " verbose (bool): Print messages when improvement occurs.\n",
465
+ " \"\"\"\n",
466
+ " self.patience = patience\n",
467
+ " self.min_delta = min_delta\n",
468
+ " self.mode = mode\n",
469
+ " self.verbose = verbose\n",
470
+ " self.counter = 0\n",
471
+ " self.best_score = None\n",
472
+ " self.early_stop = False\n",
473
+ " self.best_model_state = None\n",
474
+ " \n",
475
+ " def __call__(self, current_score, model):\n",
476
+ " \"\"\"\n",
477
+ " Call this after each epoch with the validation metric and model.\n",
478
+ " \n",
479
+ " Args:\n",
480
+ " current_score: Current epoch's validation metric (loss, accuracy, f1, etc.)\n",
481
+ " model: The model to save if there's improvement\n",
482
+ " \n",
483
+ " Returns:\n",
484
+ " bool: True if training should stop, False otherwise\n",
485
+ " \"\"\"\n",
486
+ " if self.best_score is None:\n",
487
+ " # First epoch\n",
488
+ " self.best_score = current_score\n",
489
+ " self.save_checkpoint(model)\n",
490
+ " if self.verbose:\n",
491
+ " print(f\"Initial best score: {self.best_score:.4f}\")\n",
492
+ " else:\n",
493
+ " # Check if there's improvement\n",
494
+ " if self.mode == 'min':\n",
495
+ " improved = current_score < (self.best_score - self.min_delta)\n",
496
+ " else: # mode == 'max'\n",
497
+ " improved = current_score > (self.best_score + self.min_delta)\n",
498
+ " \n",
499
+ " if improved:\n",
500
+ " self.best_score = current_score\n",
501
+ " self.save_checkpoint(model)\n",
502
+ " self.counter = 0\n",
503
+ " if self.verbose:\n",
504
+ " print(f\"Validation improved! New best score: {self.best_score:.4f}\")\n",
505
+ " else:\n",
506
+ " self.counter += 1\n",
507
+ " if self.verbose:\n",
508
+ " print(f\"No improvement. Patience counter: {self.counter}/{self.patience}\")\n",
509
+ " \n",
510
+ " if self.counter >= self.patience:\n",
511
+ " self.early_stop = True\n",
512
+ " if self.verbose:\n",
513
+ " print(f\"Early stopping triggered! Best score: {self.best_score:.4f}\")\n",
514
+ " \n",
515
+ " return self.early_stop\n",
516
+ " \n",
517
+ " def save_checkpoint(self, model):\n",
518
+ " \"\"\"Save model state dict\"\"\"\n",
519
+ " import copy\n",
520
+ " self.best_model_state = copy.deepcopy(model.state_dict())\n",
521
+ " \n",
522
+ " def load_best_model(self, model):\n",
523
+ " \"\"\"Load the best model state into the model\"\"\"\n",
524
+ " if self.best_model_state is not None:\n",
525
+ " model.load_state_dict(self.best_model_state)\n",
526
+ " if self.verbose:\n",
527
+ " print(f\"Loaded best model with score: {self.best_score:.4f}\")\n",
528
+ " return model"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": null,
534
+ "metadata": {
535
+ "trusted": true
536
+ },
537
+ "outputs": [],
538
+ "source": [
539
+ "def main(args):\n",
540
+ " torch.manual_seed(args.seed)\n",
541
+ " torch.cuda.empty_cache()\n",
542
+ " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
543
+ "\n",
544
+ "\n",
545
+ " file_map = {\n",
546
+ " \"gab\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_GAB_dataset(85-15).csv',\n",
547
+ " \"twitter\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_Twitter_dataset(85-15).csv',\n",
548
+ " \"reddit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_REDDIT_dataset(85-15).csv',\n",
549
+ " \"youtube\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_YOUTUBE_dataset(85-15).csv',\n",
550
+ " \"implicit\": '/kaggle/input/datasets/jonniellm/final-dataset/Mistral_Rationales_file_IMPLICIT_dataset(85-15).csv'\n",
551
+ " }\n",
552
+ "\n",
553
+ " file_path = file_map[args.dataset]\n",
554
+ " df = pd.read_csv(file_path)\n",
555
+ " train_df = df[df['exp_split'] == 'train']\n",
556
+ " test_df = df[df['exp_split'] == 'test']\n",
557
+ "\n",
558
+ " print(\"Train df: \", len(train_df))\n",
559
+ " print(\"Test_df: \", len(test_df))\n",
560
+ "\n",
561
+ " import gc\n",
562
+ " # del variables\n",
563
+ " gc.collect()\n",
564
+ "\n",
565
+ " \n",
566
+ " tokenizer = args.hate_tokenizer ## need this for tokenizing the input text in data loader\n",
567
+ " tokenizer_bert = args.bert_tokenizer\n",
568
+ " #Splitting training and validation testing split to test accuracy\n",
569
+ " train_idx, val_idx = train_test_split(\n",
570
+ " train_df.index,\n",
571
+ " test_size=0.2,\n",
572
+ " stratify=train_df[\"label\"],\n",
573
+ " random_state=args.seed\n",
574
+ " )\n",
575
+ " \n",
576
+ " if args.dataset == \"implicit\":\n",
577
+ " train_text = train_df.loc[train_idx, \"post\"].tolist()\n",
578
+ " val_texts = train_df.loc[val_idx, \"post\"].tolist()\n",
579
+ " else:\n",
580
+ " train_text = train_df.loc[train_idx, \"text\"].tolist()\n",
581
+ " val_texts = train_df.loc[val_idx, \"text\"].tolist()\n",
582
+ " \n",
583
+ " add_train_text = train_df.loc[train_idx, \"Mistral_Rationales\"].tolist()\n",
584
+ " add_val_texts = train_df.loc[val_idx, \"Mistral_Rationales\"].tolist()\n",
585
+ " \n",
586
+ " train_labels = train_df.loc[train_idx, \"label\"].tolist()\n",
587
+ " val_labels = train_df.loc[val_idx, \"label\"].tolist()\n",
588
+ "\n",
589
+ " train_dataset = AdditionalCustomDataset(\n",
590
+ " train_text,\n",
591
+ " train_labels,\n",
592
+ " add_train_text,\n",
593
+ " tokenizer,\n",
594
+ " tokenizer_bert,\n",
595
+ " max_length=512\n",
596
+ " )\n",
597
+ " \n",
598
+ " val_dataset = AdditionalCustomDataset(\n",
599
+ " val_texts,\n",
600
+ " val_labels,\n",
601
+ " add_val_texts,\n",
602
+ " tokenizer,\n",
603
+ " tokenizer_bert,\n",
604
+ " max_length=512\n",
605
+ " )\n",
606
+ "\n",
607
+ " #Creating dataloader object to train the model\n",
608
+ " # num_workers=2 for parallel data loading, pin_memory=True for faster GPU transfer\n",
609
+ " train_dataloader = DataLoader(\n",
610
+ " train_dataset, \n",
611
+ " batch_size=args.batch_size, \n",
612
+ " shuffle=True,\n",
613
+ " num_workers=2,\n",
614
+ " pin_memory=True if torch.cuda.is_available() else False\n",
615
+ " )\n",
616
+ " val_dataloader = DataLoader(\n",
617
+ " val_dataset, \n",
618
+ " batch_size=args.batch_size, \n",
619
+ " shuffle=False,\n",
620
+ " num_workers=2,\n",
621
+ " pin_memory=True if torch.cuda.is_available() else False\n",
622
+ " )\n",
623
+ " hatebert_model = args.hatebert_model\n",
624
+ " additional_model = args.additional_model\n",
625
+ " \n",
626
+ " \n",
627
+ " temporal_cnn = TemporalCNN(\n",
628
+ " input_dim=768,\n",
629
+ " num_filters=args.temp_filters, # 64 - 128\n",
630
+ " kernel_sizes=(2, 3, 4),\n",
631
+ " dropout=args.temp_dropout,\n",
632
+ " ).to(device)\n",
633
+ "\n",
634
+ " msa_cnn = MultiScaleAttentionCNN(\n",
635
+ " hidden_size=768,\n",
636
+ " num_filters=args.msa_filters, # 64 - 128\n",
637
+ " kernel_sizes=(2, 3, 4),\n",
638
+ " dropout=args.msa_dropout\n",
639
+ " ).to(device)\n",
640
+ " \n",
641
+ " selector = AnnealedGumbelSelector(\n",
642
+ " hidden_size=768,\n",
643
+ " initial_tau=1.0, # Start hot (random exploration)\n",
644
+ " min_tau=0.1, # End cold (deterministic selection)\n",
645
+ " decay_rate=0.9, # Decrease by 10% every epoch\n",
646
+ " hard=True # Force binary 0/1 selections\n",
647
+ " ).to(device)\n",
648
+ " \n",
649
+ " projection_mlp = ProjectionMLP(\n",
650
+ " input_size=temporal_cnn.out_dim + msa_cnn.out_dim + 768 * 2,\n",
651
+ " output_size=512\n",
652
+ " ).to(device)\n",
653
+ "\n",
654
+ "\n",
655
+ "\n",
656
+ " concat_model = ConcatModel(\n",
657
+ " hatebert_model=hatebert_model,\n",
658
+ " additional_model=additional_model,\n",
659
+ " temporal_cnn=temporal_cnn,\n",
660
+ " msa_cnn=msa_cnn,\n",
661
+ " selector=selector,\n",
662
+ " projection_mlp=projection_mlp,\n",
663
+ " unfreeze_n_layers_hate=args.hate_layers, #hatebert unfreeze all 12 layers id 12 by default\n",
664
+ " hate_pooler=args.hate_pooler, #bool to controle if pooler is frozen or not true=not frozen\n",
665
+ " ).to(device)\n",
666
+ "\n",
667
+ " optimizer = AdamW(\n",
668
+ " [\n",
669
+ " {'params': concat_model.hatebert_model.parameters(), 'lr': args.hatebert_lr},\n",
670
+ " {'params': concat_model.additional_model.parameters(), 'lr': args.additional_lr},\n",
671
+ " {'params': concat_model.temporal_cnn.parameters(), 'lr': args.temp_lr},\n",
672
+ " {'params': concat_model.msa_cnn.parameters(), 'lr': args.msa_lr},\n",
673
+ " {'params': concat_model.projection_mlp.parameters(), 'lr': args.proj_lr}\n",
674
+ " ],\n",
675
+ " weight_decay=args.wd\n",
676
+ " )\n",
677
+ " criterion = nn.CrossEntropyLoss().to(device)\n",
678
+ "\n",
679
+ " # criterion = criterion.to(device)\n",
680
+ "\n",
681
+ " os.makedirs(\"/kaggle/working/models\", exist_ok=True)\n",
682
+ "\n",
683
+ " history = {\n",
684
+ " \"train_loss\": [],\n",
685
+ " \"val_loss\": [], #loss\n",
686
+ " \"train_acc\": [],\n",
687
+ " \"train_precision\": [],\n",
688
+ " \"train_recall\": [],\n",
689
+ " \"train_f1\": [], #train binary\n",
690
+ " \"val_acc\": [],\n",
691
+ " \"val_precision\": [],\n",
692
+ " \"val_recall\": [],\n",
693
+ " \"val_f1\": [], #validation binary\n",
694
+ " \"train_acc_weighted\": [],\n",
695
+ " \"train_precision_weighted\": [],\n",
696
+ " \"train_recall_weighted\": [],\n",
697
+ " \"train_f1_weighted\": [], # train weighted\n",
698
+ " \"val_acc_weighted\": [],\n",
699
+ " \"val_precision_weighted\": [],\n",
700
+ " \"val_recall_weighted\": [],\n",
701
+ " \"val_f1_weighted\": [], # validation weighted\n",
702
+ " \"train_acc_macro\": [],\n",
703
+ " \"train_precision_macro\": [],\n",
704
+ " \"train_recall_macro\": [],\n",
705
+ " \"train_f1_macro\": [], #train macro\n",
706
+ " \"val_acc_macro\": [],\n",
707
+ " \"val_precision_macro\": [],\n",
708
+ " \"val_recall_macro\": [],\n",
709
+ " \"val_f1_macro\": [],# validation macro\n",
710
+ " \"epoch_time\": [],\n",
711
+ " \"train_throughput\": [],\n",
712
+ " \"val_confidence_mean\": [],\n",
713
+ " \"val_confidence_std\": [],\n",
714
+ " \"gpu_memory_mb\": [],\n",
715
+ " }\n",
716
+ "\n",
717
+ " early_stopping = EarlyStopping(patience=args.patience, min_delta=0.001, mode='min', verbose=True) # early stop on loss\n",
718
+ "\n",
719
+ " for epoch in range(args.num_epochs):\n",
720
+ " epoch_val_confidences = []\n",
721
+ " epoch_start_time = time.time()\n",
722
+ " samples_seen = 0\n",
723
+ " \n",
724
+ " concat_model.train()\n",
725
+ "\n",
726
+ " train_losses = []\n",
727
+ " train_preds = []\n",
728
+ " train_labels_epoch = []\n",
729
+ " train_accuracy = 0\n",
730
+ " train_epoch_size = 0\n",
731
+ "\n",
732
+ " with tqdm(train_dataloader, desc=f'Epoch {epoch + 1}', dynamic_ncols=True) as loop:\n",
733
+ " for batch in loop:\n",
734
+ " input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n",
735
+ " \n",
736
+ " samples_seen += labels.size(0)\n",
737
+ " \n",
738
+ " if torch.cuda.is_available():\n",
739
+ " input_ids = input_ids.to(device)\n",
740
+ " attention_mask = attention_mask.to(device)\n",
741
+ " additional_input_ids = additional_input_ids.to(device)\n",
742
+ " additional_attention_mask = additional_attention_mask.to(device)\n",
743
+ " labels = labels.to(device)\n",
744
+ "\n",
745
+ " # Forward pass through the ConcatModel\n",
746
+ " optimizer.zero_grad()\n",
747
+ " outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n",
748
+ " loss = criterion(outputs, labels)\n",
749
+ "\n",
750
+ " # Backward pass and optimization\n",
751
+ " loss.backward()\n",
752
+ " torch.nn.utils.clip_grad_norm_(concat_model.parameters(), max_norm=args.max_grad_norm)\n",
753
+ " optimizer.step()\n",
754
+ " \n",
755
+ " probs = torch.softmax(outputs, dim=1)\n",
756
+ " confidences, predictions = torch.max(probs, dim=1)\n",
757
+ " train_preds.extend(predictions.cpu().numpy())\n",
758
+ " train_labels_epoch.extend(labels.cpu().numpy())\n",
759
+ "\n",
760
+ " train_losses.append(loss.item())\n",
761
+ "\n",
762
+ " # Update accuracy and epoch size\n",
763
+ " predictions = torch.argmax(outputs, dim=1)\n",
764
+ " train_accuracy += (predictions == labels).sum().item()\n",
765
+ " train_epoch_size += len(labels)\n",
766
+ " \n",
767
+ " epoch_train_time = time.time() - epoch_start_time\n",
768
+ " train_throughput = samples_seen / epoch_train_time \n",
769
+ " \n",
770
+ " # Calculate train metrics (binary, weighted, macro)\n",
771
+ " train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(\n",
772
+ " train_labels_epoch, train_preds, average='binary'\n",
773
+ " )\n",
774
+ " train_precision_weighted, train_recall_weighted, train_f1_weighted, _ = precision_recall_fscore_support(\n",
775
+ " train_labels_epoch, train_preds, average='weighted'\n",
776
+ " )\n",
777
+ " train_precision_macro, train_recall_macro, train_f1_macro, _ = precision_recall_fscore_support(\n",
778
+ " train_labels_epoch, train_preds, average='macro'\n",
779
+ " )\n",
780
+ " train_acc = accuracy_score(train_labels_epoch, train_preds)\n",
781
+ "\n",
782
+ " # Evaluation on the validation set\n",
783
+ " concat_model.eval()\n",
784
+ "\n",
785
+ " val_predictions = []\n",
786
+ " val_labels_epoch = []\n",
787
+ " val_loss = 0\n",
788
+ " num_batches = 0\n",
789
+ "\n",
790
+ " with torch.no_grad(), tqdm(val_dataloader, desc='Validation', dynamic_ncols=True) as loop:\n",
791
+ " for batch in loop:\n",
792
+ " input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n",
793
+ "\n",
794
+ " if torch.cuda.is_available():\n",
795
+ " input_ids = input_ids.to(device)\n",
796
+ " attention_mask = attention_mask.to(device)\n",
797
+ " additional_input_ids = additional_input_ids.to(device)\n",
798
+ " additional_attention_mask = additional_attention_mask.to(device)\n",
799
+ " labels = labels.to(device)\n",
800
+ "\n",
801
+ " # Forward pass through the ConcatModel\n",
802
+ " outputs = concat_model(input_ids=input_ids, attention_mask=attention_mask, additional_input_ids=additional_input_ids, additional_attention_mask=additional_attention_mask)\n",
803
+ " loss = criterion(outputs, labels)\n",
804
+ " val_loss += loss.item()\n",
805
+ " num_batches += 1\n",
806
+ " probs = torch.softmax(outputs, dim=1)\n",
807
+ " confidences, predictions = torch.max(probs, dim=1)\n",
808
+ "\n",
809
+ " epoch_val_confidences.extend(confidences.cpu().numpy())\n",
810
+ " \n",
811
+ " val_predictions.extend(predictions.cpu().numpy())\n",
812
+ " val_labels_epoch.extend(labels.cpu().numpy())\n",
813
+ " \n",
814
+ " val_loss /= num_batches\n",
815
+ " \n",
816
+ " # Calculate validation metrics (binary)\n",
817
+ " val_accuracy = accuracy_score(val_labels_epoch, val_predictions)\n",
818
+ " val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(\n",
819
+ " val_labels_epoch, val_predictions, average='binary'\n",
820
+ " )\n",
821
+ " # Calculate validation metrics (weighted)\n",
822
+ " val_precision_weighted, val_recall_weighted, val_f1_weighted, _ = precision_recall_fscore_support(\n",
823
+ " val_labels_epoch, val_predictions, average='weighted'\n",
824
+ " )\n",
825
+ " # Calculate validation metrics (macro)\n",
826
+ " val_precision_macro, val_recall_macro, val_f1_macro, _ = precision_recall_fscore_support(\n",
827
+ " val_labels_epoch, val_predictions, average='macro'\n",
828
+ " )\n",
829
+ " \n",
830
+ " print(f\"Epoch {epoch}:\")\n",
831
+ " print(f\" Train Accuracy: {train_acc:.4f}\")\n",
832
+ " print(f\" Validation Accuracy: {val_accuracy:.4f}\")\n",
833
+ " print(f\" Train Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}\")\n",
834
+ " print(f\" Val Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}\")\n",
835
+ " print(f\" Avg. Train Loss: {sum(train_losses) / len(train_losses):.4f}\")\n",
836
+ " print(f\" Validation Loss: {val_loss:.4f}\")\n",
837
+ " epoch_time = time.time() - epoch_start_time\n",
838
+ " conf_mean = np.mean(epoch_val_confidences)\n",
839
+ " conf_std = np.std(epoch_val_confidences)\n",
840
+ "\n",
841
+ " # Append all train metrics to history\n",
842
+ " history[\"train_loss\"].append(np.mean(train_losses))\n",
843
+ " history[\"train_acc\"].append(train_acc)\n",
844
+ " history[\"train_precision\"].append(train_precision)\n",
845
+ " history[\"train_recall\"].append(train_recall)\n",
846
+ " history[\"train_f1\"].append(train_f1)\n",
847
+ " history[\"train_acc_weighted\"].append(train_acc) # accuracy is same for all averaging\n",
848
+ " history[\"train_precision_weighted\"].append(train_precision_weighted)\n",
849
+ " history[\"train_recall_weighted\"].append(train_recall_weighted)\n",
850
+ " history[\"train_f1_weighted\"].append(train_f1_weighted)\n",
851
+ " history[\"train_acc_macro\"].append(train_acc) # accuracy is same for all averaging\n",
852
+ " history[\"train_precision_macro\"].append(train_precision_macro)\n",
853
+ " history[\"train_recall_macro\"].append(train_recall_macro)\n",
854
+ " history[\"train_f1_macro\"].append(train_f1_macro)\n",
855
+ " \n",
856
+ " # Append all validation metrics to history\n",
857
+ " history[\"val_loss\"].append(val_loss)\n",
858
+ " history[\"val_acc\"].append(val_accuracy)\n",
859
+ " history[\"val_precision\"].append(val_precision)\n",
860
+ " history[\"val_recall\"].append(val_recall)\n",
861
+ " history[\"val_f1\"].append(val_f1)\n",
862
+ " history[\"val_acc_weighted\"].append(val_accuracy) # accuracy is same for all averaging\n",
863
+ " history[\"val_precision_weighted\"].append(val_precision_weighted)\n",
864
+ " history[\"val_recall_weighted\"].append(val_recall_weighted)\n",
865
+ " history[\"val_f1_weighted\"].append(val_f1_weighted)\n",
866
+ " history[\"val_acc_macro\"].append(val_accuracy) # accuracy is same for all averaging\n",
867
+ " history[\"val_precision_macro\"].append(val_precision_macro)\n",
868
+ " history[\"val_recall_macro\"].append(val_recall_macro)\n",
869
+ " history[\"val_f1_macro\"].append(val_f1_macro)\n",
870
+ " \n",
871
+ " # Append efficiency metrics\n",
872
+ " history[\"epoch_time\"].append(epoch_time)\n",
873
+ " history[\"train_throughput\"].append(train_throughput)\n",
874
+ " history[\"val_confidence_mean\"].append(conf_mean)\n",
875
+ " history[\"val_confidence_std\"].append(conf_std)\n",
876
+ " \n",
877
+ " if torch.cuda.is_available():\n",
878
+ " history[\"gpu_memory_mb\"].append(\n",
879
+ " torch.cuda.max_memory_allocated() / 1024**2\n",
880
+ " )\n",
881
+ " torch.cuda.reset_peak_memory_stats()\n",
882
+ " else:\n",
883
+ " history[\"gpu_memory_mb\"].append(0)\n",
884
+ " \n",
885
+ " print(f\" Epoch Time (s): {epoch_time:.2f}\")\n",
886
+ " print(f\" Throughput (samples/sec): {train_throughput:.2f}\")\n",
887
+ " print(f\" Val Confidence Mean: {conf_mean:.4f} Β± {conf_std:.4f}\")\n",
888
+ "\n",
889
+ " new_tau = concat_model.selector.step_tau()\n",
890
+ " print(f\" Updated Gumbel Tau: {new_tau:.4f}\")\n",
891
+ " \n",
892
+ " current_metric = val_loss\n",
893
+ " if early_stopping(current_metric, concat_model):\n",
894
+ " print(f\"\\n{'='*50}\")\n",
895
+ " print(f\"Early stopping at epoch {epoch+1}\")\n",
896
+ " print(f\"{'='*50}\\n\")\n",
897
+ " break\n",
898
+ " model = early_stopping.load_best_model(concat_model)\n",
899
+ " torch.save(model.state_dict(), f\"/kaggle/working/models/{args.dataset}_concat_model.pt\")\n",
900
+ " print(f\"Best model saved to /kaggle/working/models/{args.dataset}_concat_model.pt\")\n",
901
+ "\n",
902
+ " checkpoint = {\n",
903
+ " \"history\": history,\n",
904
+ " }\n",
905
+ " torch.save(checkpoint, f\"/kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n",
906
+ " print(f\"Checkpoint with history saved to /kaggle/working/models/{args.dataset}_concat_checkpoint.pt\")\n",
907
+ "\n",
908
+ " if args.dataset == \"implicit\":\n",
909
+ " test_texts = test_df[\"post\"].tolist()\n",
910
+ " else:\n",
911
+ " test_texts = test_df[\"text\"].tolist()\n",
912
+ " \n",
913
+ " add_test_texts = test_df[\"Mistral_Rationales\"].tolist()\n",
914
+ " test_labels = test_df[\"label\"].tolist()\n",
915
+ "\n",
916
+ " test_dataset = AdditionalCustomDataset(test_texts, test_labels, add_test_texts, tokenizer, tokenizer_bert, max_length=512)\n",
917
+ " test_dataloader = DataLoader(\n",
918
+ " test_dataset, \n",
919
+ " batch_size=args.batch_size, # Use same batch size as training for faster inference\n",
920
+ " shuffle=False,\n",
921
+ " num_workers=2,\n",
922
+ " pin_memory=True if torch.cuda.is_available() else False\n",
923
+ " )\n",
924
+ "\n",
925
+ " # ================= TEST EVALUATION WITH EFFICIENCY =================\n",
926
+ " model.eval()\n",
927
+ " test_predictions = []\n",
928
+ " test_true_labels = []\n",
929
+ " test_confidences = []\n",
930
+ "\n",
931
+ " samples_seen = 0\n",
932
+ " test_start_time = time.time()\n",
933
+ " \n",
934
+ " if torch.cuda.is_available():\n",
935
+ " torch.cuda.reset_peak_memory_stats()\n",
936
+ " \n",
937
+ " with torch.no_grad(), tqdm(test_dataloader, desc='Testing', dynamic_ncols=True) as loop:\n",
938
+ " for batch in loop:\n",
939
+ " input_ids, attention_mask, additional_input_ids, additional_attention_mask, labels = batch\n",
940
+ " \n",
941
+ " batch_size = labels.size(0)\n",
942
+ " samples_seen += batch_size\n",
943
+ " \n",
944
+ " input_ids = input_ids.to(device)\n",
945
+ " attention_mask = attention_mask.to(device)\n",
946
+ " additional_input_ids = additional_input_ids.to(device)\n",
947
+ " additional_attention_mask = additional_attention_mask.to(device)\n",
948
+ " labels = labels.to(device)\n",
949
+ " \n",
950
+ " outputs = model(\n",
951
+ " input_ids=input_ids,\n",
952
+ " attention_mask=attention_mask,\n",
953
+ " additional_input_ids=additional_input_ids,\n",
954
+ " additional_attention_mask=additional_attention_mask\n",
955
+ " )\n",
956
+ " \n",
957
+ " probs = torch.softmax(outputs, dim=1)\n",
958
+ " confidences, preds = torch.max(probs, dim=1)\n",
959
+ " \n",
960
+ " test_confidences.extend(confidences.cpu().numpy())\n",
961
+ " test_predictions.extend(preds.cpu().numpy())\n",
962
+ " test_true_labels.extend(labels.cpu().numpy())\n",
963
+ "\n",
964
+ " \n",
965
+ " # ================= TEST METRICS =================\n",
966
+ " test_time = time.time() - test_start_time\n",
967
+ " test_throughput = samples_seen / test_time\n",
968
+ " \n",
969
+ " accuracy = accuracy_score(test_true_labels, test_predictions)\n",
970
+ " precision, recall, f1, _ = precision_recall_fscore_support(\n",
971
+ " test_true_labels, test_predictions, average='weighted'\n",
972
+ " )\n",
973
+ " conf_mean = np.mean(test_confidences)\n",
974
+ " conf_std = np.std(test_confidences)\n",
975
+ " cm = confusion_matrix(test_true_labels, test_predictions)\n",
976
+ " \n",
977
+ " gpu_memory_mb = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0\n",
978
+ "\n",
979
+ " print(\"\\n================= FINAL TEST RESULTS =================\")\n",
980
+ " print(f\"Dataset: {args.dataset}, Seed: {args.seed}, Epochs: {args.num_epochs}\")\n",
981
+ " print(f\"Test Accuracy : {accuracy:.4f}\")\n",
982
+ " print(f\"Test Precision: {precision:.4f}\")\n",
983
+ " print(f\"Test Recall : {recall:.4f}\")\n",
984
+ " print(f\"Test F1-score : {f1:.4f}\")\n",
985
+ " print(f\"Test Confidence Mean Β± Std: {conf_mean:.4f} Β± {conf_std:.4f}\")\n",
986
+ " print(f\"Test Time (s) : {test_time:.2f}\")\n",
987
+ " print(f\"Throughput (samples/sec) : {test_throughput:.2f}\")\n",
988
+ " print(f\"Peak GPU Memory (MB) : {gpu_memory_mb:.2f}\")\n",
989
+ " print(\"\\nClassification Report:\")\n",
990
+ " print(classification_report(test_true_labels, test_predictions))\n",
991
+ " print(\"\\nConfusion Matrix:\")\n",
992
+ " # Plot confusion matrix with seaborn\n",
993
+ " plt.figure(figsize=(8, 6))\n",
994
+ " sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n",
995
+ " xticklabels=['Non-Hate', 'Hate'],\n",
996
+ " yticklabels=['Non-Hate', 'Hate'],\n",
997
+ " cbar_kws={'label': 'Count'})\n",
998
+ " plt.title('Confusion Matrix', fontsize=14, pad=20)\n",
999
+ " plt.ylabel('True label', fontsize=12)\n",
1000
+ " plt.xlabel('Predicted label', fontsize=12)\n",
1001
+ " plt.tight_layout()\n",
1002
+ " plt.show()\n",
1003
+ " print(\"======================================================\")\n",
1004
+ "\n",
1005
+ " return model, history"
1006
+ ]
1007
+ },
1008
+ {
1009
+ "cell_type": "code",
1010
+ "execution_count": null,
1011
+ "metadata": {
1012
+ "trusted": true
1013
+ },
1014
+ "outputs": [],
1015
+ "source": [
1016
+ "from argparse import Namespace\n",
1017
+ "from transformers import AutoTokenizer, AutoModel\n",
1018
+ "\n",
1019
+ "# Load tokenizers and models ONCE (outside objective to save time)\n",
1020
+ "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
1021
+ "hate_tokenizer = AutoTokenizer.from_pretrained(\"GroNLP/hateBERT\")\n",
1022
+ "bert_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
1023
+ "\n",
1024
+ "torch.cuda.empty_cache()\n",
1025
+ "\n",
1026
+ "# Load fresh models for each trial (to reset weights)\n",
1027
+ "hatebert_model = AutoModel.from_pretrained(\"GroNLP/hateBERT\").to(device)\n",
1028
+ "additional_model = AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
1029
+ "\n",
1030
+ "# Sample hyperparameters\n",
1031
+ "args = Namespace(\n",
1032
+ " # Fixed\n",
1033
+ " seed=42,\n",
1034
+ " dataset=\"reddit\", # Change as needed: \"gab\", \"twitter\", \"reddit\", \"youtube\", \"implicit\"\n",
1035
+ "\n",
1036
+ " # Tokenizers & Models\n",
1037
+ " hate_tokenizer=hate_tokenizer,\n",
1038
+ " bert_tokenizer=bert_tokenizer,\n",
1039
+ " hatebert_model=hatebert_model,\n",
1040
+ " additional_model=additional_model,\n",
1041
+ "\n",
1042
+ " # Training hyperparameters\n",
1043
+ " batch_size=32,\n",
1044
+ " num_epochs=20,\n",
1045
+ " wd=0.01, # Weight decay\n",
1046
+ " patience=5, # Early stopping patience\n",
1047
+ " max_grad_norm=2.0, # Gradient clipping\n",
1048
+ "\n",
1049
+ " # Learning Rates (Missing but required by optimizer in main)\n",
1050
+ " hatebert_lr=2e-5, # Lower LR for pre-trained models\n",
1051
+ " additional_lr=2e-5, # Usually kept low or 0 if frozen\n",
1052
+ " temp_lr=1e-3, # Higher LR for initialized layers (CNN)\n",
1053
+ " msa_lr=1e-3, # Higher LR for initialized layers (CNN)\n",
1054
+ " proj_lr=1e-3, # Higher LR for Head (MLP)\n",
1055
+ "\n",
1056
+ " # TemporalCNN hyperparameters\n",
1057
+ " temp_filters=256,\n",
1058
+ " temp_dropout=0.1,\n",
1059
+ " temp_dilate=3, # Note: This is in args, but TemporalCNN class doesn't use it. \n",
1060
+ " # Code fix required in main() or TemporalCNN class.\n",
1061
+ "\n",
1062
+ " # MultiScaleAttentionCNN hyperparameters\n",
1063
+ " msa_filters=64,\n",
1064
+ " msa_dropout=0.23,\n",
1065
+ "\n",
1066
+ " # Layer unfreezing hyperparameters\n",
1067
+ " hate_layers=10, # How many layers of HateBERT to unfreeze\n",
1068
+ " add_layers=0, # How many layers of Add. BERT to unfreeze\n",
1069
+ " hate_pooler=True,\n",
1070
+ " add_pooler=True,\n",
1071
+ ") \n",
1072
+ "model, history = main(args)\n",
1073
+ "\n",
1074
+ "\n",
1075
+ " "
1076
+ ]
1077
+ },
1078
+ {
1079
+ "cell_type": "code",
1080
+ "execution_count": null,
1081
+ "metadata": {},
1082
+ "outputs": [],
1083
+ "source": [
1084
+ "import optuna\n",
1085
+ "from optuna.pruners import MedianPruner\n",
1086
+ "\n",
1087
+ "# Load tokenizers and models ONCE (outside objective to save time)\n",
1088
+ "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
1089
+ "hate_tokenizer = AutoTokenizer.from_pretrained(\"GroNLP/hateBERT\")\n",
1090
+ "bert_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
1091
+ "\n",
1092
+ "def objective(trial):\n",
1093
+ " \"\"\"\n",
1094
+ " Optuna objective function that samples hyperparameters and trains the model.\n",
1095
+ " Returns: tuple of (validation_loss, validation_accuracy)\n",
1096
+ " - Minimize validation loss\n",
1097
+ " - Maximize validation accuracy\n",
1098
+ " \"\"\"\n",
1099
+ " # Clear GPU cache before each trial\n",
1100
+ " torch.cuda.empty_cache()\n",
1101
+ " \n",
1102
+ " # Load fresh models for each trial (to reset weights)\n",
1103
+ " hatebert_model = AutoModel.from_pretrained(\"GroNLP/hateBERT\").to(device)\n",
1104
+ " additional_model = AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
1105
+ " \n",
1106
+ " # ========== SAMPLE HYPERPARAMETERS USING OPTUNA ==========\n",
1107
+ " # Training hyperparameters\n",
1108
+ " batch_size = trial.suggest_int(\"batch_size\", 16, 64, step=16) # 16, 32, 48, 64\n",
1109
+ " num_epochs = trial.suggest_int(\"num_epochs\", 15, 30, step=5) # 15, 20, 25, 30\n",
1110
+ " weight_decay = trial.suggest_float(\"weight_decay\", 1e-4, 1e-1, log=True)\n",
1111
+ " max_grad_norm = trial.suggest_float(\"max_grad_norm\", 0.5, 3.0)\n",
1112
+ " patience = trial.suggest_int(\"patience\", 3, 8)\n",
1113
+ " \n",
1114
+ " # Learning Rates (use log scale for better sampling across orders of magnitude)\n",
1115
+ " hatebert_lr = trial.suggest_float(\"hatebert_lr\", 1e-6, 5e-5, log=True)\n",
1116
+ " additional_lr = trial.suggest_float(\"additional_lr\", 1e-6, 5e-5, log=True)\n",
1117
+ " temp_lr = trial.suggest_float(\"temp_lr\", 1e-5, 1e-2, log=True)\n",
1118
+ " msa_lr = trial.suggest_float(\"msa_lr\", 1e-5, 1e-2, log=True)\n",
1119
+ " proj_lr = trial.suggest_float(\"proj_lr\", 1e-5, 1e-2, log=True)\n",
1120
+ " \n",
1121
+ " # TemporalCNN hyperparameters\n",
1122
+ " temp_filters = trial.suggest_int(\"temp_filters\", 128, 512, step=64) # 128, 192, 256, ...\n",
1123
+ " temp_dropout = trial.suggest_float(\"temp_dropout\", 0.0, 0.5, step=0.05) # 0.0, 0.05, 0.1, ...\n",
1124
+ " kernel_size_2_weight = trial.suggest_float(\"kernel_size_weight\", 0.3, 1.0) # Weight for inclusion\n",
1125
+ " \n",
1126
+ " # MultiScaleAttentionCNN hyperparameters\n",
1127
+ " msa_filters = trial.suggest_int(\"msa_filters\", 32, 256, step=32) # 32, 64, 96, 128, ...\n",
1128
+ " msa_dropout = trial.suggest_float(\"msa_dropout\", 0.0, 0.5, step=0.05)\n",
1129
+ " \n",
1130
+ " # Layer unfreezing hyperparameters\n",
1131
+ " hate_layers = trial.suggest_int(\"hate_layers\", 4, 12) # How many layers to unfreeze (0-12)\n",
1132
+ " hate_pooler = trial.suggest_categorical(\"hate_pooler\", [True, False])\n",
1133
+ " \n",
1134
+ " # Create args Namespace with sampled hyperparameters\n",
1135
+ " args = Namespace(\n",
1136
+ " # Fixed\n",
1137
+ " seed=42,\n",
1138
+ " dataset=\"reddit\", # Change as needed\n",
1139
+ " \n",
1140
+ " # Tokenizers & Models\n",
1141
+ " hate_tokenizer=hate_tokenizer,\n",
1142
+ " bert_tokenizer=bert_tokenizer,\n",
1143
+ " hatebert_model=hatebert_model,\n",
1144
+ " additional_model=additional_model,\n",
1145
+ " \n",
1146
+ " # Sampled training hyperparameters\n",
1147
+ " batch_size=batch_size,\n",
1148
+ " num_epochs=num_epochs,\n",
1149
+ " wd=weight_decay,\n",
1150
+ " patience=patience,\n",
1151
+ " max_grad_norm=max_grad_norm,\n",
1152
+ " \n",
1153
+ " # Sampled learning rates\n",
1154
+ " hatebert_lr=hatebert_lr,\n",
1155
+ " additional_lr=additional_lr,\n",
1156
+ " temp_lr=temp_lr,\n",
1157
+ " msa_lr=msa_lr,\n",
1158
+ " proj_lr=proj_lr,\n",
1159
+ " \n",
1160
+ " # Sampled TemporalCNN hyperparameters\n",
1161
+ " temp_filters=temp_filters,\n",
1162
+ " temp_dropout=temp_dropout,\n",
1163
+ " temp_dilate=3, # Fixed (not optimized)\n",
1164
+ " \n",
1165
+ " # Sampled MultiScaleAttentionCNN hyperparameters\n",
1166
+ " msa_filters=msa_filters,\n",
1167
+ " msa_dropout=msa_dropout,\n",
1168
+ " \n",
1169
+ " # Sampled layer unfreezing hyperparameters\n",
1170
+ " hate_layers=hate_layers,\n",
1171
+ " add_layers=0, # Fixed (frozen)\n",
1172
+ " hate_pooler=hate_pooler,\n",
1173
+ " add_pooler=False, # Fixed\n",
1174
+ " )\n",
1175
+ " \n",
1176
+ " try:\n",
1177
+ " print(f\"\\n{'='*60}\")\n",
1178
+ " print(f\"Trial {trial.number} started\")\n",
1179
+ " print(f\"Batch Size: {batch_size}, Epochs: {num_epochs}, Patience: {patience}\")\n",
1180
+ " print(f\"HateBERT LR: {hatebert_lr:.2e}, Temp LR: {temp_lr:.2e}, MSA LR: {msa_lr:.2e}\")\n",
1181
+ " print(f\"Temp Filters: {temp_filters}, MSA Filters: {msa_filters}\")\n",
1182
+ " print(f\"{'='*60}\\n\")\n",
1183
+ " \n",
1184
+ " model, history = main(args)\n",
1185
+ " \n",
1186
+ " # Get best validation loss and corresponding accuracy\n",
1187
+ " val_loss = min(history[\"val_loss\"])\n",
1188
+ " val_loss_index = history[\"val_loss\"].index(val_loss)\n",
1189
+ " val_acc = history['val_acc'][val_loss_index]\n",
1190
+ " val_f1_macro = history['val_f1_macro'][val_loss_index]\n",
1191
+ " \n",
1192
+ " print(f\"\\nTrial {trial.number} completed:\")\n",
1193
+ " print(f\" Best Val Loss: {val_loss:.4f}\")\n",
1194
+ " print(f\" Val Acc at Best Loss: {val_acc:.4f}\")\n",
1195
+ " print(f\" Val F1-Macro at Best Loss: {val_f1_macro:.4f}\")\n",
1196
+ " \n",
1197
+ " # Return tuple for multi-objective optimization\n",
1198
+ " # Optuna will minimize val_loss and maximize val_acc\n",
1199
+ " return val_loss, val_acc\n",
1200
+ " \n",
1201
+ " except Exception as e:\n",
1202
+ " print(f\"\\n❌ Trial {trial.number} FAILED with error: {e}\")\n",
1203
+ " import traceback\n",
1204
+ " traceback.print_exc()\n",
1205
+ " return float(\"inf\"), 0.0 # Return worst case for failed trials\n",
1206
+ "\n",
1207
+ "\n",
1208
+ "# ========== CREATE AND RUN OPTUNA STUDY ==========\n",
1209
+ "# Using in-memory storage for simplicity (change to PostgreSQL if you need persistence)\n",
1210
+ "study = optuna.create_study(\n",
1211
+ " storage=None, # Use in-memory storage\n",
1212
+ " study_name=\"proposed-modelv4-optim\",\n",
1213
+ " directions=[\"minimize\", \"maximize\"], # Minimize val_loss, Maximize val_acc\n",
1214
+ " pruner=MedianPruner(\n",
1215
+ " n_startup_trials=5, # Don't prune the first 5 trials\n",
1216
+ " n_warmup_steps=3, # Warmup steps before pruning kicks in\n",
1217
+ " interval_steps=1 # Check for pruning after each epoch\n",
1218
+ " ),\n",
1219
+ " sampler=optuna.samplers.TPESampler(seed=42), # Tree-structured Parzen Estimator\n",
1220
+ " load_if_exists=False,\n",
1221
+ ")\n",
1222
+ "\n",
1223
+ "print(\"\\n\" + \"=\"*80)\n",
1224
+ "print(\"STARTING OPTUNA HYPERPARAMETER OPTIMIZATION\")\n",
1225
+ "print(\"Multi-Objective: Minimize Val Loss | Maximize Val Accuracy\")\n",
1226
+ "print(\"=\"*80 + \"\\n\")\n",
1227
+ "\n",
1228
+ "study.optimize(\n",
1229
+ " objective,\n",
1230
+ " n_trials=50, # Reduced for testing; increase for production\n",
1231
+ " timeout=None, # Optional: timeout in seconds\n",
1232
+ " show_progress_bar=True,\n",
1233
+ " callbacks=[\n",
1234
+ " optuna.study.MaxTrialsCallback(max_trials=50),\n",
1235
+ " ]\n",
1236
+ ")\n",
1237
+ "\n",
1238
+ "# ========== PRINT RESULTS ==========\n",
1239
+ "print(\"\\n\" + \"=\"*80)\n",
1240
+ "print(\"OPTIMIZATION COMPLETE\")\n",
1241
+ "print(\"=\"*80)\n",
1242
+ "\n",
1243
+ "# For multi-objective optimization, there are multiple Pareto-optimal trials\n",
1244
+ "print(f\"\\nNumber of Pareto-optimal trials: {len(study.best_trials)}\")\n",
1245
+ "print(\"\\nPareto-optimal trials (ranked by loss):\")\n",
1246
+ "print(\"-\" * 80)\n",
1247
+ "\n",
1248
+ "# Sort by first objective (val_loss) for display\n",
1249
+ "best_trials_sorted = sorted(study.best_trials, key=lambda t: t.values[0])\n",
1250
+ "\n",
1251
+ "for i, trial in enumerate(best_trials_sorted[:5]): # Show top 5\n",
1252
+ " print(f\"\\nTrial #{trial.number}:\")\n",
1253
+ " print(f\" Validation Loss: {trial.values[0]:.4f}\")\n",
1254
+ " print(f\" Validation Accuracy: {trial.values[1]:.4f}\")\n",
1255
+ " print(f\" Key Hyperparameters:\")\n",
1256
+ " print(f\" - Batch Size: {trial.params.get('batch_size')}\")\n",
1257
+ " print(f\" - Temp Filters: {trial.params.get('temp_filters')}\")\n",
1258
+ " print(f\" - MSA Filters: {trial.params.get('msa_filters')}\")\n",
1259
+ " print(f\" - HateBERT LR: {trial.params.get('hatebert_lr'):.2e}\")\n",
1260
+ " print(f\" - Temp LR: {trial.params.get('temp_lr'):.2e}\")\n",
1261
+ " print(f\" - Hate Layers: {trial.params.get('hate_layers')}\")\n",
1262
+ "\n",
1263
+ "# Print top 5 trials as DataFrame\n",
1264
+ "print(f\"\\n\" + \"=\"*80)\n",
1265
+ "print(\"TOP 5 TRIALS (by validation loss)\")\n",
1266
+ "print(\"=\"*80)\n",
1267
+ "trials_df = study.trials_dataframe().sort_values('values_0').head(5)\n",
1268
+ "print(trials_df[['number', 'values_0', 'values_1', 'state', 'params_batch_size', 'params_temp_filters', 'params_msa_filters']].to_string())\n"
1269
+ ]
1270
+ },
1271
+ {
1272
+ "cell_type": "code",
1273
+ "execution_count": null,
1274
+ "metadata": {
1275
+ "trusted": true
1276
+ },
1277
+ "outputs": [],
1278
+ "source": [
1279
+ "fig, axes = plt.subplots(3, 4, figsize=(20, 15))\n",
1280
+ "\n",
1281
+ "# Row 1: Train vs Val comparisons\n",
1282
+ "# Accuracy\n",
1283
+ "axes[0, 0].plot(history['train_acc'], label='Train', marker='o')\n",
1284
+ "axes[0, 0].plot(history['val_acc'], label='Validation', marker='s')\n",
1285
+ "axes[0, 0].set_title('Accuracy')\n",
1286
+ "axes[0, 0].set_xlabel('Epoch')\n",
1287
+ "axes[0, 0].set_ylabel('Accuracy')\n",
1288
+ "axes[0, 0].legend()\n",
1289
+ "axes[0, 0].grid(True)\n",
1290
+ "\n",
1291
+ "# Loss\n",
1292
+ "axes[0, 1].plot(history['train_loss'], label='Train', marker='o')\n",
1293
+ "axes[0, 1].plot(history['val_loss'], label='Validation', marker='s')\n",
1294
+ "axes[0, 1].set_title('Loss')\n",
1295
+ "axes[0, 1].set_xlabel('Epoch')\n",
1296
+ "axes[0, 1].set_ylabel('Loss')\n",
1297
+ "axes[0, 1].legend()\n",
1298
+ "axes[0, 1].grid(True)\n",
1299
+ "\n",
1300
+ "# F1 Score\n",
1301
+ "axes[0, 2].plot(history['train_f1'], label='Train', marker='o')\n",
1302
+ "axes[0, 2].plot(history['val_f1'], label='Validation', marker='s')\n",
1303
+ "axes[0, 2].set_title('F1 Score (Binary)')\n",
1304
+ "axes[0, 2].set_xlabel('Epoch')\n",
1305
+ "axes[0, 2].set_ylabel('F1')\n",
1306
+ "axes[0, 2].legend()\n",
1307
+ "axes[0, 2].grid(True)\n",
1308
+ "\n",
1309
+ "# Precision\n",
1310
+ "axes[0, 3].plot(history['train_precision'], label='Train', marker='o')\n",
1311
+ "axes[0, 3].plot(history['val_precision'], label='Validation', marker='s')\n",
1312
+ "axes[0, 3].set_title('Precision (Binary)')\n",
1313
+ "axes[0, 3].set_xlabel('Epoch')\n",
1314
+ "axes[0, 3].set_ylabel('Precision')\n",
1315
+ "axes[0, 3].legend()\n",
1316
+ "axes[0, 3].grid(True)\n",
1317
+ "\n",
1318
+ "# Row 2: More Train vs Val + Individual metrics\n",
1319
+ "# Recall\n",
1320
+ "axes[1, 0].plot(history['train_recall'], label='Train', marker='o')\n",
1321
+ "axes[1, 0].plot(history['val_recall'], label='Validation', marker='s')\n",
1322
+ "axes[1, 0].set_title('Recall (Binary)')\n",
1323
+ "axes[1, 0].set_xlabel('Epoch')\n",
1324
+ "axes[1, 0].set_ylabel('Recall')\n",
1325
+ "axes[1, 0].legend()\n",
1326
+ "axes[1, 0].grid(True)\n",
1327
+ "\n",
1328
+ "# Epoch Time\n",
1329
+ "axes[1, 1].plot(history['epoch_time'], marker='o', color='green')\n",
1330
+ "axes[1, 1].set_title('Epoch Time')\n",
1331
+ "axes[1, 1].set_xlabel('Epoch')\n",
1332
+ "axes[1, 1].set_ylabel('Time (s)')\n",
1333
+ "axes[1, 1].grid(True)\n",
1334
+ "\n",
1335
+ "# Train Throughput\n",
1336
+ "axes[1, 2].plot(history['train_throughput'], marker='o', color='purple')\n",
1337
+ "axes[1, 2].set_title('Train Throughput')\n",
1338
+ "axes[1, 2].set_xlabel('Epoch')\n",
1339
+ "axes[1, 2].set_ylabel('Samples/sec')\n",
1340
+ "axes[1, 2].grid(True)\n",
1341
+ "\n",
1342
+ "# Validation Confidence\n",
1343
+ "axes[1, 3].errorbar(range(len(history['val_confidence_mean'])), \n",
1344
+ " history['val_confidence_mean'], \n",
1345
+ " yerr=history['val_confidence_std'], \n",
1346
+ " marker='o', color='orange', capsize=3)\n",
1347
+ "axes[1, 3].set_title('Validation Confidence')\n",
1348
+ "axes[1, 3].set_xlabel('Epoch')\n",
1349
+ "axes[1, 3].set_ylabel('Confidence')\n",
1350
+ "axes[1, 3].grid(True)\n",
1351
+ "\n",
1352
+ "# Row 3: GPU Memory and Weighted/Macro metrics\n",
1353
+ "# GPU Memory\n",
1354
+ "axes[2, 0].plot(history['gpu_memory_mb'], marker='o', color='red')\n",
1355
+ "axes[2, 0].set_title('GPU Memory Usage')\n",
1356
+ "axes[2, 0].set_xlabel('Epoch')\n",
1357
+ "axes[2, 0].set_ylabel('Memory (MB)')\n",
1358
+ "axes[2, 0].grid(True)\n",
1359
+ "\n",
1360
+ "# Weighted F1 comparison\n",
1361
+ "axes[2, 1].plot(history['train_f1_weighted'], label='Train', marker='o')\n",
1362
+ "axes[2, 1].plot(history['val_f1_weighted'], label='Validation', marker='s')\n",
1363
+ "axes[2, 1].set_title('F1 Score (Weighted)')\n",
1364
+ "axes[2, 1].set_xlabel('Epoch')\n",
1365
+ "axes[2, 1].set_ylabel('F1')\n",
1366
+ "axes[2, 1].legend()\n",
1367
+ "axes[2, 1].grid(True)\n",
1368
+ "\n",
1369
+ "# Macro F1 comparison\n",
1370
+ "axes[2, 2].plot(history['train_f1_macro'], label='Train', marker='o')\n",
1371
+ "axes[2, 2].plot(history['val_f1_macro'], label='Validation', marker='s')\n",
1372
+ "axes[2, 2].set_title('F1 Score (Macro)')\n",
1373
+ "axes[2, 2].set_xlabel('Epoch')\n",
1374
+ "axes[2, 2].set_ylabel('F1')\n",
1375
+ "axes[2, 2].legend()\n",
1376
+ "axes[2, 2].grid(True)\n",
1377
+ "\n",
1378
+ "# Hide last subplot\n",
1379
+ "axes[2, 3].axis('off')\n",
1380
+ "\n",
1381
+ "plt.suptitle('Training History', fontsize=16, y=1.02)\n",
1382
+ "plt.tight_layout()\n",
1383
+ "plt.show()"
1384
+ ]
1385
+ }
1386
+ ],
1387
+ "metadata": {
1388
+ "kaggle": {
1389
+ "accelerator": "gpu",
1390
+ "dataSources": [
1391
+ {
1392
+ "databundleVersionId": 15785969,
1393
+ "datasetId": 9546304,
1394
+ "sourceId": 14919503,
1395
+ "sourceType": "datasetVersion"
1396
+ }
1397
+ ],
1398
+ "dockerImageVersionId": 31286,
1399
+ "isGpuEnabled": true,
1400
+ "isInternetEnabled": true,
1401
+ "language": "python",
1402
+ "sourceType": "notebook"
1403
+ },
1404
+ "kernelspec": {
1405
+ "display_name": "Python 3",
1406
+ "language": "python",
1407
+ "name": "python3"
1408
+ },
1409
+ "language_info": {
1410
+ "codemirror_mode": {
1411
+ "name": "ipython",
1412
+ "version": 3
1413
+ },
1414
+ "file_extension": ".py",
1415
+ "mimetype": "text/x-python",
1416
+ "name": "python",
1417
+ "nbconvert_exporter": "python",
1418
+ "pygments_lexer": "ipython3",
1419
+ "version": "3.12.12"
1420
+ }
1421
+ },
1422
+ "nbformat": 4,
1423
+ "nbformat_minor": 4
1424
+ }
src/app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gc
 
2
 
3
  import streamlit as st
4
  from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
@@ -24,6 +25,16 @@ def load_cached_model(model_type="altered"):
24
  """Load and cache the model"""
25
  return load_model_from_hf(model_type=model_type)
26
 
 
 
 
 
 
 
 
 
 
 
27
  # Custom CSS
28
  st.markdown("""
29
  <style>
@@ -156,21 +167,24 @@ classify_button = st.button("πŸ” Analyze Text", type="primary", use_container_w
156
 
157
  if classify_button:
158
  if user_input and user_input.strip():
 
 
 
159
  with st.spinner('πŸ”„ Generating rationale from Mistral AI...'):
160
  # --- Step 1: Get rationale from Mistral ---
161
  try:
162
- raw_rationale = get_rationale_from_mistral(user_input)
163
  cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
164
  print(f"Raw rationale from Mistral: {raw_rationale}")
165
  except Exception as e:
166
  st.error(f"❌ Error generating/processing rationale: {str(e)}")
167
- cleaned_rationale = user_input # fallback to raw input
168
 
169
  with st.spinner('πŸ”„ Analyzing text with models...'):
170
  # Run enhanced model
171
  enhanced_start = time.time()
172
  enhanced_model_result = predict_hatespeech(
173
- text=user_input,
174
  rationale=cleaned_rationale, # use cleaned rationale
175
  model=enhanced_model,
176
  tokenizer_hatebert=enhanced_tokenizer_hatebert,
@@ -184,7 +198,7 @@ if classify_button:
184
  # Run base model
185
  base_start = time.time()
186
  base_model_result = predict_hatespeech(
187
- text=user_input,
188
  rationale=cleaned_rationale, # use cleaned rationale
189
  model=base_model,
190
  tokenizer_hatebert=base_tokenizer_hatebert,
 
1
  import gc
2
+ import re
3
 
4
  import streamlit as st
5
  from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
 
25
  """Load and cache the model"""
26
  return load_model_from_hf(model_type=model_type)
27
 
28
+ def clean_user_input(text):
29
+ """Remove URLs and special characters (except exclamation points) from text"""
30
+ # Remove URLs
31
+ text = re.sub(r'https?://\S+|www\.\S+', '', text)
32
+ # Remove special characters except exclamation points
33
+ text = re.sub(r'[^a-zA-Z0-9\s!]', '', text)
34
+ # Remove extra whitespace
35
+ text = re.sub(r'\s+', ' ', text).strip()
36
+ return text
37
+
38
  # Custom CSS
39
  st.markdown("""
40
  <style>
 
167
 
168
  if classify_button:
169
  if user_input and user_input.strip():
170
+ # Clean the input text
171
+ cleaned_input = clean_user_input(user_input)
172
+
173
  with st.spinner('πŸ”„ Generating rationale from Mistral AI...'):
174
  # --- Step 1: Get rationale from Mistral ---
175
  try:
176
+ raw_rationale = get_rationale_from_mistral(cleaned_input)
177
  cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
178
  print(f"Raw rationale from Mistral: {raw_rationale}")
179
  except Exception as e:
180
  st.error(f"❌ Error generating/processing rationale: {str(e)}")
181
+ cleaned_rationale = cleaned_input # fallback to cleaned input
182
 
183
  with st.spinner('πŸ”„ Analyzing text with models...'):
184
  # Run enhanced model
185
  enhanced_start = time.time()
186
  enhanced_model_result = predict_hatespeech(
187
+ text=cleaned_input,
188
  rationale=cleaned_rationale, # use cleaned rationale
189
  model=enhanced_model,
190
  tokenizer_hatebert=enhanced_tokenizer_hatebert,
 
198
  # Run base model
199
  base_start = time.time()
200
  base_model_result = predict_hatespeech(
201
+ text=cleaned_input,
202
  rationale=cleaned_rationale, # use cleaned rationale
203
  model=base_model,
204
  tokenizer_hatebert=base_tokenizer_hatebert,