| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel | |
| from .model_config import CustomConfig | |
| class LogRegClassifier(nn.Module): | |
| def __init__(self, transformer_output_dim): | |
| super(LogRegClassifier, self).__init__() | |
| self.linear = nn.Linear(transformer_output_dim, 1) | |
| def forward(self, x): | |
| return torch.sigmoid(self.linear(x)) | |
| class CombinedModel(PreTrainedModel): | |
| config_class = CustomConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.transformer = AutoModel.from_pretrained(config.transformer_type) | |
| self.classifier = LogRegClassifier(config.transformer_output_dim) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.last_hidden_state[:, 0, :] | |
| return self.classifier(pooled_output) | |