| import torch.nn as nn | |
| from transformers import AutoModel | |
| class CustomMPRNAForSequenceClassification(nn.Module): | |
| def __init__(self, base_model, num_labels): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.num_labels = num_labels | |
| self.classifier = nn.Linear(base_model.config.hidden_size, num_labels) | |
| self.dropout = nn.Dropout(0.1) | |
| def forward(self, input_ids, attention_mask=None, labels=None): | |
| outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs[0][:, 0, :] | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits} | |