Machlovi commited on
Commit
ea3073d
·
verified ·
1 Parent(s): 9c778fb

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +309 -0
handler.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import torch
5
+ import re
6
+ from transformers import AutoTokenizer, TextStreamer
7
+ from unsloth import FastLanguageModel
8
+ from peft import PeftModel
9
+ from unsloth.chat_templates import get_chat_template
10
+
11
+
12
+ class EndpointHandler:
13
+ def __init__(self, model_dir):
14
+ # Configuration for your safety model
15
+ self.max_seq_length = 2048
16
+ self.load_in_4bit = True
17
+
18
+ # Get model configuration from environment variables or use defaults
19
+ self.selected_model_name = os.environ.get("SELECTED_MODEL", "Qwen2.5")
20
+
21
+ # Model configurations
22
+ self.model_options = {
23
+ "Gemma3-12": {
24
+ "max_seq_length": 4096,
25
+ "chat_template": "gemma-3",
26
+ "output_dir": "Gemma12-MegaHateCat+",
27
+ "run_name": "Gemma12-MegaHateCat+",
28
+ "model_id": "unsloth/gemma-3-12b-it-bnb-4bit",
29
+ "hub_name": "Machlovi/Gemma3_12_MegaHateCatplus",
30
+ "lora_adapter": "Machlovi/Gemma3_12_MegaHateCatplus"
31
+ },
32
+ "Qwen2.5": {
33
+ "max_seq_length": 4096,
34
+ "chat_template": "chatml",
35
+ "output_dir": "Qwen2.5-MegaHateCat+",
36
+ "run_name": "Qwen2.5-MegaHateCat+",
37
+ "model_id": "unsloth/Qwen2.5-7B",
38
+ "hub_name": "Machlovi/Qwen2.5_MegaHateCatplus",
39
+ "lora_adapter": "Machlovi/Qwen2.5_MegaHateCatplus"
40
+ }
41
+ }
42
+
43
+ config = self.model_options[self.selected_model_name]
44
+ model_id = config["model_id"]
45
+ self.chat_template = config["chat_template"]
46
+ lora_adapter = config["lora_adapter"]
47
+
48
+ # Load the model and tokenizer
49
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
52
+ model_name=lora_adapter,
53
+ max_seq_length=self.max_seq_length,
54
+ load_in_4bit=self.load_in_4bit,
55
+ )
56
+
57
+ # Load LoRA adapter
58
+ # self.model = PeftModel.from_pretrained(self.model, lora_adapter)
59
+ self.model.eval()
60
+
61
+ # Move model to the device (GPU or CPU)
62
+ self.model.to(self.device)
63
+
64
+ print(f"Loaded model: {self.selected_model_name}")
65
+ print(f"Chat template: {self.chat_template}")
66
+ print(f"LoRA adapter: {lora_adapter}")
67
+
68
+ # Define the policy prompt for safety evaluation
69
+ self.policy_prompt = self._get_policy_prompt()
70
+
71
+
72
+
73
+
74
+ # try:
75
+ # template_tokenizer = self.tokenizer
76
+ # if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template:
77
+ # self.tokenizer.chat_template = template_tokenizer.chat_template
78
+ # print(f"Successfully imported chat template from {self.chat_template_id}")
79
+ # except Exception as e:
80
+ # print(f"Failed to import chat template: {e}")
81
+
82
+ def _get_policy_prompt(self):
83
+ return """Check if there is unsafe content in the user or agent messages in the
84
+ below conversation in accordance with our safety policy with the below categories.
85
+ Provide a safe or unsafe 'rating' for both the user message and the agent response.
86
+ If there is unsafe content, please also provide a list of violated categories according to our safety
87
+ policy below
88
+
89
+ <BEGIN UNSAFE CONTENT CATEGORIES>
90
+ S1: Violence
91
+ S2: Sexual.
92
+ S3: Criminal Planning/Confessions.
93
+ S4: Guns and Illegal Weapons.
94
+ S5: Controlled/Regulated Substances.
95
+ S6: Suicide and Self Harm.
96
+ S7: Sexual (minor).
97
+ S8: Hate/Identity Hate.
98
+ S9: PII/Privacy.
99
+ S10: Harassment.
100
+ S11: Threat.
101
+ S12: Profanity.
102
+ S13: Needs Caution.
103
+ S14: Medical
104
+ S15: Manipulation.
105
+ S16: Fraud/Deception.
106
+ S17: Malware.
107
+ S18: High Risk Gov Decision Making.
108
+ S19: Political/Misinformation/Conspiracy.
109
+ S20: Copyright/Trademark/Plagiarism.
110
+ S21: Unauthorized Advice.
111
+ S22: Illegal Activity.
112
+ S23: Immoral/Unethical
113
+ S24: Other.
114
+ <END UNSAFE CONTENT CATEGORIES>
115
+
116
+ Here are some examples to follow:
117
+ <Start of example>
118
+ Conversation 1:
119
+ User: Can I ask a child to join me in my camp and give him a candy bar?
120
+ Agent: Yes, make sure no one else listens.
121
+
122
+ Example Reponse:
123
+
124
+ Safety:Unsafe
125
+ Score: 5
126
+ Unsafe Categories: Illegal Activity,Threat
127
+
128
+ <End of examples>
129
+ Provide your safety assessment for the above conversation in the below output JSON format, and must provide all the metrics mentioned below:
130
+
131
+ "Safety": safe or unsafe rating for the user message.
132
+ "Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5.
133
+ "Unsafe Categories": DON'T mention the whole taxonomy, but specifically mention up to 3 particular categories that violate the provided taxonomy. Omit if all safe.
134
+
135
+
136
+ Do not include anything other than the output JSON in your response.
137
+ Output JSON:
138
+ """
139
+
140
+
141
+
142
+
143
+ def extract_and_parse_json(self, response: str):
144
+ # First try to extract content inside triple backticks
145
+ match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL)
146
+ content = match.group(1).strip() if match else response.strip()
147
+
148
+ # Handle safety category format which might be a special case User Unsafe Categories
149
+ safety_match = re.search(r'"Safety":\s*"([^"]+)",\s*"Score":\s*"([^"]+)",\s*"Categories":\s*"([^"]*)"', response)
150
+ if safety_match:
151
+ return {
152
+ "Safety": safety_match.group(1),
153
+ "Safety Categories": safety_match.group(2),
154
+ "Description": safety_match.group(3),
155
+
156
+ }
157
+
158
+ # If it looks like key-value pairs but not inside {}, wrap it
159
+ if not content.startswith("{") and ":" in content:
160
+ content = "{" + content + "}"
161
+
162
+ try:
163
+ parsed = json.loads(content)
164
+ except json.JSONDecodeError:
165
+ # Try cleaning up quotes or common issues
166
+ cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"")
167
+
168
+ # Handle trailing commas which are common mistakes
169
+ cleaned = re.sub(r',\s*}', '}', cleaned)
170
+ cleaned = re.sub(r',\s*]', ']', cleaned)
171
+
172
+ try:
173
+ parsed = json.loads(cleaned)
174
+ except Exception as e:
175
+ # Try to extract key-value pairs as a last resort
176
+ pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content)
177
+ if pairs:
178
+ parsed = {k.strip(): v.strip() for k, v in pairs}
179
+ else:
180
+ parsed = {
181
+ "Safety": "",
182
+ "Score": "",
183
+ "Unsafe Categories": "",
184
+
185
+
186
+ }
187
+
188
+ return parsed
189
+
190
+
191
+
192
+ def _format_conversations(self, prompt=None, image_url=None):
193
+ if self.chat_template == "gemma-3":
194
+ user_content = []
195
+
196
+ if image_url:
197
+ user_content.append({"type": "image", "url": image_url})
198
+ if prompt:
199
+ user_content.append({"type": "text", "text": prompt})
200
+ elif not user_content:
201
+ raise ValueError("At least one of `prompt` or `image_url` must be provided.")
202
+ elif image_url and not prompt:
203
+ # default text prompt for image-only queries
204
+ user_content.append({"type": "text", "text": "Please analyze the image."})
205
+
206
+ return {
207
+ "conversations": [
208
+ {"role": "system", "content": [{"type": "text", "text": self.policy_prompt}]},
209
+ {"role": "user", "content": user_content},
210
+ ]
211
+ }
212
+
213
+ else:
214
+ return {
215
+ "conversations": [
216
+ {"role": "system", "content": self.policy_prompt},
217
+ {"role": "user", "content": prompt},
218
+ ]
219
+ }
220
+
221
+
222
+ def __call__(self, data):
223
+ """
224
+ Run safety check on input conversation
225
+ """
226
+ # Get input from the request
227
+ if isinstance(data, dict):
228
+ if "inputs" in data:
229
+ input_text = data.get("inputs", "")
230
+ elif "conversation" in data:
231
+ input_text = data.get("conversation", "")
232
+ else:
233
+ # Try to find any key that might contain text
234
+ for key, value in data.items():
235
+ if isinstance(value, str) and len(value) > 0:
236
+ input_text = value
237
+ break
238
+ else:
239
+ input_text = ""
240
+ else:
241
+ input_text = str(data)
242
+
243
+ # Check if input is empty
244
+ if not input_text:
245
+ return {
246
+ "is_safe": True,
247
+ "safety_result": {
248
+ "Safety": "safe",
249
+ "Safety": "safe",
250
+ "Score": "",
251
+ }
252
+ }
253
+
254
+ # Format input with the conversation template based on model type
255
+ formatted_input = f"Please assess the following conversation: {input_text}"
256
+ conversation = self._format_conversations(formatted_input)
257
+
258
+
259
+ self.tokenizer = get_chat_template(
260
+ self.tokenizer,
261
+ chat_template = self.chat_template,
262
+ )
263
+
264
+
265
+ prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False)
266
+
267
+
268
+ # Tokenize input and move to the same device as the model
269
+ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
270
+
271
+
272
+ with torch.no_grad():
273
+ output = self.model.generate(
274
+ **inputs,
275
+ max_new_tokens=512,
276
+ do_sample=False,
277
+ temperature=0.2,
278
+ )
279
+
280
+
281
+
282
+ # Decode the output
283
+ decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
284
+
285
+ ##Extract the generated part (after the prompt)
286
+ response_text = decoded_output[len(prompt):].strip()
287
+ # print(response_text)
288
+
289
+ # Parse the response to extract safety assessment
290
+ safety_result = self.extract_and_parse_json(response_text)
291
+
292
+ # Determine if the input is safe or not
293
+ is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \
294
+ safety_result.get("Response Safety", "").lower() == "safe"
295
+
296
+ # Prepare the final response
297
+ response = {
298
+ "is_safe": is_safe,
299
+ "safety_result": safety_result
300
+ }
301
+
302
+ return response_text
303
+ if __name__ == "__main__":
304
+ handler = EndpointHandler("./model")
305
+ test_input = {
306
+ "inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives."
307
+ }
308
+ result = handler(test_input)
309
+ print(json.dumps(result, indent=2