| import torch | |
| from torch import nn | |
| from transformers import BertModel, BertPreTrainedModel | |
| class CustomBertModel(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bert = BertModel(config) | |
| # Freeze first 6 layers | |
| for param in self.bert.encoder.layer[:6].parameters(): | |
| param.requires_grad = False | |
| self.dropout = nn.Dropout(0.22) | |
| self.fc1 = nn.Linear(768, 512) | |
| self.relu1 = nn.ReLU() | |
| self.fc2 = nn.Linear(512, 512) | |
| self.relu2 = nn.ReLU() | |
| self.fc3 = nn.Linear(512, 128) | |
| self.relu3 = nn.ReLU() | |
| self.fc4 = nn.Linear(128, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.init_weights() | |
| def forward(self, input_ids, attention_mask=None, token_type_ids=None): | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| ) | |
| pooled_output = outputs.pooler_output | |
| x = self.dropout(pooled_output) | |
| x = self.fc1(x) | |
| x = self.relu1(x) | |
| x = self.fc2(x) | |
| x = self.relu2(x) | |
| x = self.fc3(x) | |
| x = self.relu3(x) | |
| x = self.fc4(x) | |
| logits = self.sigmoid(x) | |
| return logits | |