File size: 1,505 Bytes
c5260a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
from transformers import XLMRobertaModel, PreTrainedModel, PretrainedConfig
class MultiTaskConfig(PretrainedConfig):
model_type = "bert_guard"
def __init__(self, num_prompt_labels=2, num_response_labels=2, num_categories=13, **kwargs):
super().__init__(**kwargs)
self.num_prompt_labels = num_prompt_labels
self.num_response_labels = num_response_labels
self.num_categories = num_categories
class MultiTaskModel(PreTrainedModel):
config_class = MultiTaskConfig
def __init__(self, config):
super().__init__(config)
self.bert = XLMRobertaModel.from_pretrained('xlm-roberta-base')
hidden_size = self.bert.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.prompt_classifier = nn.Linear(hidden_size, config.num_prompt_labels)
self.response_classifier = nn.Linear(hidden_size, config.num_response_labels)
self.category_classifier = nn.Linear(hidden_size, config.num_categories)
def forward(self, input_ids, attention_mask, **kwargs):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.last_hidden_state[:, 0, :])
return {
'prompt_logits': self.prompt_classifier(pooled_output),
'response_logits': self.response_classifier(pooled_output),
'category_logits': self.category_classifier(pooled_output)
}
|