| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, PreTrainedModel | |
| from .config import AutextificationMTLConfig | |
| class AutextificationMTLModel(PreTrainedModel): | |
| config_class = AutextificationMTLConfig | |
| def __init__(self, config: AutextificationMTLConfig): | |
| super().__init__(config) | |
| self.encoder = AutoModel.from_pretrained(config.transformer_name) | |
| embedding_size = self.encoder.config.hidden_size | |
| self.hidden = torch.nn.Linear(embedding_size, config.hidden_nodes) | |
| self.out_generated = torch.nn.Linear(config.hidden_nodes, 1) | |
| self.out_language = torch.nn.Linear(config.hidden_nodes, 1) | |
| self.threshold = config.threshold | |
| def forward(self, tensor): | |
| output = self.encoder( | |
| input_ids=tensor["input_ids"], | |
| attention_mask=tensor["attention_mask"], | |
| return_dict=True | |
| ) | |
| pooler_output = output["pooler_output"] | |
| out = F.relu(self.hidden(pooler_output)) | |
| out_generated = torch.sigmoid(self.out_generated(out)) | |
| out_language = torch.sigmoid(self.out_language(out)) | |
| out_verdict = out_generated > self.threshold | |
| return { | |
| "is_bot": out_verdict, | |
| "bot_prob": out_generated, | |
| "english_prob": out_language, | |
| } | |