| | import torch |
| | from torch import nn |
| | from transformers import DistilBertModel |
| |
|
| | class QuestionTypeClassifier(nn.Module): |
| | def __init__(self, num_types): |
| | super().__init__() |
| | |
| | |
| | self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased") |
| | |
| | |
| | self.fc = nn.Linear(self.distilbert.config.hidden_size, num_types) |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | outputs = self.distilbert( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask |
| | ) |
| | |
| | |
| | cls_token = outputs.last_hidden_state[:, 0, :] |
| | |
| | logits = self.fc(cls_token) |
| | return logits |
| |
|