| import torch | |
| import torch.nn as nn | |
| from transformers import XLMRobertaModel, PreTrainedModel, PretrainedConfig | |
| class MultiTaskConfig(PretrainedConfig): | |
| model_type = "bert_guard" | |
| def __init__(self, num_prompt_labels=2, num_response_labels=2, num_categories=13, **kwargs): | |
| super().__init__(**kwargs) | |
| self.num_prompt_labels = num_prompt_labels | |
| self.num_response_labels = num_response_labels | |
| self.num_categories = num_categories | |
| class MultiTaskModel(PreTrainedModel): | |
| config_class = MultiTaskConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bert = XLMRobertaModel.from_pretrained('xlm-roberta-base') | |
| hidden_size = self.bert.config.hidden_size | |
| self.dropout = nn.Dropout(0.1) | |
| self.prompt_classifier = nn.Linear(hidden_size, config.num_prompt_labels) | |
| self.response_classifier = nn.Linear(hidden_size, config.num_response_labels) | |
| self.category_classifier = nn.Linear(hidden_size, config.num_categories) | |
| def forward(self, input_ids, attention_mask, **kwargs): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = self.dropout(outputs.last_hidden_state[:, 0, :]) | |
| return { | |
| 'prompt_logits': self.prompt_classifier(pooled_output), | |
| 'response_logits': self.response_classifier(pooled_output), | |
| 'category_logits': self.category_classifier(pooled_output) | |
| } | |