{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2025-03-25T16:04:47.234614Z", "start_time": "2025-03-25T16:04:47.228876Z" } }, "source": [ "import torch\n", "import torch.nn as nn\n", "from transformers import BertForSequenceClassification, BertTokenizerFast, BertModel, BertPreTrainedModel, BertConfig\n", "from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput\n", "from typing import Optional, Tuple, Union" ], "outputs": [], "execution_count": 46 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T16:04:47.392034Z", "start_time": "2025-03-25T16:04:47.379839Z" } }, "cell_type": "code", "source": [ "class BertConvModel(BertPreTrainedModel):\n", " def __init__(self, config: BertConfig):\n", " super().__init__(config)\n", " self.encoder = BertModel(config)\n", " self.conv3 = nn.Conv1d(\n", " in_channels=config.hidden_size,\n", " out_channels=256,\n", " kernel_size=3,\n", " padding=1,\n", " )\n", " self.conv5 = nn.Conv1d(\n", " in_channels=config.hidden_size,\n", " out_channels=256,\n", " kernel_size=5,\n", " padding=2,\n", " )\n", " self.conv7 = nn.Conv1d(\n", " in_channels=config.hidden_size,\n", " out_channels=256,\n", " kernel_size=7,\n", " padding=3,\n", " )\n", " self.conv_bn = nn.BatchNorm1d(256*3)\n", " self.linear = nn.Linear(256*3, config.hidden_size)\n", " self.act = nn.GELU()\n", " self.layernorm = nn.LayerNorm(config.hidden_size)\n", "\n", " def forward(self, input_ids, attention_mask, token_type_ids):\n", " encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n", " last_hidden_state = encoder_outputs.last_hidden_state # [B, L, H]\n", "\n", " hidden_conv = last_hidden_state.permute(0, 2, 1) # [B, H, L]\n", "\n", " combined = torch.cat([\n", " self.conv3(hidden_conv),\n", " self.conv5(hidden_conv),\n", " self.conv7(hidden_conv),\n", " ], dim=1).permute(0,2, 1) # [B, L, H]\n", " fused = self.linear(combined)\n", " fused = self.act(fused)\n", "\n", " output = last_hidden_state + fused\n", " output = self.layernorm(output)\n", "\n", " return BaseModelOutput(\n", " last_hidden_state=output\n", " )" ], "id": "34d786f5b97b8bab", "outputs": [], "execution_count": 47 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T16:04:47.507570Z", "start_time": "2025-03-25T16:04:47.490208Z" } }, "cell_type": "code", "source": [ "class BertConvForSequenceClassification(BertPreTrainedModel):\n", " def __init__(self, config: BertConfig):\n", " super().__init__(config)\n", " self.config = config\n", " self.num_labels = config.num_labels\n", " self.bert_conv = BertConvModel(config)\n", " classifier_dropout = (\n", " config.classifier_dropout if config.classifier_dropout is not None\n", " else config.hidden_dropout_prob\n", " )\n", " self.dropout = nn.Dropout(classifier_dropout)\n", " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n", "\n", " self.post_init()\n", "\n", " def forward(\n", " self,\n", " input_ids: Optional[torch.Tensor] = None,\n", " attention_mask: Optional[torch.Tensor] = None,\n", " token_type_ids: Optional[torch.Tensor] = None,\n", " labels: Optional[torch.Tensor] = None,\n", " return_dict: Optional[bool] = None,\n", " ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:\n", " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", "\n", " outputs = self.bert_conv(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " token_type_ids=token_type_ids\n", " )\n", "\n", " last_hidden_state = outputs.last_hidden_state\n", " pooled_output = last_hidden_state[:, 0, :]\n", " pooled_output = self.dropout(pooled_output)\n", " logits = self.classifier(pooled_output)\n", "\n", " loss = None\n", " if labels is not None:\n", " if self.config.problem_type is None:\n", " if self.num_labels == 1:\n", " self.config.problem_type = \"regression\"\n", " elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n", " self.config.problem_type = \"single_label_classification\"\n", " else:\n", " self.config.problem_type = \"multi_label_classification\"\n", "\n", " if self.config.problem_type == \"regression\":\n", " loss_fct = nn.MSELoss()\n", " if self.num_labels == 1:\n", " loss = loss_fct(logits.squeeze(), labels.squeeze())\n", " else:\n", " loss = loss_fct(logits, labels)\n", " elif self.config.problem_type == \"single_label_classification\":\n", " loss_fct = nn.CrossEntropyLoss()\n", " loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n", " elif self.config.problem_type == \"multi_label_classification\":\n", " loss_fct = nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits, labels)\n", "\n", " if not return_dict:\n", " output = (logits,) + outputs[2:]\n", " return ((loss,) + output) if loss is not None else output\n", "\n", " return SequenceClassifierOutput(\n", " loss=loss,\n", " logits=logits,\n", " hidden_states=outputs.hidden_states,\n", " attentions=outputs.attentions,\n", " )" ], "id": "e1afead74e5d56c8", "outputs": [], "execution_count": 48 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T16:04:47.593331Z", "start_time": "2025-03-25T16:04:47.588584Z" } }, "cell_type": "code", "source": "from datasets import load_dataset, concatenate_datasets, DatasetDict", "id": "ef15760c46f3148b", "outputs": [], "execution_count": 49 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:08.446489Z", "start_time": "2025-03-25T14:49:03.966024Z" } }, "cell_type": "code", "source": [ "mnli = load_dataset(\"bias-amplified-splits/mnli\", \"minority_examples\")\n", "mnli" ], "id": "dec1c1ae07c4474", "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train.biased: Dataset({\n", " features: ['premise', 'hypothesis', 'label', 'idx'],\n", " num_rows: 309873\n", " })\n", " train.anti_biased: Dataset({\n", " features: ['premise', 'hypothesis', 'label', 'idx'],\n", " num_rows: 82829\n", " })\n", " validation_matched.biased: Dataset({\n", " features: ['premise', 'hypothesis', 'label', 'idx'],\n", " num_rows: 7771\n", " })\n", " validation_matched.anti_biased: Dataset({\n", " features: ['premise', 'hypothesis', 'label', 'idx'],\n", " num_rows: 2044\n", " })\n", " validation_mismatched.biased: Dataset({\n", " features: ['premise', 'hypothesis', 'label', 'idx'],\n", " num_rows: 7797\n", " })\n", " validation_mismatched.anti_biased: Dataset({\n", " features: ['premise', 'hypothesis', 'label', 'idx'],\n", " num_rows: 2035\n", " })\n", "})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:08.554250Z", "start_time": "2025-03-25T14:49:08.506856Z" } }, "cell_type": "code", "source": [ "train = concatenate_datasets([mnli[\"train.biased\"], mnli[\"train.anti_biased\"]])\n", "\n", "val_matched_biased = mnli[\"validation_matched.biased\"]\n", "val_matched_anti_biased = mnli[\"validation_matched.anti_biased\"]\n", "val_matched = concatenate_datasets([val_matched_biased, val_matched_anti_biased])\n", "\n", "val_mismatched_biased = mnli[\"validation_mismatched.biased\"]\n", "val_mismatched_anti_biased = mnli[\"validation_mismatched.anti_biased\"]\n", "val_mismatched = concatenate_datasets([val_mismatched_biased, val_mismatched_anti_biased])\n", "\n", "test = concatenate_datasets([val_matched, val_mismatched])" ], "id": "f7a87126395bf25d", "outputs": [], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:08.594195Z", "start_time": "2025-03-25T14:49:08.575454Z" } }, "cell_type": "code", "source": [ "data = DatasetDict({\n", " \"train\": train,\n", " \"test\": test,\n", "}).remove_columns(['idx'])\n", "data" ], "id": "bdeb7ea17acd9bd4", "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['premise', 'hypothesis', 'label'],\n", " num_rows: 392702\n", " })\n", " test: Dataset({\n", " features: ['premise', 'hypothesis', 'label'],\n", " num_rows: 19647\n", " })\n", "})" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:08.967254Z", "start_time": "2025-03-25T14:49:08.636482Z" } }, "cell_type": "code", "source": "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")", "id": "cb332f816eb96ca6", "outputs": [], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:09.013451Z", "start_time": "2025-03-25T14:49:08.996755Z" } }, "cell_type": "code", "source": [ "premise = \"The cat sat on the mat.\"\n", "hypothesis = \"The cat was sitting on the mat.\"\n", "\n", "tokenizer.decode(tokenizer(premise, hypothesis, padding=True)['input_ids'])" ], "id": "e6b12359e9a054fc", "outputs": [ { "data": { "text/plain": [ "'[CLS] the cat sat on the mat. [SEP] the cat was sitting on the mat. [SEP]'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:09.066234Z", "start_time": "2025-03-25T14:49:09.061293Z" } }, "cell_type": "code", "source": [ "def preprocess(examples):\n", " return tokenizer(examples['premise'], examples['hypothesis'], truncation=\"longest_first\", max_length=512)" ], "id": "403351acf45cb794", "outputs": [], "execution_count": 12 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T14:49:16.810416Z", "start_time": "2025-03-25T14:49:09.115611Z" } }, "cell_type": "code", "source": "tokenized_data = data.map(preprocess, batched=True, num_proc=20, remove_columns=(\"premise\", \"hypothesis\"))", "id": "737168d34408b655", "outputs": [ { "data": { "text/plain": [ "Map (num_proc=20): 0%| | 0/392702 [00:00" ], "text/html": [ "\n", "
\n", " \n", " \n", " [77/77 15:51]\n", "
\n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 1.1944094896316528,\n", " 'eval_model_preparation_time': 0.007,\n", " 'eval_accuracy': 0.3641268387031099,\n", " 'eval_precision': 0.3050162608329799,\n", " 'eval_recall': 0.3641268387031099,\n", " 'eval_f1': 0.29583067778201166,\n", " 'eval_runtime': 19.8257,\n", " 'eval_samples_per_second': 990.988,\n", " 'eval_steps_per_second': 3.884}" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 59 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T17:16:14.590823Z", "start_time": "2025-03-25T16:12:26.732033Z" } }, "cell_type": "code", "source": "trainer.train()", "id": "46524fd4a95af711", "outputs": [ { "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [20000/20000 1:03:47, Epoch 3/4]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossModel Preparation TimeAccuracyPrecisionRecallF1
50000.5671000.4349570.0070000.8318320.8369410.8318320.832825
100000.3689000.4244740.0070000.8438950.8459850.8438950.844391
150000.2755000.5013430.0070000.8445560.8472590.8445560.845071
200000.2010000.5515700.0070000.8456760.8484080.8456760.846358

" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=20000, training_loss=0.35313146667480466, metrics={'train_runtime': 3827.3631, 'train_samples_per_second': 334.434, 'train_steps_per_second': 5.226, 'total_flos': 7.376417927681814e+16, 'train_loss': 0.35313146667480466, 'epoch': 3.259452411994785})" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 60 }, { "metadata": {}, "cell_type": "markdown", "source": "Result:", "id": "3531b1a18e32af40" }, { "metadata": {}, "cell_type": "markdown", "source": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossModel Preparation TimeAccuracyPrecisionRecallF1
50000.5671000.4349570.0070000.8318320.8369410.8318320.832825
100000.3689000.4244740.0070000.8438950.8459850.8438950.844391
150000.2755000.5013430.0070000.8445560.8472590.8445560.845071
200000.2010000.5515700.0070000.8456760.8484080.8456760.846358
\n" ], "id": "db1ed297ef851eab" }, { "metadata": {}, "cell_type": "markdown", "source": "Comparison", "id": "e484fedcd827fd96" }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T17:16:15.342756Z", "start_time": "2025-03-25T17:16:14.703412Z" } }, "cell_type": "code", "source": "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\", num_labels=3, id2label=id2label, label2id=label2id)", "id": "8b3a585df6e993c8", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "execution_count": 61 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T17:16:15.555113Z", "start_time": "2025-03-25T17:16:15.392354Z" } }, "cell_type": "code", "source": [ "training_args = TrainingArguments(\n", " output_dir=\"./compare\",\n", " overwrite_output_dir=True,\n", " eval_strategy=\"steps\",\n", " logging_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " save_steps=5000,\n", " eval_steps=5000,\n", " logging_steps=5000,\n", " max_steps=20000,\n", " learning_rate=3e-5,\n", " weight_decay=0.001,\n", " adam_epsilon=1e-8,\n", " warmup_steps=1000,\n", " report_to=\"tensorboard\",\n", " per_device_train_batch_size=64,\n", " #gradient_accumulation_steps=2,\n", " per_device_eval_batch_size=256,\n", " fp16=True,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_data['train'],\n", " eval_dataset=tokenized_data['test'],\n", " processing_class=tokenizer,\n", " data_collator=data_collator,\n", " #preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n", " compute_metrics=compute_metrics,\n", " #optimizers=(optimizer, scheduler),\n", ")" ], "id": "be0ec82ebb4c18ee", "outputs": [], "execution_count": 62 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T17:16:36.635075Z", "start_time": "2025-03-25T17:16:15.574263Z" } }, "cell_type": "code", "source": "trainer.evaluate()", "id": "157359d28e31f33f", "outputs": [ { "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [77/77 16:10]\n", "
\n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 1.173392415046692,\n", " 'eval_model_preparation_time': 0.0034,\n", " 'eval_accuracy': 0.3155189087392477,\n", " 'eval_precision': 0.31114208439248486,\n", " 'eval_recall': 0.3155189087392477,\n", " 'eval_f1': 0.1570748637959829,\n", " 'eval_runtime': 21.0427,\n", " 'eval_samples_per_second': 933.671,\n", " 'eval_steps_per_second': 3.659}" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 63 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-25T18:19:14.283378Z", "start_time": "2025-03-25T17:16:36.663666Z" } }, "cell_type": "code", "source": "trainer.train()", "id": "96899639b391ead", "outputs": [ { "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [20000/20000 1:02:36, Epoch 3/4]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossModel Preparation TimeAccuracyPrecisionRecallF1
50000.5664000.4378720.0034000.8313740.8362920.8313740.832473
100000.3692000.4260020.0034000.8434370.8463170.8434370.844099
150000.2768000.4815460.0034000.8428260.8454950.8428260.843323
200000.2036000.5296400.0034000.8440470.8466020.8440470.844737

" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=20000, training_loss=0.35399990234375, metrics={'train_runtime': 3757.0942, 'train_samples_per_second': 340.689, 'train_steps_per_second': 5.323, 'total_flos': 7.083480128775549e+16, 'train_loss': 0.35399990234375, 'epoch': 3.259452411994785})" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 64 }, { "metadata": {}, "cell_type": "markdown", "source": "Result:", "id": "6d9604187a3d84c5" }, { "metadata": {}, "cell_type": "markdown", "source": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossModel Preparation TimeAccuracyPrecisionRecallF1
50000.5664000.4378720.0034000.8313740.8362920.8313740.832473
100000.3692000.4260020.0034000.8434370.8463170.8434370.844099
150000.2768000.4815460.0034000.8428260.8454950.8428260.843323
200000.2036000.5296400.0034000.8440470.8466020.8440470.844737
" ], "id": "a32ebe5d6ce99f98" }, { "metadata": {}, "cell_type": "markdown", "source": "ChromaDB Embedding Function", "id": "f37bfcfe59d88a95" }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "from chromadb import Documents, EmbeddingFunction, Embeddings", "id": "291aa9e620dc571d" }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "class BertConvEmbeddingFunction(EmbeddingFunction):\n", " def __init__(self, model_path, device=None):\n", " super().__init__()\n", " self.model = BertConvModel.from_pretrained(model_path)\n", " self.tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n", " self.device = device or torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " self.model.to(self.device)\n", " self.model.eval()\n", "\n", " def __call__(self, input: Documents) -> Embeddings:\n", " encoded_input = self.tokenizer(\n", " input,\n", " padding=True,\n", " truncation=True,\n", " max_length=512,\n", " return_tensors=\"pt\",\n", " )\n", " encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}\n", "\n", " with torch.no_grad():\n", " outputs = self.model(**encoded_input, return_dict=True)\n", "\n", " embeddings = outputs.last_hidden_state.cpu().tolist()\n", " return embeddings" ], "id": "142c0fcc0a92667a" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }