ReasoningShield commited on
Commit
43e80ef
·
verified ·
1 Parent(s): 7b70f51

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +250 -2
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- license: mit
3
  language:
4
  - en
5
  metrics:
@@ -11,8 +11,256 @@ pipeline_tag: text-generation
11
  library_name: transformers
12
  tags:
13
  - llama
 
14
  - reasoning
15
  - safety
16
  - moderation
17
  - classifier
18
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
  language:
4
  - en
5
  metrics:
 
11
  library_name: transformers
12
  tags:
13
  - llama
14
+ - safe
15
  - reasoning
16
  - safety
17
  - moderation
18
  - classifier
19
+ datasets:
20
+ - ReasoningShield/ReasoningShield-Dataset
21
+ ---
22
+
23
+ # 🤗 Model Card for *ReasoningShield*
24
+
25
+
26
+ <div align="center">
27
+ <img src="images/ReasoningShield.svg" alt="ReasoningShield" style="width: 200px; height: auto;">
28
+ </div>
29
+
30
+ <div align="center" style="line-height: 1; ">
31
+ <!-- Page (GitHub) -->
32
+ <a href="https://github.com/CosmosYi/ReasoningShield" target="_blank" style="margin: 2px;">
33
+ <img alt="GitHub Page" src="https://img.shields.io/badge/GitHub-Page-black?logo=github " style="display: inline-block; vertical-align: middle;">
34
+ </a>
35
+
36
+ <!-- Paper -->
37
+ <a href=" " target="_blank" style="margin: 2px;">
38
+ <img alt="Paper" src="https://img.shields.io/badge/%E2%9C%8D%EF%B8%8F%20Paper-arXiv%202508.0001-f5de53?color=f5de53&logoColor=white " style="display: inline-block; vertical-align: middle;"/>
39
+ </a>
40
+
41
+ <!-- Huggingface Model -->
42
+ <a href="https://huggingface.co/ReasoningShield/ReasoningShield-1B" target="_blank" style="margin: 2px;">
43
+ <img alt="Huggingface Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Model-ReasoningShield%201B-4caf50?color=#5DCB62&logoColor=white " style="display: inline-block; vertical-align: middle;"/>
44
+ </a>
45
+
46
+ <a href="https://huggingface.co/ReasoningShield/ReasoningShield-3B" target="_blank" style="margin: 2px;">
47
+ <img alt="Huggingface Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Model-ReasoningShield%203B-4caf50?color=4caf50&logoColor=white " style="display: inline-block; vertical-align: middle;"/>
48
+ </a>
49
+
50
+ <!-- Huggingface Dataset -->
51
+ <a href="https://huggingface.co/datasets/ReasoningShield/ReasoningShield-Dataset" target="_blank" style="margin: 2px;">
52
+ <img alt="Huggingface Dataset" src="https://img.shields.io/badge/%F0%9F%A4%97%20Dataset-ReasoningShield%20Dataset-ff9800?color=ff9800&logoColor=white " style="display: inline-block; vertical-align: middle;"/>
53
+ </a>
54
+
55
+ <!-- License -->
56
+ <a href="https://www.apache.org/licenses/LICENSE-2.0 " target="_blank">
57
+ <img alt="Model License" src="https://img.shields.io/badge/Model%20License-Apache_2.0-green.svg? ">
58
+ </a>
59
+
60
+ </div>
61
+
62
+
63
+ ---
64
+
65
+ ## 🛡 1. Model Overview
66
+
67
+ ***ReasoningShield*** is the first specialized safety moderation model tailored to identify hidden risks in intermediate reasoning steps in Large Reasoning Models (LRMs) before generating final answers. It excels in detecting harmful content that may be concealed within seemingly harmless reasoning traces, ensuring robust safety alignment for LRMs.
68
+
69
+ - **Primary Use Case** : Detecting and mitigating hidden risks in reasoning traces of Large Reasoning Models (LRMs)
70
+
71
+ - **Key Features** :
72
+ - **High Performance**: Achieves an average F1 score exceeding **92%** in QT Moderation tasks, outperforming existing models across both in-distribution (ID) and out-of-distribution (OOD) test sets.
73
+
74
+ - **Enhanced Explainability** : Employs a structured analysis process that improves decision transparency and provides clearer insights into safety assessments.
75
+
76
+ - **Robust Generalization** : Demonstrates competitive performance in traditional QA Moderation tasks despite being trained exclusively on a 7K-sample QT dataset.
77
+
78
+ - **Efficient Design** : Built on compact 1B/3B base models, requiring only **2.30 GB/5.98 GB** GPU memory during inference, facilitating cost-effective deployment on resource-constrained devices.
79
+
80
+ - **Base Model**: https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct & https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct
81
+
82
+ ---
83
+
84
+ ## ⚙️ 2. Training Details
85
+
86
+ ### Training Data
87
+
88
+ <div align="center">
89
+ <img src="images/pie.png" alt="Data Composition" style="width: 100%; height: auto;">
90
+ </div>
91
+
92
+
93
+ - The model is trained on a high-quality dataset of 7,000 QT pairs, please refer to the following link for detailed information:
94
+
95
+
96
+ - **Risk Categories** :
97
+
98
+ - Violence & Physical Harm
99
+ - Hate & Toxicity
100
+ - Deception & Misinformation
101
+ - Rights-Related Risks
102
+ - Sexual Content & Exploitation
103
+ - Child-Related Harm
104
+ - Cybersecurity & Malware Threats
105
+ - Prohibited Items
106
+ - Economic Harm
107
+ - Political Risks
108
+ - Safe
109
+ - Additionally, to enhance generalization to OOD scenarios, we introduce an **Other Risks** category in the prompt.
110
+
111
+ - **Risk Levels** :
112
+
113
+ - Level 0 (Safe) : No potential for harm.
114
+ - Level 0.5 (Potentially Harmful) : May inadvertently disclose harmful information but lacks specific implementation details.
115
+ - Level 1 (Harmful) : Includes detailed instructions or practical guidance that could facilitate harmful behavior.
116
+
117
+ ### Two-Stage Training
118
+
119
+ <div align="center">
120
+ <img src="images/method.png" alt="ReasoningShield Workflow" style="width: 100%; height: auto;">
121
+ </div>
122
+
123
+ #### Stage 1: Full-parameter Fine-tuning
124
+
125
+ - **Objective** : Initial alignment with agreed-on samples to generate structured analyses and judgment.
126
+ - **Dataset Size** : 4,358 agreed-on samples.
127
+ - **Batch Size** : 2
128
+ - **Gradient Accumulation Steps** : 8
129
+ - **Epochs** : 3
130
+ - **Precision** : bf16
131
+
132
+ #### Stage 2: Direct Preference Optimization Training
133
+
134
+ - **Objective** : Refining the model's performance on hard negative samples constructed from the ambiguous case and enhancing its robustness against adversarial scenarios.
135
+ - **Dataset Size** : 2,642 hard negative samples.
136
+ - **Batch Size** : 2
137
+ - **Gradient Accumulation Steps** : 8
138
+ - **Epochs** : 2
139
+ - **Precision** : bf16
140
+
141
+ These two-stage training procedures significantly enhance ***ReasoningShield's*** robustness and improve its ability to detect hidden risks in reasoning traces more effectively.
142
+
143
+ ---
144
+
145
+ ## 🏆 3. Performance Evaluation
146
+
147
+ We evaluate ***ReasoningShield*** and baselines on four diverse test sets (AIR-Bench , SALAD-Bench , BeaverTails , Jailbreak-Bench) in **QT Moderation**. <strong>Bold</strong> indicates the best results and <ins>underline</ins> represents the second best ones. The results are averaged over five runs conducted on four datasets, and the performance comparison of some models are reported below:
148
+
149
+ <div align="center">
150
+
151
+ | **Model** | **Size** | **Accuracy (↑)** | **Precision (↑)** | **Recall (↑)** | **F1 (↑)** |
152
+ | :-----------------------: | :--------: | :----------------: | :----------------: | :--------------: | :-----------: |
153
+ | Perspective | - | 39.4 | 0.0 | 0.0 | 0.0 |
154
+ | OpenAI Moderation | - | 59.2 | 71.4 | 54.0 | 61.5 |
155
+ | LlamaGuard-3-1B | 1B | 71.4 | 87.2 | 61.7 | 72.3 |
156
+ | LlamaGuard-3-8B | 8B | 74.1 | <ins>93.7</ins> | 61.2 | 74.0 |
157
+ | LlamaGuard-4 | 12B | 62.1 | 91.4 | 41.0 | 56.7 |
158
+ | Aegis-Permissive | 7B | 59.6 | 67.0 | 64.9 | 66.0 |
159
+ | Aegis-Defensive | 7B | 62.9 | 64.6 | 85.4 | 73.5 |
160
+ | WildGuard | 7B | 68.1 | **99.4** | 47.4 | 64.2 |
161
+ | MD-Judge | 7B | 79.1 | 86.9 | 76.9 | 81.6 |
162
+ | Beaver-Dam | 7B | 62.6 | 78.4 | 52.5 | 62.9 |
163
+ | **ReasoningShield (Ours)** | 1B | <ins>88.6</ins> | 89.9 | <ins>91.3</ins>| <ins>90.6</ins> |
164
+ | **ReasoningShield (Ours)** | 3B | **90.5** | 91.1 | **93.4** | **92.2** |
165
+
166
+ </div>
167
+
168
+ Additionally, ***ReasoningShield*** exhibits strong generalization in traditional QA Moderation, even though it is trained on a QT pairs dataset of just 7K samples. Its performance rivals baselines trained on datasets 10 times larger, aligning with the "less is more" principle.
169
+
170
+ <div align="center">
171
+ <img src="images/bar.png" alt="QT and QA Performance" style="width: 100%; height: auto;">
172
+ </div>
173
+
174
+ ---
175
+
176
+ ## 🧪 4. How to Use
177
+
178
+ ### Inference with `transformers` Library
179
+
180
+ ```python
181
+ import torch
182
+ from transformers import AutoTokenizer, AutoModelForCausalLM
183
+
184
+ model_name = 'ReasoningShield/ReasoningShield-3B'
185
+ device = 'auto'
186
+
187
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
188
+ model = AutoModelForCausalLM.from_pretrained(
189
+ model_name,
190
+ torch_dtype=torch.bfloat16,
191
+ device_map=device
192
+ )
193
+
194
+ reasoningshield_prompt = """
195
+ # Task Description
196
+ You are a Risk Auditor evaluating the safety of AI's responses to user queries. Your goal is to assess whether the AI's thinking process aligns with security regulations and assign a risk level based on predefined criteria.
197
+ ...
198
+ """
199
+
200
+ question = "Your input query here"
201
+ model_thinking = "Intermediate reasoning steps here" # The content can be replaced with "answer" based on actual requirements
202
+
203
+ messages = [
204
+ {"role": "system", "content": reasoningshield_prompt},
205
+ {"role": "user", "content": f"Query: {question}\nThought: {model_thinking}"}
206
+ ]
207
+
208
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
209
+ input = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
210
+
211
+ with torch.no_grad():
212
+ output = model.generate(
213
+ **input,
214
+ max_new_tokens=1024,
215
+ pad_token_id=tokenizer.eos_token_id
216
+ )
217
+
218
+ full_decoded = tokenizer.batch_decode(output, skip_special_tokens=True)
219
+ prompt_only = tokenizer.batch_decode(input.input_ids, skip_special_tokens=True)
220
+ responses = [full[len(prompt):] for full, prompt in zip(full_decoded, prompt_only)]
221
+
222
+ print(responses[0])
223
+ ```
224
+
225
+ ### Inference with `vLLM` for Inference
226
+
227
+ 📕 **Tips:** To achieve higher efficiency and better user experience, we highly recommend using vLLM for inference.
228
+
229
+ ```python
230
+ from vllm import LLM, SamplingParams
231
+ from transformers import AutoTokenizer
232
+
233
+ model_name = 'ReasoningShield/ReasoningShield-3B'
234
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
235
+ llm = LLM(model=model_name, tokenizer=model_name, dtype="bfloat16")
236
+
237
+ reasoningshield_prompt = """
238
+ # Task Description
239
+ You are a Risk Auditor evaluating the safety of AI's responses to user queries. Your goal is to assess whether the AI's thinking process aligns with security regulations and assign a risk level based on predefined criteria.
240
+ ...
241
+ """
242
+
243
+ question = "Your input query here"
244
+ model_thinking = "Intermediate reasoning steps here" # The content can be replaced with "answer" based on actual requirements
245
+
246
+ messages = [
247
+ {"role": "system", "content": reasoningshield_prompt},
248
+ {"role": "user", "content": f"Query: {question}\nThought: {model_thinking}"}
249
+ ]
250
+
251
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
252
+ sampling_params = SamplingParams(
253
+ max_tokens=1024,
254
+ stop=[tokenizer.eos_token],
255
+ )
256
+
257
+ outputs = llm.generate(prompt, sampling_params)
258
+ responses = [output.outputs[0].text.strip() for output in outputs]
259
+ print(responses[0])
260
+ ```
261
+
262
+ ---
263
+
264
+ ## 📄 5. License
265
+
266
+ This model is released under the **Apache License 2.0**. See the [LICENSE ](https://choosealicense.com/licenses/apache-2.0/)file for details.