| import torch | |
| from torch import nn | |
| from transformers import AutoModel | |
| class BERTDiseaseClassifier(nn.Module): | |
| def __init__(self, model_type, num_symps) -> None: | |
| super().__init__() | |
| self.model_type = model_type | |
| self.num_symps = num_symps | |
| self.encoder = AutoModel.from_pretrained(model_type) | |
| self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) | |
| self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps) | |
| def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): | |
| outputs = self.encoder(input_ids, attention_mask, token_type_ids) | |
| x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling | |
| x = self.dropout(x) | |
| logits = self.clf(x) | |
| return logits |