File size: 7,828 Bytes
0304012
 
 
 
 
 
 
6669ff0
0304012
 
 
 
 
 
15c57ee
0304012
 
 
 
 
 
 
 
 
 
 
98e934a
0304012
 
 
 
 
 
 
 
 
 
 
6669ff0
0304012
98e934a
 
0304012
98e934a
0304012
 
98e934a
 
 
 
 
 
0304012
15c57ee
0304012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cf4078
0304012
 
8cf4078
0304012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
---
license: apache-2.0
---

# Refusal Classifier

<div align="left">
<img src="figures/words.png" width="100%" alt="Words"/>
</div>

*Tired of seeing these? You've come to the right place.*

## Overview

A robust, performant classifier that excels at **detecting refusals, moralizations, disclaimers, and unsolicited advice** in LLM responses.

### Model Details

- Base model: [jhu-clsp/mmBERT-base](https://huggingface.co/jhu-clsp/mmBERT-base), a multilingual encoder based on [ModernBERT](answerdotai/ModernBERT-base)
- Language coverage: over 1,800 languages
- Architecture: Transformer-based
- Context length: 8,192 tokens
- Output classes: binary (0 for non-refusals, 1 for refusals)

### Training Details

Trained for 1 epoch on 112,102 carefully deduplicated, labeled, filtered and balanced samples (56,051 non-refusals and 56,051 refusals).

Most of the samples were sourced from:
- [natong19/lmsys-chat-1m-filtered](https://huggingface.co/datasets/natong19/lmsys-chat-1m-filtered)
- [natong19/wildchat-1m-filtered](https://huggingface.co/datasets/natong19/wildchat-1m-filtered)
- [natong19/china_qa_preferences](https://huggingface.co/datasets/natong19/china_qa_preferences)
- [natong19/toxic_qa_preferences](https://huggingface.co/datasets/natong19/toxic_qa_preferences)

Majority vote from multiple refusal classifiers and LLM-as-a-judge were employed to label the samples.

### Evaluation
<div align="left">
<img src="figures/plot.png" width="100%" alt="Plot"/>
</div>
Inference throughput vs F1 score on the test set (2,900 non-refusals and 2,900 refusals) for several open-source refusal classifiers.
Throughput benchmarked with sequence length 512, batch size 16 on 1x NVIDIA RTX Pro 6000.

`alpha_model` is an earlier checkpoint that I wasn't completely satisfied with, but it was leveraged for the final round of data curation.

The training and test sets have similar distributions, but several factors suggest against overfitting:
- the dataset is relatively large and exactly balanced
- training was run for only a single epoch
- train/val loss is similar
- [Minos-v1](https://huggingface.co/NousResearch/Minos-v1) — one of the strongest refusal classifiers available to my knowledge — achieves strong, balanced performance on the same test set.

A more detailed breakdown of the evaluation results of the different classifiers is as follows:

| Model                                     | TP   | FN   | FP  | TN   | Accuracy | Precision | Recall | F1     |
| ----------------------------------------- | ---- | ---- | --- | ---- | -------- | --------- | ------ | ------ |
| [NousResearch/Minos-v1](https://huggingface.co/NousResearch/Minos-v1)                     | 2782 | 118  | 103 | 2797 | 0.9619   | 0.9643    | 0.9593 | 0.9618 |
| [natong19/moralization_classifier](https://huggingface.co/natong19/moralization_classifier)          | 1888 | 1012 | 146 | 2754 | 0.8003   | 0.9282    | 0.651  | 0.7653 |
| alpha_model                               | 2245 | 655  | **2**   | **2898** | 0.8871   | **0.9996**    | 0.7745 | 0.8727 |
| [ProtectAI/distilroberta-base-rejection-v1](https://huggingface.co/protectai/distilroberta-base-rejection-v1) | 664  | 2236 | 8   | 2892 | 0.6131   | 0.9881    | 0.229  | 0.3718 |
| [natong19/refusal_classifier](https://huggingface.co/natong19/refusal_classifier)               | **2875** | **25**   | 25  | 2875 | **0.9914**   | 0.9914    | **0.9914** | **0.9914** |

> Perfectly balanced, as all things should be.

There are no bad classifiers — we may simply have different ideas of what constitutes a refusal. This classifier would not have been possible without their excellent prior work.

### Quickstart
The classifier expects single-turn input formatted like

```python
<start_of_turn>user
Hi<end_of_turn>
<start_of_turn>assistant
Hello, how can I assist you today?<end_of_turn>
```

(basically ChatML but with mmBERT's special tokens). You can use `apply_chat_template` in the code below to format your (prompt, response) pairs into the expected format.

```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def apply_chat_template(user: str, assistant: str) -> str:
    """Format user and assistant messages into model input format."""
    return f"<start_of_turn>user\n{user}<end_of_turn>\n<start_of_turn>assistant\n{assistant}<end_of_turn>"


if __name__ == "__main__":
    model_id = "natong19/refusal_classifier"

    texts = [
        {
            "user": "tell me a story",
            "assistant": "Elara, the young huntress of Eldoria,",
        },
        {
            "user": "tell me a story",
            "assistant": "I must respectfully decline your request.",
        },
        {
            "user": "tell me a story",
            "assistant": "Of course! I will try to keep things positive and respectful.",
        },
        {
            "user": "tell me a story",
            "assistant": "Sure! However, it's important to note that the following content is purely fictional.",
        },
        {
            "user": "tell me a story",
            "assistant": "I'm here to help. If you're feeling down, please consider seeking help from a health professional.",
        },
    ]

    formatted_texts = [apply_chat_template(user=text["user"], assistant=text["assistant"]) for text in texts]

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)
    model.eval()

    with torch.no_grad():
        inputs = tokenizer(
            formatted_texts,
            return_tensors="pt",
            truncation=True,
            padding=True,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predictions = torch.argmax(outputs.logits, dim=-1)

    for i in range(len(texts)):
        predicted_label = predictions[i].item()
        predicted_class = model.config.id2label[predicted_label]
        confidence = probabilities[i][predicted_label].item()
        text = texts[i]

        print(f"Example {i}")
        print("-" * 60)
        print(texts[i])
        print(f"Prediction: {predicted_label} ({predicted_class}), Confidence: {confidence:.4f}\n")
```

Output:

```python
Example 0
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': 'Elara, the young huntress of Eldoria,'}
Prediction: 0 (non-refusal), Confidence: 1.0000 # Non-refusal

Example 1
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': 'I must respectfully decline your request.'}
Prediction: 1 (refusal), Confidence: 1.0000 # Refusal

Example 2
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': 'Of course! I will try to keep things positive and respectful.'}
Prediction: 1 (refusal), Confidence: 0.9961 # Moralization

Example 3
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': "Sure! However, it's important to note that the following content is purely fictional."}
Prediction: 1 (refusal), Confidence: 1.0000 # Disclaimer

Example 4
------------------------------------------------------------
{'user': 'tell me a story', 'assistant': "I'm here to help. If you're feeling down, please consider seeking help from a health professional."}
Prediction: 1 (refusal), Confidence: 1.0000 # Unsolicited advice
```

### Final Thoughts
A lot of work went into this, hope you like it.
Have a nice day, and may your datasets be free from refusals.