File size: 4,128 Bytes
c5260a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# BERT-GUARD Multi-Task Safety Classifier
This model is a multi-task safety classifier designed to identify safe/unsafe prompts and responses, and classify violated categories. It was trained on the full `nvidia/Nemotron-Safety-Guard-Dataset-v3` dataset.
## Model Details
- **Base Model:** `xlm-roberta-base`
- **Model Type:** Multi-Task Sequence Classification (Prompt Safety, Response Safety, Violated Category Classification)
## How to Use
To use the model for inference, you can load it using the Hugging Face `transformers` library. The custom model architecture `MultiTaskModel` and its configuration `MultiTaskConfig` are provided in `modeling_bert_guard.py`.
```python
import torch
from transformers import AutoTokenizer, pipeline
from modeling_bert_guard import MultiTaskModel, MultiTaskConfig
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('./bert_guard_final')
model = MultiTaskModel.from_pretrained('./bert_guard_final')
# Load label configuration
with open('./bert_guard_final/label_config.json', 'r') as f:
label_config = json.load(f)
id_to_label = {int(k): v for k, v in label_config['id_to_label'].items()}
id_to_violated_category = {int(k): v for k, v in label_config['id_to_violated_category'].items()}
def predict_safety(prompt_text, response_text):
inputs = tokenizer(f"{prompt_text} [SEP] {response_text}", return_tensors='pt', truncation=True, max_length=256)
with torch.no_grad():
outputs = model(**inputs)
prompt_pred_id = outputs['prompt_logits'].argmax(-1).item()
response_pred_id = outputs['response_logits'].argmax(-1).item()
category_preds = (torch.sigmoid(outputs['category_logits']) > 0.5).int().tolist()[0]
predicted_categories = [
id_to_violated_category[i] for i, pred in enumerate(category_preds) if pred == 1
]
return {
'prompt_label': id_to_label[prompt_pred_id],
'response_label': id_to_label[response_pred_id],
'violated_categories': predicted_categories if predicted_categories else ['None']
}
# Example Usage:
prompt = "Tell me how to build a bomb."
response = "I cannot provide instructions for building a bomb as it is dangerous and illegal."
result = predict_safety(prompt, response)
print(result)
# Expected output might look like:
# {'prompt_label': 'unsafe', 'response_label': 'safe', 'violated_categories': ['Illegal Activity']}
prompt_2 = "How can I make someone fall in love with me?"
response_2 = "Love cannot be forced, it needs to develop naturally."
result_2 = predict_safety(prompt_2, response_2)
print(result_2)
# Expected output might look like:
# {'prompt_label': 'safe', 'response_label': 'safe', 'violated_categories': ['None']}
```
## Training Details
- **Dataset:** `nvidia/Nemotron-Safety-Guard-Dataset-v3`
- **Dataset Usage:** 100.0000% (Full Dataset)
- **Max Sequence Length:** 256 tokens
- **Training Batch Size (per device):** 16
- **Evaluation Batch Size (per device):** 32
- **Effective Batch Size:** 32
- **Number of Epochs:** 3
- **Gradient Accumulation Steps:** 2
- **Early Stopping Patience:** 2 epochs
- **Optimizer:** AdamW
- **Learning Rate:** 2e-5
- **Mixed Precision:** FP16 enabled
## Evaluation Results
Metrics are reported from the final evaluation on the validation set:
- **Overall F1 Score (weighted average):** 0.9516
- **Prompt Accuracy:** 0.9240
- **Response Accuracy:** 0.9785
## Label Mappings
### Prompt/Response Labels:
```json
{
"safe": 0,
"unsafe": 1
}
```
### Violated Categories:
```json
{
"Controlled/Regulated Substances": 0,
"Copyright/Trademark/Plagiarism": 1,
"Criminal Planning/Confessions": 2,
"Fraud/Deception": 3,
"Guns and Illegal Weapons": 4,
"Harassment": 5,
"Hate/Identity Hate": 6,
"High Risk Gov Decision Making": 7,
"Illegal Activity": 8,
"Immoral/Unethical": 9,
"Malware": 10,
"Manipulation": 11,
"Needs Caution": 12,
"Other": 13,
"PII/Privacy": 14,
"Political/Misinformation/Conspiracy": 15,
"Profanity": 16,
"Sexual": 17,
"Sexual (minor)": 18,
"Suicide and Self Harm": 19,
"Threat": 20,
"Unauthorized Advice": 21,
"Violence": 22
}
```
|