Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from transformers import BertPreTrainedModel, BertModel | |
| class MyBERTClassifier(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bert = BertModel(config) | |
| hidden_size = config.hidden_size | |
| num_labels = config.num_labels | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_size // 2, num_labels) | |
| ) | |
| self.post_init() | |
| def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| return_dict=True, | |
| **kwargs | |
| ) | |
| pooled_output = outputs.pooler_output | |
| logits = self.classifier(pooled_output) | |
| return logits |