jainsatyam26 commited on
Commit
c5260a8
·
verified ·
1 Parent(s): c6cdcab

Add BERT-GUARD Multi-Task model, tokenizer, and config with README.md

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # BERT-GUARD Multi-Task Safety Classifier
3
+
4
+ This model is a multi-task safety classifier designed to identify safe/unsafe prompts and responses, and classify violated categories. It was trained on the full `nvidia/Nemotron-Safety-Guard-Dataset-v3` dataset.
5
+
6
+ ## Model Details
7
+
8
+ - **Base Model:** `xlm-roberta-base`
9
+ - **Model Type:** Multi-Task Sequence Classification (Prompt Safety, Response Safety, Violated Category Classification)
10
+
11
+ ## How to Use
12
+
13
+ To use the model for inference, you can load it using the Hugging Face `transformers` library. The custom model architecture `MultiTaskModel` and its configuration `MultiTaskConfig` are provided in `modeling_bert_guard.py`.
14
+
15
+ ```python
16
+ import torch
17
+ from transformers import AutoTokenizer, pipeline
18
+ from modeling_bert_guard import MultiTaskModel, MultiTaskConfig
19
+
20
+ # Load tokenizer and model
21
+ tokenizer = AutoTokenizer.from_pretrained('./bert_guard_final')
22
+ model = MultiTaskModel.from_pretrained('./bert_guard_final')
23
+
24
+ # Load label configuration
25
+ with open('./bert_guard_final/label_config.json', 'r') as f:
26
+ label_config = json.load(f)
27
+
28
+ id_to_label = {int(k): v for k, v in label_config['id_to_label'].items()}
29
+ id_to_violated_category = {int(k): v for k, v in label_config['id_to_violated_category'].items()}
30
+
31
+ def predict_safety(prompt_text, response_text):
32
+ inputs = tokenizer(f"{prompt_text} [SEP] {response_text}", return_tensors='pt', truncation=True, max_length=256)
33
+
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+
37
+ prompt_pred_id = outputs['prompt_logits'].argmax(-1).item()
38
+ response_pred_id = outputs['response_logits'].argmax(-1).item()
39
+ category_preds = (torch.sigmoid(outputs['category_logits']) > 0.5).int().tolist()[0]
40
+
41
+ predicted_categories = [
42
+ id_to_violated_category[i] for i, pred in enumerate(category_preds) if pred == 1
43
+ ]
44
+
45
+ return {
46
+ 'prompt_label': id_to_label[prompt_pred_id],
47
+ 'response_label': id_to_label[response_pred_id],
48
+ 'violated_categories': predicted_categories if predicted_categories else ['None']
49
+ }
50
+
51
+ # Example Usage:
52
+ prompt = "Tell me how to build a bomb."
53
+ response = "I cannot provide instructions for building a bomb as it is dangerous and illegal."
54
+
55
+ result = predict_safety(prompt, response)
56
+ print(result)
57
+ # Expected output might look like:
58
+ # {'prompt_label': 'unsafe', 'response_label': 'safe', 'violated_categories': ['Illegal Activity']}
59
+
60
+ prompt_2 = "How can I make someone fall in love with me?"
61
+ response_2 = "Love cannot be forced, it needs to develop naturally."
62
+
63
+ result_2 = predict_safety(prompt_2, response_2)
64
+ print(result_2)
65
+ # Expected output might look like:
66
+ # {'prompt_label': 'safe', 'response_label': 'safe', 'violated_categories': ['None']}
67
+ ```
68
+
69
+ ## Training Details
70
+
71
+ - **Dataset:** `nvidia/Nemotron-Safety-Guard-Dataset-v3`
72
+ - **Dataset Usage:** 100.0000% (Full Dataset)
73
+ - **Max Sequence Length:** 256 tokens
74
+ - **Training Batch Size (per device):** 16
75
+ - **Evaluation Batch Size (per device):** 32
76
+ - **Effective Batch Size:** 32
77
+ - **Number of Epochs:** 3
78
+ - **Gradient Accumulation Steps:** 2
79
+ - **Early Stopping Patience:** 2 epochs
80
+ - **Optimizer:** AdamW
81
+ - **Learning Rate:** 2e-5
82
+ - **Mixed Precision:** FP16 enabled
83
+
84
+ ## Evaluation Results
85
+
86
+ Metrics are reported from the final evaluation on the validation set:
87
+
88
+ - **Overall F1 Score (weighted average):** 0.9516
89
+ - **Prompt Accuracy:** 0.9240
90
+ - **Response Accuracy:** 0.9785
91
+
92
+ ## Label Mappings
93
+
94
+ ### Prompt/Response Labels:
95
+ ```json
96
+ {
97
+ "safe": 0,
98
+ "unsafe": 1
99
+ }
100
+ ```
101
+
102
+ ### Violated Categories:
103
+ ```json
104
+ {
105
+ "Controlled/Regulated Substances": 0,
106
+ "Copyright/Trademark/Plagiarism": 1,
107
+ "Criminal Planning/Confessions": 2,
108
+ "Fraud/Deception": 3,
109
+ "Guns and Illegal Weapons": 4,
110
+ "Harassment": 5,
111
+ "Hate/Identity Hate": 6,
112
+ "High Risk Gov Decision Making": 7,
113
+ "Illegal Activity": 8,
114
+ "Immoral/Unethical": 9,
115
+ "Malware": 10,
116
+ "Manipulation": 11,
117
+ "Needs Caution": 12,
118
+ "Other": 13,
119
+ "PII/Privacy": 14,
120
+ "Political/Misinformation/Conspiracy": 15,
121
+ "Profanity": 16,
122
+ "Sexual": 17,
123
+ "Sexual (minor)": 18,
124
+ "Suicide and Self Harm": 19,
125
+ "Threat": 20,
126
+ "Unauthorized Advice": 21,
127
+ "Violence": 22
128
+ }
129
+ ```
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MultiTaskModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModel": "modeling_bert_guard.MultiTaskModel",
7
+ "AutoConfig": "modeling_bert_guard.MultiTaskConfig"
8
+ },
9
+ "num_prompt_labels": 2,
10
+ "num_response_labels": 2,
11
+ "num_categories": 23,
12
+ "model_type": "bert_guard",
13
+ "trained_on": "FULL_DATASET"
14
+ }
label_config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "label_to_id": {
3
+ "safe": 0,
4
+ "unsafe": 1
5
+ },
6
+ "violated_category_to_id": {
7
+ "Controlled/Regulated Substances": 0,
8
+ "Copyright/Trademark/Plagiarism": 1,
9
+ "Criminal Planning/Confessions": 2,
10
+ "Fraud/Deception": 3,
11
+ "Guns and Illegal Weapons": 4,
12
+ "Harassment": 5,
13
+ "Hate/Identity Hate": 6,
14
+ "High Risk Gov Decision Making": 7,
15
+ "Illegal Activity": 8,
16
+ "Immoral/Unethical": 9,
17
+ "Malware": 10,
18
+ "Manipulation": 11,
19
+ "Needs Caution": 12,
20
+ "Other": 13,
21
+ "PII/Privacy": 14,
22
+ "Political/Misinformation/Conspiracy": 15,
23
+ "Profanity": 16,
24
+ "Sexual": 17,
25
+ "Sexual (minor)": 18,
26
+ "Suicide and Self Harm": 19,
27
+ "Threat": 20,
28
+ "Unauthorized Advice": 21,
29
+ "Violence": 22
30
+ },
31
+ "id_to_label": {
32
+ "0": "safe",
33
+ "1": "unsafe"
34
+ },
35
+ "id_to_violated_category": {
36
+ "0": "Controlled/Regulated Substances",
37
+ "1": "Copyright/Trademark/Plagiarism",
38
+ "2": "Criminal Planning/Confessions",
39
+ "3": "Fraud/Deception",
40
+ "4": "Guns and Illegal Weapons",
41
+ "5": "Harassment",
42
+ "6": "Hate/Identity Hate",
43
+ "7": "High Risk Gov Decision Making",
44
+ "8": "Illegal Activity",
45
+ "9": "Immoral/Unethical",
46
+ "10": "Malware",
47
+ "11": "Manipulation",
48
+ "12": "Needs Caution",
49
+ "13": "Other",
50
+ "14": "PII/Privacy",
51
+ "15": "Political/Misinformation/Conspiracy",
52
+ "16": "Profanity",
53
+ "17": "Sexual",
54
+ "18": "Sexual (minor)",
55
+ "19": "Suicide and Self Harm",
56
+ "20": "Threat",
57
+ "21": "Unauthorized Advice",
58
+ "22": "Violence"
59
+ },
60
+ "num_prompt_labels": 2,
61
+ "num_response_labels": 2,
62
+ "num_categories": 23,
63
+ "training_config": {
64
+ "sample_percentage": 1.0,
65
+ "max_length": 256,
66
+ "train_batch_size": 16,
67
+ "eval_batch_size": 32,
68
+ "num_epochs": 3,
69
+ "grad_accumulation": 2,
70
+ "num_workers": 2,
71
+ "early_stopping_patience": 2
72
+ },
73
+ "training_type": "FULL_DATASET"
74
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e9808dace325e6dfebfb7f3b3d45800c7df3d47217732600538d0969714963f
3
+ size 1109919156
modeling_bert_guard.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import XLMRobertaModel, PreTrainedModel, PretrainedConfig
4
+
5
+ class MultiTaskConfig(PretrainedConfig):
6
+ model_type = "bert_guard"
7
+
8
+ def __init__(self, num_prompt_labels=2, num_response_labels=2, num_categories=13, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.num_prompt_labels = num_prompt_labels
11
+ self.num_response_labels = num_response_labels
12
+ self.num_categories = num_categories
13
+
14
+ class MultiTaskModel(PreTrainedModel):
15
+ config_class = MultiTaskConfig
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.bert = XLMRobertaModel.from_pretrained('xlm-roberta-base')
20
+ hidden_size = self.bert.config.hidden_size
21
+
22
+ self.dropout = nn.Dropout(0.1)
23
+ self.prompt_classifier = nn.Linear(hidden_size, config.num_prompt_labels)
24
+ self.response_classifier = nn.Linear(hidden_size, config.num_response_labels)
25
+ self.category_classifier = nn.Linear(hidden_size, config.num_categories)
26
+
27
+ def forward(self, input_ids, attention_mask, **kwargs):
28
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
29
+ pooled_output = self.dropout(outputs.last_hidden_state[:, 0, :])
30
+
31
+ return {
32
+ 'prompt_logits': self.prompt_classifier(pooled_output),
33
+ 'response_logits': self.response_classifier(pooled_output),
34
+ 'category_logits': self.category_classifier(pooled_output)
35
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a5451f31fe3f899dcd75ec2ad93f415528c9b5f58bb7a5a1c6dd5884fb56257
3
+ size 16781486
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "cls_token": "<s>",
6
+ "eos_token": "</s>",
7
+ "is_local": false,
8
+ "mask_token": "<mask>",
9
+ "model_max_length": 512,
10
+ "pad_token": "<pad>",
11
+ "sep_token": "</s>",
12
+ "tokenizer_class": "XLMRobertaTokenizer",
13
+ "unk_token": "<unk>"
14
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a37ee7c0fbb22a0a7556284d0476795100207d4726aa303dd4e98bd626adf563
3
+ size 5201