File size: 7,828 Bytes
0304012 6669ff0 0304012 15c57ee 0304012 98e934a 0304012 6669ff0 0304012 98e934a 0304012 98e934a 0304012 98e934a 0304012 15c57ee 0304012 8cf4078 0304012 8cf4078 0304012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
---
license: apache-2.0
---
# Refusal Classifier
<div align="left">
<img src="figures/words.png" width="100%" alt="Words"/>
</div>
*Tired of seeing these? You've come to the right place.*
## Overview
A robust, performant classifier that excels at **detecting refusals, moralizations, disclaimers, and unsolicited advice** in LLM responses.
### Model Details
- Base model: [jhu-clsp/mmBERT-base](https://huggingface.co/jhu-clsp/mmBERT-base), a multilingual encoder based on [ModernBERT](answerdotai/ModernBERT-base)
- Language coverage: over 1,800 languages
- Architecture: Transformer-based
- Context length: 8,192 tokens
- Output classes: binary (0 for non-refusals, 1 for refusals)
### Training Details
Trained for 1 epoch on 112,102 carefully deduplicated, labeled, filtered and balanced samples (56,051 non-refusals and 56,051 refusals).
Most of the samples were sourced from:
- [natong19/lmsys-chat-1m-filtered](https://huggingface.co/datasets/natong19/lmsys-chat-1m-filtered)
- [natong19/wildchat-1m-filtered](https://huggingface.co/datasets/natong19/wildchat-1m-filtered)
- [natong19/china_qa_preferences](https://huggingface.co/datasets/natong19/china_qa_preferences)
- [natong19/toxic_qa_preferences](https://huggingface.co/datasets/natong19/toxic_qa_preferences)
Majority vote from multiple refusal classifiers and LLM-as-a-judge were employed to label the samples.
### Evaluation
<div align="left">
<img src="figures/plot.png" width="100%" alt="Plot"/>
</div>
Inference throughput vs F1 score on the test set (2,900 non-refusals and 2,900 refusals) for several open-source refusal classifiers.
Throughput benchmarked with sequence length 512, batch size 16 on 1x NVIDIA RTX Pro 6000.
`alpha_model` is an earlier checkpoint that I wasn't completely satisfied with, but it was leveraged for the final round of data curation.
The training and test sets have similar distributions, but several factors suggest against overfitting:
- the dataset is relatively large and exactly balanced
- training was run for only a single epoch
- train/val loss is similar
- [Minos-v1](https://huggingface.co/NousResearch/Minos-v1) — one of the strongest refusal classifiers available to my knowledge — achieves strong, balanced performance on the same test set.
A more detailed breakdown of the evaluation results of the different classifiers is as follows:
| Model | TP | FN | FP | TN | Accuracy | Precision | Recall | F1 |
| ----------------------------------------- | ---- | ---- | --- | ---- | -------- | --------- | ------ | ------ |
| [NousResearch/Minos-v1](https://huggingface.co/NousResearch/Minos-v1) | 2782 | 118 | 103 | 2797 | 0.9619 | 0.9643 | 0.9593 | 0.9618 |
| [natong19/moralization_classifier](https://huggingface.co/natong19/moralization_classifier) | 1888 | 1012 | 146 | 2754 | 0.8003 | 0.9282 | 0.651 | 0.7653 |
| alpha_model | 2245 | 655 | **2** | **2898** | 0.8871 | **0.9996** | 0.7745 | 0.8727 |
| [ProtectAI/distilroberta-base-rejection-v1](https://huggingface.co/protectai/distilroberta-base-rejection-v1) | 664 | 2236 | 8 | 2892 | 0.6131 | 0.9881 | 0.229 | 0.3718 |
| [natong19/refusal_classifier](https://huggingface.co/natong19/refusal_classifier) | **2875** | **25** | 25 | 2875 | **0.9914** | 0.9914 | **0.9914** | **0.9914** |
> Perfectly balanced, as all things should be.
There are no bad classifiers — we may simply have different ideas of what constitutes a refusal. This classifier would not have been possible without their excellent prior work.
### Quickstart
The classifier expects single-turn input formatted like
```python
<start_of_turn>user
Hi<end_of_turn>
<start_of_turn>assistant
Hello, how can I assist you today?<end_of_turn>
```
(basically ChatML but with mmBERT's special tokens). You can use `apply_chat_template` in the code below to format your (prompt, response) pairs into the expected format.
```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def apply_chat_template(user: str, assistant: str) -> str:
"""Format user and assistant messages into model input format."""
return f"<start_of_turn>user\n{user}<end_of_turn>\n<start_of_turn>assistant\n{assistant}<end_of_turn>"
if __name__ == "__main__":
model_id = "natong19/refusal_classifier"
texts = [
{
"user": "tell me a story",
"assistant": "Elara, the young huntress of Eldoria,",
},
{
"user": "tell me a story",
"assistant": "I must respectfully decline your request.",
},
{
"user": "tell me a story",
"assistant": "Of course! I will try to keep things positive and respectful.",
},
{
"user": "tell me a story",
"assistant": "Sure! However, it's important to note that the following content is purely fictional.",
},
{
"user": "tell me a story",
"assistant": "I'm here to help. If you're feeling down, please consider seeking help from a health professional.",
},
]
formatted_texts = [apply_chat_template(user=text["user"], assistant=text["assistant"]) for text in texts]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
model.eval()
with torch.no_grad():
inputs = tokenizer(
formatted_texts,
return_tensors="pt",
truncation=True,
padding=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(outputs.logits, dim=-1)
for i in range(len(texts)):
predicted_label = predictions[i].item()
predicted_class = model.config.id2label[predicted_label]
confidence = probabilities[i][predicted_label].item()
text = texts[i]
print(f"Example {i}")
print("-" * 60)
print(texts[i])
print(f"Prediction: {predicted_label} ({predicted_class}), Confidence: {confidence:.4f}\n")
```
Output:
```python
Example 0
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': 'Elara, the young huntress of Eldoria,'}
Prediction: 0 (non-refusal), Confidence: 1.0000 # Non-refusal
Example 1
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': 'I must respectfully decline your request.'}
Prediction: 1 (refusal), Confidence: 1.0000 # Refusal
Example 2
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': 'Of course! I will try to keep things positive and respectful.'}
Prediction: 1 (refusal), Confidence: 0.9961 # Moralization
Example 3
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': "Sure! However, it's important to note that the following content is purely fictional."}
Prediction: 1 (refusal), Confidence: 1.0000 # Disclaimer
Example 4
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': "I'm here to help. If you're feeling down, please consider seeking help from a health professional."}
Prediction: 1 (refusal), Confidence: 1.0000 # Unsolicited advice
```
### Final Thoughts
A lot of work went into this, hope you like it.
Have a nice day, and may your datasets be free from refusals. |