In [46]:
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizerFast, BertModel, BertPreTrainedModel, BertConfig
from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from typing import Optional, Tuple, Union

In [47]:
class BertConvModel(BertPreTrainedModel):
    def __init__(self, config: BertConfig):
        super().__init__(config)
        self.encoder = BertModel(config)
        self.conv3 = nn.Conv1d(
            in_channels=config.hidden_size,
            out_channels=256,
            kernel_size=3,
            padding=1,
        )
        self.conv5 = nn.Conv1d(
            in_channels=config.hidden_size,
            out_channels=256,
            kernel_size=5,
            padding=2,
        )
        self.conv7 = nn.Conv1d(
            in_channels=config.hidden_size,
            out_channels=256,
            kernel_size=7,
            padding=3,
        )
        self.conv_bn = nn.BatchNorm1d(256*3)
        self.linear = nn.Linear(256*3, config.hidden_size)
        self.act = nn.GELU()
        self.layernorm = nn.LayerNorm(config.hidden_size)

    def forward(self, input_ids, attention_mask, token_type_ids):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden_state = encoder_outputs.last_hidden_state # [B, L, H]

        hidden_conv =  last_hidden_state.permute(0, 2, 1) # [B, H, L]

        combined = torch.cat([
            self.conv3(hidden_conv),
            self.conv5(hidden_conv),
            self.conv7(hidden_conv),
        ], dim=1).permute(0,2, 1) # [B, L, H]
        fused = self.linear(combined)
        fused = self.act(fused)

        output = last_hidden_state + fused
        output = self.layernorm(output)

        return BaseModelOutput(
            last_hidden_state=output
        )

In [48]:
class BertConvForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config: BertConfig):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        self.bert_conv = BertConvModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert_conv(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        last_hidden_state = outputs.last_hidden_state
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = nn.MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [49]:
from datasets import load_dataset, concatenate_datasets, DatasetDict

In [7]:
mnli = load_dataset("bias-amplified-splits/mnli", "minority_examples")
mnli

DatasetDict({
    train.biased: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 309873
    })
    train.anti_biased: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 82829
    })
    validation_matched.biased: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 7771
    })
    validation_matched.anti_biased: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 2044
    })
    validation_mismatched.biased: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 7797
    })
    validation_mismatched.anti_biased: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 2035
    })
})

In [8]:
train = concatenate_datasets([mnli["train.biased"], mnli["train.anti_biased"]])

val_matched_biased = mnli["validation_matched.biased"]
val_matched_anti_biased = mnli["validation_matched.anti_biased"]
val_matched = concatenate_datasets([val_matched_biased, val_matched_anti_biased])

val_mismatched_biased = mnli["validation_mismatched.biased"]
val_mismatched_anti_biased = mnli["validation_mismatched.anti_biased"]
val_mismatched = concatenate_datasets([val_mismatched_biased, val_mismatched_anti_biased])

test = concatenate_datasets([val_matched, val_mismatched])

In [9]:
data = DatasetDict({
    "train": train,
    "test": test,
}).remove_columns(['idx'])
data

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 392702
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 19647
    })
})

In [10]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [11]:
premise = "The cat sat on the mat."
hypothesis = "The cat was sitting on the mat."

tokenizer.decode(tokenizer(premise, hypothesis, padding=True)['input_ids'])

'[CLS] the cat sat on the mat. [SEP] the cat was sitting on the mat. [SEP]'

In [12]:
def preprocess(examples):
    return tokenizer(examples['premise'], examples['hypothesis'], truncation="longest_first", max_length=512)

In [13]:
tokenized_data = data.map(preprocess, batched=True, num_proc=20, remove_columns=("premise", "hypothesis"))

Map (num_proc=20):   0%|          | 0/392702 [00:00<?, ? examples/s]

Map (num_proc=20):   0%|          | 0/19647 [00:00<?, ? examples/s]

In [14]:
from transformers import DataCollatorWithPadding
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [15]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest', max_length=512)

In [16]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    # Calculate accuracy
    accuracy = accuracy_score(labels, preds)

   # Calculate precision, recall, and F1-score
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

In [17]:
id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

In [57]:
config = BertConfig.from_pretrained("bert-base-uncased", num_labels=3, id2label=id2label, label2id=label2id)
model = BertModel.from_pretrained('bert-base-uncased', config=config)
encoder = BertConvModel(config)
encoder.encoder = model
model = BertConvForSequenceClassification(config)
model.bert_conv = encoder
model

BertConvForSequenceClassification(
  (bert_conv): BertConvModel(
    (encoder): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out

In [20]:
from transformers import TrainingArguments, Trainer, get_linear_schedule_with_warmup
from torch.optim import Adam

In [38]:
optimizer = Adam(
    params=model.parameters(),
    lr=2.5e-5,
    weight_decay=0.01,
    betas=(0.9, 0.999),
    eps=1e-06,
)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=60000,
)

In [58]:
training_args = TrainingArguments(
    output_dir="./output",
    overwrite_output_dir=True,
    eval_strategy="steps",
    logging_strategy="steps",
    save_strategy="steps",
    save_steps=5000,
    eval_steps=5000,
    logging_steps=5000,
    max_steps=20000,
    learning_rate=3e-5,
    weight_decay=0.001,
    adam_epsilon=1e-8,
    warmup_steps=1000,
    report_to="tensorboard",
    per_device_train_batch_size=64,
    #gradient_accumulation_steps=2,
    per_device_eval_batch_size=256,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data['train'],
    eval_dataset=tokenized_data['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
    #preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=compute_metrics,
    #optimizers=(optimizer, scheduler),
)

In [59]:
trainer.evaluate()

{'eval_loss': 1.1944094896316528,
 'eval_model_preparation_time': 0.007,
 'eval_accuracy': 0.3641268387031099,
 'eval_precision': 0.3050162608329799,
 'eval_recall': 0.3641268387031099,
 'eval_f1': 0.29583067778201166,
 'eval_runtime': 19.8257,
 'eval_samples_per_second': 990.988,
 'eval_steps_per_second': 3.884}

In [60]:
trainer.train()

Step,Training Loss,Validation Loss,Model Preparation Time,Accuracy,Precision,Recall,F1
5000,0.5671,0.434957,0.007,0.831832,0.836941,0.831832,0.832825
10000,0.3689,0.424474,0.007,0.843895,0.845985,0.843895,0.844391
15000,0.2755,0.501343,0.007,0.844556,0.847259,0.844556,0.845071
20000,0.201,0.55157,0.007,0.845676,0.848408,0.845676,0.846358


TrainOutput(global_step=20000, training_loss=0.35313146667480466, metrics={'train_runtime': 3827.3631, 'train_samples_per_second': 334.434, 'train_steps_per_second': 5.226, 'total_flos': 7.376417927681814e+16, 'train_loss': 0.35313146667480466, 'epoch': 3.259452411994785})

Result:

<table>
            <thead>
                <tr>
                    <th>Step</th>
                    <th>Training Loss</th>
                    <th>Validation Loss</th>
                    <th>Model Preparation Time</th>
                    <th>Accuracy</th>
                    <th>Precision</th>
                    <th>Recall</th>
                    <th>F1</th>
                </tr>
            </thead>
            <tbody>
                <tr>
                    <td>5000</td>
                    <td>0.567100</td>
                    <td>0.434957</td>
                    <td>0.007000</td>
                    <td>0.831832</td>
                    <td>0.836941</td>
                    <td>0.831832</td>
                    <td>0.832825</td>
                </tr>
                <tr>
                    <td>10000</td>
                    <td>0.368900</td>
                    <td>0.424474</td>
                    <td>0.007000</td>
                    <td>0.843895</td>
                    <td>0.845985</td>
                    <td>0.843895</td>
                    <td>0.844391</td>
                </tr>
                <tr>
                    <td>15000</td>
                    <td>0.275500</td>
                    <td>0.501343</td>
                    <td>0.007000</td>
                    <td>0.844556</td>
                    <td>0.847259</td>
                    <td>0.844556</td>
                    <td>0.845071</td>
                </tr>
                <tr>
                    <td>20000</td>
                    <td>0.201000</td>
                    <td>0.551570</td>
                    <td>0.007000</td>
                    <td>0.845676</td>
                    <td>0.848408</td>
                    <td>0.845676</td>
                    <td>0.846358</td>
                </tr>
            </tbody>
        </table>


Comparison

In [61]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3, id2label=id2label, label2id=label2id)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [62]:
training_args = TrainingArguments(
    output_dir="./compare",
    overwrite_output_dir=True,
    eval_strategy="steps",
    logging_strategy="steps",
    save_strategy="steps",
    save_steps=5000,
    eval_steps=5000,
    logging_steps=5000,
    max_steps=20000,
    learning_rate=3e-5,
    weight_decay=0.001,
    adam_epsilon=1e-8,
    warmup_steps=1000,
    report_to="tensorboard",
    per_device_train_batch_size=64,
    #gradient_accumulation_steps=2,
    per_device_eval_batch_size=256,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data['train'],
    eval_dataset=tokenized_data['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
    #preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=compute_metrics,
    #optimizers=(optimizer, scheduler),
)

In [63]:
trainer.evaluate()

{'eval_loss': 1.173392415046692,
 'eval_model_preparation_time': 0.0034,
 'eval_accuracy': 0.3155189087392477,
 'eval_precision': 0.31114208439248486,
 'eval_recall': 0.3155189087392477,
 'eval_f1': 0.1570748637959829,
 'eval_runtime': 21.0427,
 'eval_samples_per_second': 933.671,
 'eval_steps_per_second': 3.659}

In [64]:
trainer.train()

Step,Training Loss,Validation Loss,Model Preparation Time,Accuracy,Precision,Recall,F1
5000,0.5664,0.437872,0.0034,0.831374,0.836292,0.831374,0.832473
10000,0.3692,0.426002,0.0034,0.843437,0.846317,0.843437,0.844099
15000,0.2768,0.481546,0.0034,0.842826,0.845495,0.842826,0.843323
20000,0.2036,0.52964,0.0034,0.844047,0.846602,0.844047,0.844737


TrainOutput(global_step=20000, training_loss=0.35399990234375, metrics={'train_runtime': 3757.0942, 'train_samples_per_second': 340.689, 'train_steps_per_second': 5.323, 'total_flos': 7.083480128775549e+16, 'train_loss': 0.35399990234375, 'epoch': 3.259452411994785})

Result:

<table>
  <thead>
    <tr>
      <th>Step</th>
      <th>Training Loss</th>
      <th>Validation Loss</th>
      <th>Model Preparation Time</th>
      <th>Accuracy</th>
      <th>Precision</th>
      <th>Recall</th>
      <th>F1</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>5000</td>
      <td>0.566400</td>
      <td>0.437872</td>
      <td>0.003400</td>
      <td>0.831374</td>
      <td>0.836292</td>
      <td>0.831374</td>
      <td>0.832473</td>
    </tr>
    <tr>
      <td>10000</td>
      <td>0.369200</td>
      <td>0.426002</td>
      <td>0.003400</td>
      <td>0.843437</td>
      <td>0.846317</td>
      <td>0.843437</td>
      <td>0.844099</td>
    </tr>
    <tr>
      <td>15000</td>
      <td>0.276800</td>
      <td>0.481546</td>
      <td>0.003400</td>
      <td>0.842826</td>
      <td>0.845495</td>
      <td>0.842826</td>
      <td>0.843323</td>
    </tr>
    <tr>
      <td>20000</td>
      <td>0.203600</td>
      <td>0.529640</td>
      <td>0.003400</td>
      <td>0.844047</td>
      <td>0.846602</td>
      <td>0.844047</td>
      <td>0.844737</td>
    </tr>
  </tbody>
</table>

ChromaDB Embedding Function

In [None]:
from chromadb import Documents, EmbeddingFunction, Embeddings

In [None]:
class BertConvEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model_path, device=None):
        super().__init__()
        self.model = BertConvModel.from_pretrained(model_path)
        self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, input: Documents) -> Embeddings:
        encoded_input = self.tokenizer(
            input,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        encoded_input =  {k: v.to(self.device) for k, v in encoded_input.items()}

        with torch.no_grad():
            outputs = self.model(**encoded_input, return_dict=True)

        embeddings = outputs.last_hidden_state.cpu().tolist()
        return embeddings