bert-guard-multi-task / modeling_bert_guard.py
jainsatyam26's picture
Add BERT-GUARD Multi-Task model, tokenizer, and config with README.md
c5260a8 verified
import torch
import torch.nn as nn
from transformers import XLMRobertaModel, PreTrainedModel, PretrainedConfig
class MultiTaskConfig(PretrainedConfig):
model_type = "bert_guard"
def __init__(self, num_prompt_labels=2, num_response_labels=2, num_categories=13, **kwargs):
super().__init__(**kwargs)
self.num_prompt_labels = num_prompt_labels
self.num_response_labels = num_response_labels
self.num_categories = num_categories
class MultiTaskModel(PreTrainedModel):
config_class = MultiTaskConfig
def __init__(self, config):
super().__init__(config)
self.bert = XLMRobertaModel.from_pretrained('xlm-roberta-base')
hidden_size = self.bert.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.prompt_classifier = nn.Linear(hidden_size, config.num_prompt_labels)
self.response_classifier = nn.Linear(hidden_size, config.num_response_labels)
self.category_classifier = nn.Linear(hidden_size, config.num_categories)
def forward(self, input_ids, attention_mask, **kwargs):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.last_hidden_state[:, 0, :])
return {
'prompt_logits': self.prompt_classifier(pooled_output),
'response_logits': self.response_classifier(pooled_output),
'category_logits': self.category_classifier(pooled_output)
}