| import torch |
| from torch import nn |
| from transformers import DistilBertModel |
|
|
| class QuestionTypeClassifier(nn.Module): |
| def __init__(self, num_types): |
| super().__init__() |
| |
| |
| self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased") |
| |
| |
| self.fc = nn.Linear(self.distilbert.config.hidden_size, num_types) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask |
| ) |
| |
| |
| cls_token = outputs.last_hidden_state[:, 0, :] |
| |
| logits = self.fc(cls_token) |
| return logits |
|
|