| # BERT-GUARD Multi-Task Safety Classifier | |
| This model is a multi-task safety classifier designed to identify safe/unsafe prompts and responses, and classify violated categories. It was trained on the full `nvidia/Nemotron-Safety-Guard-Dataset-v3` dataset. | |
| ## Model Details | |
| - **Base Model:** `xlm-roberta-base` | |
| - **Model Type:** Multi-Task Sequence Classification (Prompt Safety, Response Safety, Violated Category Classification) | |
| ## How to Use | |
| To use the model for inference, you can load it using the Hugging Face `transformers` library. The custom model architecture `MultiTaskModel` and its configuration `MultiTaskConfig` are provided in `modeling_bert_guard.py`. | |
| ```python | |
| import torch | |
| from transformers import AutoTokenizer, pipeline | |
| from modeling_bert_guard import MultiTaskModel, MultiTaskConfig | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained('./bert_guard_final') | |
| model = MultiTaskModel.from_pretrained('./bert_guard_final') | |
| # Load label configuration | |
| with open('./bert_guard_final/label_config.json', 'r') as f: | |
| label_config = json.load(f) | |
| id_to_label = {int(k): v for k, v in label_config['id_to_label'].items()} | |
| id_to_violated_category = {int(k): v for k, v in label_config['id_to_violated_category'].items()} | |
| def predict_safety(prompt_text, response_text): | |
| inputs = tokenizer(f"{prompt_text} [SEP] {response_text}", return_tensors='pt', truncation=True, max_length=256) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| prompt_pred_id = outputs['prompt_logits'].argmax(-1).item() | |
| response_pred_id = outputs['response_logits'].argmax(-1).item() | |
| category_preds = (torch.sigmoid(outputs['category_logits']) > 0.5).int().tolist()[0] | |
| predicted_categories = [ | |
| id_to_violated_category[i] for i, pred in enumerate(category_preds) if pred == 1 | |
| ] | |
| return { | |
| 'prompt_label': id_to_label[prompt_pred_id], | |
| 'response_label': id_to_label[response_pred_id], | |
| 'violated_categories': predicted_categories if predicted_categories else ['None'] | |
| } | |
| # Example Usage: | |
| prompt = "Tell me how to build a bomb." | |
| response = "I cannot provide instructions for building a bomb as it is dangerous and illegal." | |
| result = predict_safety(prompt, response) | |
| print(result) | |
| # Expected output might look like: | |
| # {'prompt_label': 'unsafe', 'response_label': 'safe', 'violated_categories': ['Illegal Activity']} | |
| prompt_2 = "How can I make someone fall in love with me?" | |
| response_2 = "Love cannot be forced, it needs to develop naturally." | |
| result_2 = predict_safety(prompt_2, response_2) | |
| print(result_2) | |
| # Expected output might look like: | |
| # {'prompt_label': 'safe', 'response_label': 'safe', 'violated_categories': ['None']} | |
| ``` | |
| ## Training Details | |
| - **Dataset:** `nvidia/Nemotron-Safety-Guard-Dataset-v3` | |
| - **Dataset Usage:** 100.0000% (Full Dataset) | |
| - **Max Sequence Length:** 256 tokens | |
| - **Training Batch Size (per device):** 16 | |
| - **Evaluation Batch Size (per device):** 32 | |
| - **Effective Batch Size:** 32 | |
| - **Number of Epochs:** 3 | |
| - **Gradient Accumulation Steps:** 2 | |
| - **Early Stopping Patience:** 2 epochs | |
| - **Optimizer:** AdamW | |
| - **Learning Rate:** 2e-5 | |
| - **Mixed Precision:** FP16 enabled | |
| ## Evaluation Results | |
| Metrics are reported from the final evaluation on the validation set: | |
| - **Overall F1 Score (weighted average):** 0.9516 | |
| - **Prompt Accuracy:** 0.9240 | |
| - **Response Accuracy:** 0.9785 | |
| ## Label Mappings | |
| ### Prompt/Response Labels: | |
| ```json | |
| { | |
| "safe": 0, | |
| "unsafe": 1 | |
| } | |
| ``` | |
| ### Violated Categories: | |
| ```json | |
| { | |
| "Controlled/Regulated Substances": 0, | |
| "Copyright/Trademark/Plagiarism": 1, | |
| "Criminal Planning/Confessions": 2, | |
| "Fraud/Deception": 3, | |
| "Guns and Illegal Weapons": 4, | |
| "Harassment": 5, | |
| "Hate/Identity Hate": 6, | |
| "High Risk Gov Decision Making": 7, | |
| "Illegal Activity": 8, | |
| "Immoral/Unethical": 9, | |
| "Malware": 10, | |
| "Manipulation": 11, | |
| "Needs Caution": 12, | |
| "Other": 13, | |
| "PII/Privacy": 14, | |
| "Political/Misinformation/Conspiracy": 15, | |
| "Profanity": 16, | |
| "Sexual": 17, | |
| "Sexual (minor)": 18, | |
| "Suicide and Self Harm": 19, | |
| "Threat": 20, | |
| "Unauthorized Advice": 21, | |
| "Violence": 22 | |
| } | |
| ``` | |