Spaces:
Sleeping
Sleeping
File size: 979 Bytes
fd363e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import torch
from torch import nn
from transformers import BertPreTrainedModel, BertModel
class MyBERTClassifier(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
hidden_size = config.hidden_size
num_labels = config.num_labels
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 2, num_labels)
)
self.post_init()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
**kwargs
)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits |