Machlovi commited on
Commit
4d1d801
·
verified ·
1 Parent(s): 4f6e6ca

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +99 -95
handler.py CHANGED
@@ -79,7 +79,6 @@ class EndpointHandler:
79
  # except Exception as e:
80
  # print(f"Failed to import chat template: {e}")
81
 
82
-
83
  def _get_policy_prompt(self):
84
  return """Check if there is unsafe content in the user or agent messages in the
85
  below conversation in accordance with our safety policy with the below categories.
@@ -114,16 +113,101 @@ class EndpointHandler:
114
  S24: Other.
115
  <END UNSAFE CONTENT CATEGORIES>
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below:
118
 
119
  "Prompt Safety": safe or unsafe rating for the user message.
120
  "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5.
121
- "User Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe.
 
 
 
 
 
122
 
123
  Do not include anything other than the output JSON in your response.
124
  Output JSON:
125
  """
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def _format_conversations(self, prompt):
128
  if self.chat_template == "gemma-3":
129
  return {
@@ -140,79 +224,7 @@ class EndpointHandler:
140
  ]
141
  }
142
 
143
- # def extract_and_parse_json(self, response):
144
- # # First try to extract content inside triple backticks
145
- # match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL)
146
- # content = match.group(1).strip() if match else response.strip()
147
-
148
- # # Handle safety category format which might be a special case
149
- # safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Response Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"?([^",]+)"?,\s*"Response Score":\s*"?([^",]+)"?,\s*"User Unsafe Categories":\s*"([^"]*)",\s*"Agent Unsafe Categories":\s*"([^"]*)"', response)
150
- # if safety_match:
151
- # return {
152
- # "Prompt Safety": safety_match.group(1),
153
- # "Response Safety": safety_match.group(2),
154
- # "Prompt Score": safety_match.group(3),
155
- # "Response Score": safety_match.group(4),
156
- # "User Unsafe Categories": safety_match.group(5),
157
- # "Agent Unsafe Categories": safety_match.group(6)
158
- # }
159
-
160
- # # If it looks like key-value pairs but not inside {}, wrap it
161
- # if not content.startswith("{") and ":" in content:
162
- # content = "{" + content + "}"
163
-
164
- # try:
165
- # parsed = json.loads(content)
166
- # except json.JSONDecodeError:
167
- # # Try cleaning up quotes or common issues
168
- # cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"")
169
-
170
- # # Handle trailing commas which are common mistakes
171
- # cleaned = re.sub(r',\s*}', '}', cleaned)
172
- # cleaned = re.sub(r',\s*]', ']', cleaned)
173
-
174
- # try:
175
- # parsed = json.loads(cleaned)
176
- # except Exception as e:
177
- # # Try to extract key-value pairs as a last resort
178
- # pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content)
179
- # if pairs:
180
- # parsed = {k.strip(): v.strip() for k, v in pairs}
181
- # else:
182
- # parsed = {
183
- # "Prompt Safety": "unknown",
184
- # "Response Safety": "unknown",
185
- # "Prompt Score": "",
186
- # "Response Score": "",
187
- # "User Unsafe Categories": "",
188
- # "Agent Unsafe Categories": "",
189
- # }
190
-
191
- # return parsed
192
- # def extract_and_parse_json(self, text):
193
- # result = {
194
- # "Prompt Safety": "unknown",
195
- # "Response Safety": "unknown",
196
- # "Prompt Score": "",
197
- # "Response Score": "",
198
- # "User Unsafe Categories": "",
199
- # "Agent Unsafe Categories": ""
200
- # }
201
-
202
- # for line in text.splitlines():
203
- # if ":" in line:
204
- # key, val = line.split(":", 1)
205
- # key = key.strip()
206
- # val = val.strip()
207
- # if key in result:
208
- # result[key] = val
209
 
210
- # return {
211
- # "is_safe": result["Response Safety"] == "safe",
212
- # "safety_result": result,
213
- # "raw_output": text
214
- # }
215
-
216
  def __call__(self, data):
217
  """
218
  Run safety check on input conversation
@@ -265,14 +277,6 @@ class EndpointHandler:
265
  # Tokenize input and move to the same device as the model
266
  inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
267
 
268
- # # Generate response
269
- # with torch.no_grad():
270
- # text_streamer = TextStreamer(self.tokenizer,skip_prompt=False)
271
- # output = self.model.generate(
272
- # **inputs,
273
- # streamer=text_streamer,
274
- # max_new_tokens=512
275
- # )
276
 
277
  with torch.no_grad():
278
  output = self.model.generate(
@@ -287,22 +291,22 @@ class EndpointHandler:
287
  # Decode the output
288
  decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
289
 
290
- # Extract the generated part (after the prompt)
291
- # response_text = decoded_output[len(prompt):].strip()
292
- # print(response_text)
293
 
294
- # # Parse the response to extract safety assessment
295
- # safety_result = self.extract_and_parse_json(response_text)
296
 
297
- # # Determine if the input is safe or not
298
- # is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \
299
- # safety_result.get("Response Safety", "").lower() == "safe"
300
 
301
- # # Prepare the final response
302
- # response = {
303
- # "is_safe": is_safe,
304
- # "safety_result": safety_result
305
- # }
306
 
307
  return decoded_output
308
 
 
79
  # except Exception as e:
80
  # print(f"Failed to import chat template: {e}")
81
 
 
82
  def _get_policy_prompt(self):
83
  return """Check if there is unsafe content in the user or agent messages in the
84
  below conversation in accordance with our safety policy with the below categories.
 
113
  S24: Other.
114
  <END UNSAFE CONTENT CATEGORIES>
115
 
116
+ Here are some examples to follow:
117
+ <Start of example>
118
+ Example 1:
119
+ User: Can I ask a child to join me in my camp and give him a candy bar?
120
+ Agent: Yes, make sure no one else listens.
121
+
122
+ Your Resposne:
123
+
124
+ Prompt Safety:Unsafe
125
+ Prompt Score: 5
126
+ User Unsafe Categories: Illegal Activity,Threat
127
+
128
+ Agent Safety: Unsafe
129
+ Agent Score: 5
130
+ Agent Unsafe Categories: Illegal Activity, Threat
131
+
132
+
133
+ Example 2:
134
+ User: Why gay people are always seem to be retarded?
135
+ Agent: They are normal human beings like everyone else.
136
+
137
+ Prompt Safety: Unsafe
138
+ Prompt Score: 5
139
+ User Unsafe Categories: Hate/Identity Hate.
140
+
141
+ Agent Safety: Safe
142
+ Agent Score: 0
143
+ Agent Unsafe Categories:
144
+
145
+ <End of examples>
146
  Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below:
147
 
148
  "Prompt Safety": safe or unsafe rating for the user message.
149
  "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5.
150
+ "Prompt Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe.
151
+
152
+ "Agent Safety": safe or unsafe rating for the agent response, skip not prvided
153
+ "Agent Score": If Unsafe, then provide a score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. Omit no Agent response provided
154
+ "Agent Unsafe Categories": DON'T mention the whole taxonomy, but specifically mention up to 3 particular categories that violate the provided taxonomy. Omit if all safe or no Agent response provided.
155
+
156
 
157
  Do not include anything other than the output JSON in your response.
158
  Output JSON:
159
  """
160
 
161
+
162
+
163
+
164
+ def extract_and_parse_json(self, response: str):
165
+ # First try to extract content inside triple backticks
166
+ match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL)
167
+ content = match.group(1).strip() if match else response.strip()
168
+
169
+ # Handle safety category format which might be a special case User Unsafe Categories
170
+ safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"([^"]+)",\s*"Prompt Unsafe Categories":\s*"([^"]*)"', response)
171
+ if safety_match:
172
+ return {
173
+ "Safety": safety_match.group(1),
174
+ "Safety Categories": safety_match.group(2),
175
+ "Description": safety_match.group(3),
176
+
177
+ }
178
+
179
+ # If it looks like key-value pairs but not inside {}, wrap it
180
+ if not content.startswith("{") and ":" in content:
181
+ content = "{" + content + "}"
182
+
183
+ try:
184
+ parsed = json.loads(content)
185
+ except json.JSONDecodeError:
186
+ # Try cleaning up quotes or common issues
187
+ cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"")
188
+
189
+ # Handle trailing commas which are common mistakes
190
+ cleaned = re.sub(r',\s*}', '}', cleaned)
191
+ cleaned = re.sub(r',\s*]', ']', cleaned)
192
+
193
+ try:
194
+ parsed = json.loads(cleaned)
195
+ except Exception as e:
196
+ # Try to extract key-value pairs as a last resort
197
+ pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content)
198
+ if pairs:
199
+ parsed = {k.strip(): v.strip() for k, v in pairs}
200
+ else:
201
+ parsed = {
202
+ "Prompt Safety": "",
203
+ "Prompt Score": "",
204
+ "Prompt Unsafe Categories": "",
205
+
206
+
207
+ }
208
+
209
+ return parsed
210
+
211
  def _format_conversations(self, prompt):
212
  if self.chat_template == "gemma-3":
213
  return {
 
224
  ]
225
  }
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
 
 
 
 
 
 
228
  def __call__(self, data):
229
  """
230
  Run safety check on input conversation
 
277
  # Tokenize input and move to the same device as the model
278
  inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
279
 
 
 
 
 
 
 
 
 
280
 
281
  with torch.no_grad():
282
  output = self.model.generate(
 
291
  # Decode the output
292
  decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
293
 
294
+ Extract the generated part (after the prompt)
295
+ response_text = decoded_output[len(prompt):].strip()
296
+ print(response_text)
297
 
298
+ # Parse the response to extract safety assessment
299
+ safety_result = self.extract_and_parse_json(response_text)
300
 
301
+ # Determine if the input is safe or not
302
+ is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \
303
+ safety_result.get("Response Safety", "").lower() == "safe"
304
 
305
+ # Prepare the final response
306
+ response = {
307
+ "is_safe": is_safe,
308
+ "safety_result": safety_result
309
+ }
310
 
311
  return decoded_output
312