import torch import torch.nn as nn from transformers import RobertaModel, RobertaConfig, PreTrainedModel class RadBertForSequenceClassification(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "model" def __init__(self, config): super().__init__(config) num_labels = getattr(config, "num_labels", 2) self.model = RobertaModel(config) self.classifier = nn.Linear(config.hidden_size, num_labels) self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs, ) pooled_output = outputs.pooler_output if pooled_output is None: pooled_output = outputs.last_hidden_state[:, 0] return self.classifier(pooled_output)