File size: 859 Bytes
bcf7dc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel, AutoConfig
class BERTDiseaseClassifier(nn.Module):
def __init__(self, model_type, num_symps) -> None:
super().__init__()
self.model_type = model_type
self.num_symps = num_symps
# multi-label binary classification
self.encoder = AutoModel.from_pretrained(model_type)
self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob)
self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
outputs = self.encoder(input_ids, attention_mask, token_type_ids)
x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling
x = self.dropout(x)
logits = self.clf(x)
return logits |