jainsatyam26 commited on
Commit
30c7d87
·
verified ·
1 Parent(s): 07247eb

Upload guard-safety-classifier model

Browse files
README.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ tags:
5
+ - safety-classifier
6
+ - content-moderation
7
+ - multi-task
8
+ - deberta-v3
9
+ - text-classification
10
+ datasets:
11
+ - budecosystem/guardrail-training-data
12
+ metrics:
13
+ - accuracy
14
+ - f1
15
+ ---
16
+
17
+ # 🛡️ Guard Safety Classifier
18
+
19
+ A multi-task safety classifier based on **DeBERTa-v3-small** trained on 3.9M+ samples for content moderation and safety detection.
20
+
21
+ ## 🎯 Model Tasks
22
+
23
+ This model performs **three simultaneous predictions**:
24
+
25
+ 1. **Binary Safety Classification** (`is_safe`)
26
+ - ✅ Safe content
27
+ - ⚠️ Unsafe content
28
+
29
+ 2. **Single-Label Category Classification** (`category`)
30
+ - Identifies the primary safety concern category
31
+
32
+ 3. **Multi-Label Categories** (`categories`)
33
+ - Can detect multiple safety issues simultaneously
34
+
35
+ ## 📊 Performance Metrics
36
+
37
+ | Metric | Score |
38
+ |--------|-------|
39
+ | **is_safe Accuracy** | 92.76% |
40
+ | **category F1** | 0.5037 |
41
+ | **categories F1** | 0.9068 |
42
+ | **Test Loss** | 1.0233 |
43
+
44
+ ## 🚀 Quick Start
45
+
46
+ ```python
47
+ import torch
48
+ from transformers import AutoTokenizer
49
+ import pickle
50
+
51
+ # Load model and tokenizer
52
+ model_name = "YOUR_USERNAME/guard-safety-classifier"
53
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
54
+
55
+ # Load model architecture
56
+ from your_model_file import MultiTaskSafetyClassifier
57
+ model = MultiTaskSafetyClassifier(
58
+ model_name="microsoft/deberta-v3-small",
59
+ num_categories=NUM_CATEGORIES,
60
+ num_multi_labels=NUM_MULTI_LABELS
61
+ )
62
+
63
+ # Load weights
64
+ model.load_state_dict(torch.load("model_weights.pt"))
65
+ model.eval()
66
+
67
+ # Load label encoders
68
+ with open("label_encoders.pkl", "rb") as f:
69
+ encoders = pickle.load(f)
70
+ le_category = encoders['le_category']
71
+ mlb = encoders['mlb']
72
+
73
+ # Inference
74
+ text = "Your text here"
75
+ inputs = tokenizer(text, return_tensors="pt", max_length=128,
76
+ truncation=True, padding=True)
77
+
78
+ with torch.no_grad():
79
+ outputs = model(**inputs)
80
+
81
+ is_safe = torch.softmax(outputs['is_safe'], dim=1)[0][1].item() > 0.5
82
+ category = le_category.inverse_transform([outputs['category'].argmax(1).item()])[0]
83
+ categories = mlb.inverse_transform((torch.sigmoid(outputs['categories']) > 0.5).cpu().numpy())[0]
84
+
85
+ print(f"Is Safe: {is_safe}")
86
+ print(f"Category: {category}")
87
+ print(f"Categories: {list(categories)}")
88
+ ```
89
+
90
+ ## 🏗️ Model Architecture
91
+
92
+ - **Base Model**: `microsoft/deberta-v3-small` (141M parameters)
93
+ - **Hidden Size**: 768
94
+ - **Max Sequence Length**: 128 tokens
95
+ - **Training Framework**: PyTorch + Transformers
96
+
97
+ ## 📚 Training Details
98
+
99
+ - **Dataset**: [budecosystem/guardrail-training-data](https://huggingface.co/datasets/budecosystem/guardrail-training-data)
100
+ - **Training Samples**: 3,182,844
101
+ - **Validation Samples**: 397,855
102
+ - **Test Samples**: 397,856
103
+ - **Batch Size**: 64
104
+ - **Learning Rate**: 2e-5
105
+ - **Epochs**: 1
106
+ - **Optimizer**: AdamW with linear warmup
107
+ - **Hardware**: NVIDIA Tesla T4 (16GB)
108
+ - **Training Time**: ~8 hours
109
+
110
+ ## 🏷️ Categories
111
+
112
+ The model can identify the following safety categories:
113
+
114
+ ```python
115
+ [
116
+ "animal_abuse",
117
+ "benign",
118
+ "child_abuse",
119
+ "code_vulnerabilities",
120
+ "controversial_topics_politics",
121
+ "cwe_compliance",
122
+ "dangerous_expert_advice",
123
+ "discrimination_stereotype_injustice",
124
+ "drug_abuse_weapons_banned_substance",
125
+ "financial_crime_property_crime_theft",
126
+ "fraud_deception_misinformation",
127
+ "gender_bias",
128
+ "hate_speech_offensive_language",
129
+ "jailbreak_prompt_injection",
130
+ "malware_hacking_cyberattack",
131
+ "misinformation_regarding_ethics_laws_and_safety",
132
+ "mitre_compliance",
133
+ "non_violent_unethical_behavior",
134
+ "orientation_bias",
135
+ "privacy_violation",
136
+ "race_bias",
137
+ "religious_bias",
138
+ "self_harm",
139
+ "sexually_explicit_adult_content",
140
+ "terrorism_organized_crime",
141
+ "violence_aiding_and_abetting_incitement"
142
+ ]
143
+ ```
144
+
145
+ ## 🔢 Multi-Label Classes
146
+
147
+ ```python
148
+ [
149
+ " ",
150
+ ",",
151
+ "_",
152
+ "a",
153
+ "b",
154
+ "c",
155
+ "d",
156
+ "e",
157
+ "f",
158
+ "g",
159
+ "h",
160
+ "i",
161
+ "j",
162
+ "k",
163
+ "l",
164
+ "m",
165
+ "n",
166
+ "o",
167
+ "p",
168
+ "r",
169
+ "s",
170
+ "t",
171
+ "u",
172
+ "v",
173
+ "w",
174
+ "x",
175
+ "y",
176
+ "z"
177
+ ]
178
+ ```
179
+
180
+ ## ⚙️ Configuration
181
+
182
+ Full model configuration is available in `config.json`
183
+
184
+ ## 📄 License
185
+
186
+ Apache 2.0
187
+
188
+ ## 🙏 Acknowledgments
189
+
190
+ - Base model: [microsoft/deberta-v3-small](https://huggingface.co/microsoft/deberta-v3-small)
191
+ - Training data: [budecosystem/guardrail-training-data](https://huggingface.co/datasets/budecosystem/guardrail-training-data)
192
+
193
+ ## 📮 Contact
194
+
195
+ For questions or issues, please open an issue on the model repository.
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[MASK]": 128000
3
+ }
config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "microsoft/deberta-v3-small",
3
+ "max_len": 128,
4
+ "batch_size": 64,
5
+ "epochs": 1,
6
+ "lr": 2e-05,
7
+ "weight_decay": 0.01,
8
+ "warmup_steps": 500,
9
+ "grad_clip": 1.0,
10
+ "seed": 42,
11
+ "w_is_safe": 1.0,
12
+ "w_category": 1.0,
13
+ "w_categories": 0.5,
14
+ "save_steps": 200,
15
+ "eval_steps": 500,
16
+ "num_categories": 26,
17
+ "num_multi_labels": 28,
18
+ "category_classes": [
19
+ "animal_abuse",
20
+ "benign",
21
+ "child_abuse",
22
+ "code_vulnerabilities",
23
+ "controversial_topics_politics",
24
+ "cwe_compliance",
25
+ "dangerous_expert_advice",
26
+ "discrimination_stereotype_injustice",
27
+ "drug_abuse_weapons_banned_substance",
28
+ "financial_crime_property_crime_theft",
29
+ "fraud_deception_misinformation",
30
+ "gender_bias",
31
+ "hate_speech_offensive_language",
32
+ "jailbreak_prompt_injection",
33
+ "malware_hacking_cyberattack",
34
+ "misinformation_regarding_ethics_laws_and_safety",
35
+ "mitre_compliance",
36
+ "non_violent_unethical_behavior",
37
+ "orientation_bias",
38
+ "privacy_violation",
39
+ "race_bias",
40
+ "religious_bias",
41
+ "self_harm",
42
+ "sexually_explicit_adult_content",
43
+ "terrorism_organized_crime",
44
+ "violence_aiding_and_abetting_incitement"
45
+ ],
46
+ "multi_label_classes": [
47
+ " ",
48
+ ",",
49
+ "_",
50
+ "a",
51
+ "b",
52
+ "c",
53
+ "d",
54
+ "e",
55
+ "f",
56
+ "g",
57
+ "h",
58
+ "i",
59
+ "j",
60
+ "k",
61
+ "l",
62
+ "m",
63
+ "n",
64
+ "o",
65
+ "p",
66
+ "r",
67
+ "s",
68
+ "t",
69
+ "u",
70
+ "v",
71
+ "w",
72
+ "x",
73
+ "y",
74
+ "z"
75
+ ],
76
+ "best_val_loss": 1.0249000663187966,
77
+ "test_metrics": {
78
+ "loss": 1.0232949212993905,
79
+ "is_safe_acc": 0.9276446754604681,
80
+ "category_f1": 0.5036962280648937,
81
+ "categories_f1": 0.9067776039136755
82
+ }
83
+ }
label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ebba49ff1eca26a2905f9dc7e4af61c6a68ed079e0c3c3917e8c87db8dba609
3
+ size 5415
model_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4664c6f76d143d0cdbab46aca62014a06fd0d299b911f85c598d65ef8e6d0ccc
3
+ size 567685355
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "[PAD]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": {
9
+ "content": "[UNK]",
10
+ "lstrip": false,
11
+ "normalized": true,
12
+ "rstrip": false,
13
+ "single_word": false
14
+ }
15
+ }
spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
+ size 2464616
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[UNK]",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "128000": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "[CLS]",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": false,
48
+ "eos_token": "[SEP]",
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "pad_token": "[PAD]",
52
+ "sep_token": "[SEP]",
53
+ "sp_model_kwargs": {},
54
+ "split_by_punct": false,
55
+ "tokenizer_class": "DebertaV2Tokenizer",
56
+ "unk_token": "[UNK]",
57
+ "vocab_type": "spm"
58
+ }