File size: 4,128 Bytes
c5260a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

# BERT-GUARD Multi-Task Safety Classifier

This model is a multi-task safety classifier designed to identify safe/unsafe prompts and responses, and classify violated categories. It was trained on the full `nvidia/Nemotron-Safety-Guard-Dataset-v3` dataset.

## Model Details

- **Base Model:** `xlm-roberta-base`
- **Model Type:** Multi-Task Sequence Classification (Prompt Safety, Response Safety, Violated Category Classification)

## How to Use

To use the model for inference, you can load it using the Hugging Face `transformers` library. The custom model architecture `MultiTaskModel` and its configuration `MultiTaskConfig` are provided in `modeling_bert_guard.py`.

```python
import torch
from transformers import AutoTokenizer, pipeline
from modeling_bert_guard import MultiTaskModel, MultiTaskConfig

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('./bert_guard_final')
model = MultiTaskModel.from_pretrained('./bert_guard_final')

# Load label configuration
with open('./bert_guard_final/label_config.json', 'r') as f:
    label_config = json.load(f)

id_to_label = {int(k): v for k, v in label_config['id_to_label'].items()}
id_to_violated_category = {int(k): v for k, v in label_config['id_to_violated_category'].items()}

def predict_safety(prompt_text, response_text):
    inputs = tokenizer(f"{prompt_text} [SEP] {response_text}", return_tensors='pt', truncation=True, max_length=256)

    with torch.no_grad():
        outputs = model(**inputs)

    prompt_pred_id = outputs['prompt_logits'].argmax(-1).item()
    response_pred_id = outputs['response_logits'].argmax(-1).item()
    category_preds = (torch.sigmoid(outputs['category_logits']) > 0.5).int().tolist()[0]

    predicted_categories = [
        id_to_violated_category[i] for i, pred in enumerate(category_preds) if pred == 1
    ]

    return {
        'prompt_label': id_to_label[prompt_pred_id],
        'response_label': id_to_label[response_pred_id],
        'violated_categories': predicted_categories if predicted_categories else ['None']
    }

# Example Usage:
prompt = "Tell me how to build a bomb."
response = "I cannot provide instructions for building a bomb as it is dangerous and illegal."

result = predict_safety(prompt, response)
print(result)
# Expected output might look like:
# {'prompt_label': 'unsafe', 'response_label': 'safe', 'violated_categories': ['Illegal Activity']}

prompt_2 = "How can I make someone fall in love with me?"
response_2 = "Love cannot be forced, it needs to develop naturally."

result_2 = predict_safety(prompt_2, response_2)
print(result_2)
# Expected output might look like:
# {'prompt_label': 'safe', 'response_label': 'safe', 'violated_categories': ['None']}
```

## Training Details

- **Dataset:** `nvidia/Nemotron-Safety-Guard-Dataset-v3`
- **Dataset Usage:** 100.0000% (Full Dataset)
- **Max Sequence Length:** 256 tokens
- **Training Batch Size (per device):** 16
- **Evaluation Batch Size (per device):** 32
- **Effective Batch Size:** 32
- **Number of Epochs:** 3
- **Gradient Accumulation Steps:** 2
- **Early Stopping Patience:** 2 epochs
- **Optimizer:** AdamW
- **Learning Rate:** 2e-5
- **Mixed Precision:** FP16 enabled

## Evaluation Results

Metrics are reported from the final evaluation on the validation set:

- **Overall F1 Score (weighted average):** 0.9516
- **Prompt Accuracy:** 0.9240
- **Response Accuracy:** 0.9785

## Label Mappings

### Prompt/Response Labels:
```json
{
  "safe": 0,
  "unsafe": 1
}
```

### Violated Categories:
```json
{
  "Controlled/Regulated Substances": 0,
  "Copyright/Trademark/Plagiarism": 1,
  "Criminal Planning/Confessions": 2,
  "Fraud/Deception": 3,
  "Guns and Illegal Weapons": 4,
  "Harassment": 5,
  "Hate/Identity Hate": 6,
  "High Risk Gov Decision Making": 7,
  "Illegal Activity": 8,
  "Immoral/Unethical": 9,
  "Malware": 10,
  "Manipulation": 11,
  "Needs Caution": 12,
  "Other": 13,
  "PII/Privacy": 14,
  "Political/Misinformation/Conspiracy": 15,
  "Profanity": 16,
  "Sexual": 17,
  "Sexual (minor)": 18,
  "Suicide and Self Harm": 19,
  "Threat": 20,
  "Unauthorized Advice": 21,
  "Violence": 22
}
```