import torch from torch import nn from transformers import BertPreTrainedModel, BertModel class MyBERTClassifier(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = BertModel(config) hidden_size = config.hidden_size num_labels = config.num_labels self.classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_size // 2, num_labels) ) self.post_init() def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True, **kwargs ) pooled_output = outputs.pooler_output logits = self.classifier(pooled_output) return logits