Improve language tag

#2
by lbourdois - opened
Files changed (1) hide show
  1. README.md +118 -104
README.md CHANGED
@@ -1,105 +1,119 @@
1
- ---
2
- license: apache-2.0
3
- library_name: transformers
4
- pipeline_tag: text-classification
5
- base_model:
6
- - Qwen/Qwen2.5-1.5B
7
- ---
8
-
9
- ## Overview
10
- A brief description of what this model does and how it’s unique or relevant:
11
-
12
- - **Goal**: Classification upon safety of the input text sequences.
13
- - **Model Description**: DuoGuard-1.5B-transfer is a multilingual, decoder-only LLM-based classifier specifically designed for safety content moderation across 12 distinct subcategories. Each forward pass produces a 12-dimensional logits vector, where each dimension corresponds to a specific content risk area, such as violent crimes, hate, or sexual content. By applying a sigmoid function to these logits, users obtain a multi-label probability distribution, which allows for fine-grained detection of potentially unsafe or disallowed content.
14
- For simplified binary moderation tasks, the model can be used to produce a single “safe”/“unsafe” label by taking the maximum of the 12 subcategory probabilities and comparing it to a given threshold (e.g., 0.5). If the maximum probability across all categories is above the threshold, the content is deemed “unsafe.” Otherwise, it is considered “safe.”
15
-
16
- DuoGuard-1B-Llama-3.2-transfer is built upon Llama-3.2-1B, a multilingual large language model supporting 29 languages—including Chinese, English, French, Spanish, Portuguese, German, Italian, Russian, Japanese, Korean, Vietnamese, Thai, and Arabic. We directly leverage the training data developed fro DuoGuard-0.5B to train Llama-3.2-1B and obtain DuoGuard-1.5B-transfer. Thus, it is specialized (fine-tuned) for safety content moderation primarily in English, French, German, and Spanish, while still retaining the broader language coverage inherited from the Qwen2.5 base model. It is provided with open weights.
17
- ## How to Use
18
- A quick code snippet or set of instructions on how to load and use the model in an application:
19
- ```python
20
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
21
- import torch
22
-
23
- # 1. Initialize the tokenizer
24
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")
25
- tokenizer.pad_token = tokenizer.eos_token
26
-
27
- # 2. Load the DuoGuard-0.5B model
28
- model = AutoModelForSequenceClassification.from_pretrained(
29
- "DuoGuard/DuoGuard-1.5B-transfer",
30
- torch_dtype=torch.bfloat16
31
- ).to('cuda:0')
32
-
33
- # 3. Define a sample prompt to test
34
- prompt = "How to kill a python process?"
35
-
36
- # 4. Tokenize the prompt
37
- inputs = tokenizer(
38
- prompt,
39
- return_tensors="pt",
40
- truncation=True,
41
- max_length=512 # adjust as needed
42
- ).to('cuda:0')
43
-
44
- # 5. Run the model (inference)
45
- with torch.no_grad():
46
- outputs = model(**inputs)
47
- # DuoGuard outputs a 12-dimensional vector (one probability per subcategory).
48
- logits = outputs.logits # shape: (batch_size, 12)
49
- probabilities = torch.sigmoid(logits) # element-wise sigmoid
50
-
51
- # 6. Multi-label predictions (one for each category)
52
- threshold = 0.5
53
- category_names = [
54
- "Violent crimes",
55
- "Non-violent crimes",
56
- "Sex-related crimes",
57
- "Child sexual exploitation",
58
- "Specialized advice",
59
- "Privacy",
60
- "Intellectual property",
61
- "Indiscriminate weapons",
62
- "Hate",
63
- "Suicide and self-harm",
64
- "Sexual content",
65
- "Jailbreak prompts",
66
- ]
67
-
68
- # Extract probabilities for the single prompt (batch_size = 1)
69
- prob_vector = probabilities[0].tolist() # shape: (12,)
70
-
71
- predicted_labels = []
72
- for cat_name, prob in zip(category_names, prob_vector):
73
- label = 1 if prob > threshold else 0
74
- predicted_labels.append(label)
75
-
76
- # 7. Overall binary classification: "safe" vs. "unsafe"
77
- # We consider the prompt "unsafe" if ANY category is above the threshold.
78
- max_prob = max(prob_vector)
79
- overall_label = 1 if max_prob > threshold else 0 # 1 => unsafe, 0 => safe
80
-
81
- # 8. Print results
82
- print(f"Prompt: {prompt}\n")
83
- print(f"Multi-label Probabilities (threshold={threshold}):")
84
- for cat_name, prob, label in zip(category_names, prob_vector, predicted_labels):
85
- print(f" - {cat_name}: {prob:.3f}")
86
-
87
- print(f"\nMaximum probability across all categories: {max_prob:.3f}")
88
- print(f"Overall Prompt Classification => {'UNSAFE' if overall_label == 1 else 'SAFE'}")
89
- ```
90
-
91
- You can find the code at https://github.com/yihedeng9/DuoGuard.
92
-
93
- ### Citation
94
-
95
- ```plaintext
96
- @misc{deng2025duoguardtwoplayerrldrivenframework,
97
- title={DuoGuard: A Two-Player RL-Driven Framework for Multilingual LLM Guardrails},
98
- author={Yihe Deng and Yu Yang and Junkai Zhang and Wei Wang and Bo Li},
99
- year={2025},
100
- eprint={2502.05163},
101
- archivePrefix={arXiv},
102
- primaryClass={cs.CL},
103
- url={https://arxiv.org/abs/2502.05163},
104
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  ```
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: text-classification
5
+ base_model:
6
+ - Qwen/Qwen2.5-1.5B
7
+ language:
8
+ - zho
9
+ - eng
10
+ - fra
11
+ - spa
12
+ - por
13
+ - deu
14
+ - ita
15
+ - rus
16
+ - jpn
17
+ - kor
18
+ - vie
19
+ - tha
20
+ - ara
21
+ ---
22
+
23
+ ## Overview
24
+ A brief description of what this model does and how it’s unique or relevant:
25
+
26
+ - **Goal**: Classification upon safety of the input text sequences.
27
+ - **Model Description**: DuoGuard-1.5B-transfer is a multilingual, decoder-only LLM-based classifier specifically designed for safety content moderation across 12 distinct subcategories. Each forward pass produces a 12-dimensional logits vector, where each dimension corresponds to a specific content risk area, such as violent crimes, hate, or sexual content. By applying a sigmoid function to these logits, users obtain a multi-label probability distribution, which allows for fine-grained detection of potentially unsafe or disallowed content.
28
+ For simplified binary moderation tasks, the model can be used to produce a single “safe”/“unsafe” label by taking the maximum of the 12 subcategory probabilities and comparing it to a given threshold (e.g., 0.5). If the maximum probability across all categories is above the threshold, the content is deemed “unsafe.” Otherwise, it is considered “safe.”
29
+
30
+ DuoGuard-1B-Llama-3.2-transfer is built upon Llama-3.2-1B, a multilingual large language model supporting 29 languages—including Chinese, English, French, Spanish, Portuguese, German, Italian, Russian, Japanese, Korean, Vietnamese, Thai, and Arabic. We directly leverage the training data developed fro DuoGuard-0.5B to train Llama-3.2-1B and obtain DuoGuard-1.5B-transfer. Thus, it is specialized (fine-tuned) for safety content moderation primarily in English, French, German, and Spanish, while still retaining the broader language coverage inherited from the Qwen2.5 base model. It is provided with open weights.
31
+ ## How to Use
32
+ A quick code snippet or set of instructions on how to load and use the model in an application:
33
+ ```python
34
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
35
+ import torch
36
+
37
+ # 1. Initialize the tokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+
41
+ # 2. Load the DuoGuard-0.5B model
42
+ model = AutoModelForSequenceClassification.from_pretrained(
43
+ "DuoGuard/DuoGuard-1.5B-transfer",
44
+ torch_dtype=torch.bfloat16
45
+ ).to('cuda:0')
46
+
47
+ # 3. Define a sample prompt to test
48
+ prompt = "How to kill a python process?"
49
+
50
+ # 4. Tokenize the prompt
51
+ inputs = tokenizer(
52
+ prompt,
53
+ return_tensors="pt",
54
+ truncation=True,
55
+ max_length=512 # adjust as needed
56
+ ).to('cuda:0')
57
+
58
+ # 5. Run the model (inference)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+ # DuoGuard outputs a 12-dimensional vector (one probability per subcategory).
62
+ logits = outputs.logits # shape: (batch_size, 12)
63
+ probabilities = torch.sigmoid(logits) # element-wise sigmoid
64
+
65
+ # 6. Multi-label predictions (one for each category)
66
+ threshold = 0.5
67
+ category_names = [
68
+ "Violent crimes",
69
+ "Non-violent crimes",
70
+ "Sex-related crimes",
71
+ "Child sexual exploitation",
72
+ "Specialized advice",
73
+ "Privacy",
74
+ "Intellectual property",
75
+ "Indiscriminate weapons",
76
+ "Hate",
77
+ "Suicide and self-harm",
78
+ "Sexual content",
79
+ "Jailbreak prompts",
80
+ ]
81
+
82
+ # Extract probabilities for the single prompt (batch_size = 1)
83
+ prob_vector = probabilities[0].tolist() # shape: (12,)
84
+
85
+ predicted_labels = []
86
+ for cat_name, prob in zip(category_names, prob_vector):
87
+ label = 1 if prob > threshold else 0
88
+ predicted_labels.append(label)
89
+
90
+ # 7. Overall binary classification: "safe" vs. "unsafe"
91
+ # We consider the prompt "unsafe" if ANY category is above the threshold.
92
+ max_prob = max(prob_vector)
93
+ overall_label = 1 if max_prob > threshold else 0 # 1 => unsafe, 0 => safe
94
+
95
+ # 8. Print results
96
+ print(f"Prompt: {prompt}\n")
97
+ print(f"Multi-label Probabilities (threshold={threshold}):")
98
+ for cat_name, prob, label in zip(category_names, prob_vector, predicted_labels):
99
+ print(f" - {cat_name}: {prob:.3f}")
100
+
101
+ print(f"\nMaximum probability across all categories: {max_prob:.3f}")
102
+ print(f"Overall Prompt Classification => {'UNSAFE' if overall_label == 1 else 'SAFE'}")
103
+ ```
104
+
105
+ You can find the code at https://github.com/yihedeng9/DuoGuard.
106
+
107
+ ### Citation
108
+
109
+ ```plaintext
110
+ @misc{deng2025duoguardtwoplayerrldrivenframework,
111
+ title={DuoGuard: A Two-Player RL-Driven Framework for Multilingual LLM Guardrails},
112
+ author={Yihe Deng and Yu Yang and Junkai Zhang and Wei Wang and Bo Li},
113
+ year={2025},
114
+ eprint={2502.05163},
115
+ archivePrefix={arXiv},
116
+ primaryClass={cs.CL},
117
+ url={https://arxiv.org/abs/2502.05163},
118
+ }
119
  ```