import torch from torch import nn from transformers import AutoModel class BERTDiseaseClassifier(nn.Module): def __init__(self, model_type, num_symps) -> None: super().__init__() self.model_type = model_type self.num_symps = num_symps self.encoder = AutoModel.from_pretrained(model_type) self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps) def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): outputs = self.encoder(input_ids, attention_mask, token_type_ids) x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling x = self.dropout(x) logits = self.clf(x) return logits