| import torch |
| from torch import nn |
| from transformers import AutoModel, PreTrainedModel, AutoConfig |
|
|
| class BERTDiseaseClassifier(nn.Module): |
| def __init__(self, model_type, num_symps) -> None: |
| super().__init__() |
| self.model_type = model_type |
| self.num_symps = num_symps |
| |
| self.encoder = AutoModel.from_pretrained(model_type) |
| self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) |
| self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps) |
| |
| def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): |
| outputs = self.encoder(input_ids, attention_mask, token_type_ids) |
| x = outputs.last_hidden_state[:, 0, :] |
| x = self.dropout(x) |
| logits = self.clf(x) |
| return logits |