Spaces:
Running
Running
| import torch.nn as nn | |
| from bert import BERT | |
| class BERTForClassification(nn.Module): | |
| """ | |
| Progress Classifier Model | |
| """ | |
| def __init__(self, bert: BERT, vocab_size, n_labels): | |
| """ | |
| :param bert: BERT model which should be trained | |
| :param vocab_size: total vocab size for masked_lm | |
| """ | |
| super().__init__() | |
| self.bert = bert | |
| self.linear = nn.Linear(self.bert.hidden, n_labels) | |
| # self.softmax = nn.LogSoftmax(dim=-1) | |
| def forward(self, x, segment_label): | |
| x = self.bert(x, segment_label) | |
| return x, self.linear(x[:, 0]) |