Machlovi commited on
Commit
aa829b7
·
verified ·
1 Parent(s): 1192f1b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +292 -0
handler.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import re
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_dir):
9
+ # Configuration for your safety model
10
+ self.max_seq_length = 2048
11
+
12
+ # Get model configuration from environment variables or use defaults
13
+ self.selected_model_name = os.environ.get("SELECTED_MODEL", "Gemma3-12")
14
+
15
+ # Model configurations with template IDs
16
+ self.model_options = {
17
+ "Gemma3-12": {
18
+ "max_seq_length": 4096,
19
+ "chat_template": "google/gemma-3",
20
+ "model_id": "Machlovi/Gemma3_12_MegaHateCatplus",
21
+ },
22
+ "Qwen2.5": {
23
+ "max_seq_length": 4096,
24
+ "chat_template": "Qwen/Qwen2-7B-Instruct",
25
+ "model_id": "Machlovi/Qwen2.5_MegaHateCatplus",
26
+ }
27
+ }
28
+
29
+ config = self.model_options[self.selected_model_name]
30
+ model_id = config["model_id"]
31
+ self.chat_template_id = config["chat_template"]
32
+
33
+ # Load the model and tokenizer using standard HF approach
34
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # Check if model exists in the model_dir
37
+ local_model_path = os.path.join(model_dir, self.selected_model_name)
38
+ if os.path.exists(local_model_path):
39
+ model_path = local_model_path
40
+ print(f"Loading model from local path: {model_path}")
41
+ else:
42
+ model_path = model_id
43
+ print(f"Loading model from Hugging Face Hub: {model_path}")
44
+
45
+ # Load tokenizer and model
46
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
47
+
48
+ # Import chat template from reference model
49
+ try:
50
+ template_tokenizer = AutoTokenizer.from_pretrained(self.chat_template_id)
51
+ if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template:
52
+ self.tokenizer.chat_template = template_tokenizer.chat_template
53
+ print(f"Successfully imported chat template from {self.chat_template_id}")
54
+ else:
55
+ # Fallback templates
56
+ self._set_fallback_template()
57
+ except Exception as e:
58
+ print(f"Failed to import chat template: {e}")
59
+ # Set fallback template
60
+ self._set_fallback_template()
61
+
62
+ # Load the model
63
+ self.model = AutoModelForCausalLM.from_pretrained(
64
+ model_path,
65
+ torch_dtype=torch.float16,
66
+ device_map="auto"
67
+ )
68
+
69
+ # Set EOS and pad token if not set
70
+ if self.tokenizer.pad_token is None:
71
+ self.tokenizer.pad_token = self.tokenizer.eos_token
72
+
73
+ # Define the policy prompt for safety evaluation
74
+ self.policy_prompt = self._get_policy_prompt()
75
+
76
+ print(f"Loaded model: {self.selected_model_name}")
77
+ print(f"Chat template ID: {self.chat_template_id}")
78
+
79
+ def _set_fallback_template(self):
80
+ """Set fallback chat templates based on model type"""
81
+ if "gemma" in self.chat_template_id.lower():
82
+ self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'assistant' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"""
83
+ print("Using fallback Gemma chat template")
84
+ elif "qwen" in self.chat_template_id.lower():
85
+ self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"""
86
+ print("Using fallback Qwen chat template")
87
+ else:
88
+ # Generic fallback
89
+ self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}System: {{ message['content'] }}\n{% elif message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}Assistant: {% endif %}"""
90
+ print("Using generic fallback chat template")
91
+
92
+ def _get_policy_prompt(self):
93
+ return """Check if there is unsafe content in the user or agent messages in the
94
+ below conversation in accordance with our safety policy with the below categories.
95
+ Provide a safe or unsafe 'rating' for both the user message and the agent response.
96
+ If there is unsafe content, please also provide a list of violated categories according to our safety
97
+ policy below
98
+
99
+ <BEGIN UNSAFE CONTENT CATEGORIES>
100
+ S1: Violence
101
+ S2: Sexual.
102
+ S3: Criminal Planning/Confessions.
103
+ S4: Guns and Illegal Weapons.
104
+ S5: Controlled/Regulated Substances.
105
+ S6: Suicide and Self Harm.
106
+ S7: Sexual (minor).
107
+ S8: Hate/Identity Hate.
108
+ S9: PII/Privacy.
109
+ S10: Harassment.
110
+ S11: Threat.
111
+ S12: Profanity.
112
+ S13: Needs Caution.
113
+ S14: Medical
114
+ S15: Manipulation.
115
+ S16: Fraud/Deception.
116
+ S17: Malware.
117
+ S18: High Risk Gov Decision Making.
118
+ S19: Political/Misinformation/Conspiracy.
119
+ S20: Copyright/Trademark/Plagiarism.
120
+ S21: Unauthorized Advice.
121
+ S22: Illegal Activity.
122
+ S23: Immoral/Unethical
123
+ S24: Other.
124
+ <END UNSAFE CONTENT CATEGORIES>
125
+
126
+ Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below:
127
+
128
+ "Prompt Safety": safe or unsafe rating for the user message.
129
+ "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present.
130
+
131
+ "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5.
132
+ "Response Score": If Unsafe then provide score mentioned which portion of the response is unsafe. Score need to be on scale of 1-5.
133
+
134
+ "User Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe.
135
+ "Agent Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. Omit if no agent response present.
136
+
137
+ Do not include anything other than the output JSON in your response.
138
+ Output JSON:
139
+ """
140
+
141
+ def _format_conversations(self, input_text):
142
+ """Format the conversation object for the chat template"""
143
+ return [
144
+ {"role": "system", "content": self.policy_prompt},
145
+ {"role": "user", "content": f"Please assess the following conversation: {input_text}"}
146
+ ]
147
+
148
+ def extract_and_parse_json(self, response):
149
+ # First try to extract content inside triple backticks
150
+ match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL)
151
+ content = match.group(1).strip() if match else response.strip()
152
+
153
+ # Handle safety category format which might be a special case
154
+ safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Response Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"?([^",]+)"?,\s*"Response Score":\s*"?([^",]+)"?,\s*"User Unsafe Categories":\s*"([^"]*)",\s*"Agent Unsafe Categories":\s*"([^"]*)"', response)
155
+ if safety_match:
156
+ return {
157
+ "Prompt Safety": safety_match.group(1),
158
+ "Response Safety": safety_match.group(2),
159
+ "Prompt Score": safety_match.group(3),
160
+ "Response Score": safety_match.group(4),
161
+ "User Unsafe Categories": safety_match.group(5),
162
+ "Agent Unsafe Categories": safety_match.group(6)
163
+ }
164
+
165
+ # If it looks like key-value pairs but not inside {}, wrap it
166
+ if not content.startswith("{") and ":" in content:
167
+ content = "{" + content + "}"
168
+
169
+ try:
170
+ parsed = json.loads(content)
171
+ except json.JSONDecodeError:
172
+ # Try cleaning up quotes or common issues
173
+ cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"")
174
+
175
+ # Handle trailing commas which are common mistakes
176
+ cleaned = re.sub(r',\s*}', '}', cleaned)
177
+ cleaned = re.sub(r',\s*]', ']', cleaned)
178
+
179
+ try:
180
+ parsed = json.loads(cleaned)
181
+ except Exception as e:
182
+ # Try to extract key-value pairs as a last resort
183
+ pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content)
184
+ if pairs:
185
+ parsed = {k.strip(): v.strip() for k, v in pairs}
186
+ else:
187
+ parsed = {
188
+ "Prompt Safety": "unknown",
189
+ "Response Safety": "unknown",
190
+ "Prompt Score": "",
191
+ "Response Score": "",
192
+ "User Unsafe Categories": "",
193
+ "Agent Unsafe Categories": "",
194
+ }
195
+
196
+ return parsed
197
+
198
+ def __call__(self, data):
199
+ """
200
+ Run safety check on input conversation
201
+ """
202
+ # Get input from the request
203
+ if isinstance(data, dict):
204
+ if "inputs" in data:
205
+ input_text = data.get("inputs", "")
206
+ elif "conversation" in data:
207
+ input_text = data.get("conversation", "")
208
+ else:
209
+ # Try to find any key that might contain text
210
+ for key, value in data.items():
211
+ if isinstance(value, str) and len(value) > 0:
212
+ input_text = value
213
+ break
214
+ else:
215
+ input_text = ""
216
+ else:
217
+ input_text = str(data)
218
+
219
+ # Check if input is empty
220
+ if not input_text:
221
+ return {
222
+ "is_safe": True,
223
+ "safety_result": {
224
+ "Prompt Safety": "safe",
225
+ "Response Safety": "safe",
226
+ "Prompt Score": "",
227
+ "Response Score": "",
228
+ "User Unsafe Categories": "",
229
+ "Agent Unsafe Categories": ""
230
+ }
231
+ }
232
+
233
+ # Format conversation for the model
234
+ conversation = self._format_conversations(input_text)
235
+
236
+ # Apply the chat template to prepare for the model
237
+ try:
238
+ if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
239
+ prompt = self.tokenizer.apply_chat_template(
240
+ conversation,
241
+ tokenize=False,
242
+ add_generation_prompt=True
243
+ )
244
+ else:
245
+ # Manual fallback
246
+ prompt = f"System: {self.policy_prompt}\nUser: Please assess the following conversation: {input_text}\nAssistant:"
247
+ except Exception as e:
248
+ print(f"Error applying chat template: {e}")
249
+ # Manual fallback
250
+ prompt = f"System: {self.policy_prompt}\nUser: Please assess the following conversation: {input_text}\nAssistant:"
251
+
252
+ # Tokenize input
253
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
254
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
255
+
256
+ # Generate response
257
+ with torch.no_grad():
258
+ output_ids = self.model.generate(
259
+ **inputs,
260
+ max_new_tokens=512,
261
+ do_sample=False
262
+ )
263
+
264
+ # Decode the output
265
+ full_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
266
+
267
+ # Extract the generated part (after the prompt)
268
+ response_text = full_output[len(prompt):].strip()
269
+
270
+ # Parse the response to extract safety assessment
271
+ safety_result = self.extract_and_parse_json(response_text)
272
+
273
+ # Determine if the input is safe or not
274
+ is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \
275
+ safety_result.get("Response Safety", "").lower() == "safe"
276
+
277
+ # Prepare the final response
278
+ response = {
279
+ "is_safe": is_safe,
280
+ "safety_result": safety_result
281
+ }
282
+
283
+ return response
284
+
285
+ # For local testing
286
+ if __name__ == "__main__":
287
+ handler = EndpointHandler("./model")
288
+ test_input = {
289
+ "inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives."
290
+ }
291
+ result = handler(test_input)
292
+ print(json.dumps(result, indent=2))