File size: 1,505 Bytes
c5260a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from transformers import XLMRobertaModel, PreTrainedModel, PretrainedConfig

class MultiTaskConfig(PretrainedConfig):
    model_type = "bert_guard"
    
    def __init__(self, num_prompt_labels=2, num_response_labels=2, num_categories=13, **kwargs):
        super().__init__(**kwargs)
        self.num_prompt_labels = num_prompt_labels
        self.num_response_labels = num_response_labels
        self.num_categories = num_categories

class MultiTaskModel(PreTrainedModel):
    config_class = MultiTaskConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.bert = XLMRobertaModel.from_pretrained('xlm-roberta-base')
        hidden_size = self.bert.config.hidden_size
        
        self.dropout = nn.Dropout(0.1)
        self.prompt_classifier = nn.Linear(hidden_size, config.num_prompt_labels)
        self.response_classifier = nn.Linear(hidden_size, config.num_response_labels)
        self.category_classifier = nn.Linear(hidden_size, config.num_categories)
        
    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs.last_hidden_state[:, 0, :])
        
        return {
            'prompt_logits': self.prompt_classifier(pooled_output),
            'response_logits': self.response_classifier(pooled_output),
            'category_logits': self.category_classifier(pooled_output)
        }