import torch import torch.nn as nn from transformers import XLMRobertaModel, PreTrainedModel, PretrainedConfig class MultiTaskConfig(PretrainedConfig): model_type = "bert_guard" def __init__(self, num_prompt_labels=2, num_response_labels=2, num_categories=13, **kwargs): super().__init__(**kwargs) self.num_prompt_labels = num_prompt_labels self.num_response_labels = num_response_labels self.num_categories = num_categories class MultiTaskModel(PreTrainedModel): config_class = MultiTaskConfig def __init__(self, config): super().__init__(config) self.bert = XLMRobertaModel.from_pretrained('xlm-roberta-base') hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(0.1) self.prompt_classifier = nn.Linear(hidden_size, config.num_prompt_labels) self.response_classifier = nn.Linear(hidden_size, config.num_response_labels) self.category_classifier = nn.Linear(hidden_size, config.num_categories) def forward(self, input_ids, attention_mask, **kwargs): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = self.dropout(outputs.last_hidden_state[:, 0, :]) return { 'prompt_logits': self.prompt_classifier(pooled_output), 'response_logits': self.response_classifier(pooled_output), 'category_logits': self.category_classifier(pooled_output) }