Machlovi commited on
Commit
15f5ed0
·
verified ·
1 Parent(s): a11ca78

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +262 -17
handler.py CHANGED
@@ -1,24 +1,269 @@
 
 
 
 
 
1
  from unsloth import FastLanguageModel
2
  from peft import PeftModel
3
- import torch
4
 
5
- # Load the base model
6
- base_model_name = "unsloth/gemma-3-12b-it-bnb-4bit"
7
- model, tokenizer = FastLanguageModel.from_pretrained(
8
- model_name=base_model_name,
9
- max_seq_length=4096, # Must match fine-tuning
10
- load_in_4bit=True,
11
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Load the fine-tuned LoRA adapter
14
- lora_model_name = "Machlovi/Gemma3_12_MegaHateCatplus"
15
- model = PeftModel.from_pretrained(model, lora_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- input_text = "Why do we need to go to see something?"
18
- inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- with torch.no_grad():
21
- outputs = model.generate(**inputs, max_new_tokens=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Decode and print response
24
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import re
5
+ from transformers import AutoTokenizer, TextStreamer
6
  from unsloth import FastLanguageModel
7
  from peft import PeftModel
 
8
 
9
+ class EndpointHandler:
10
+ def __init__(self, model_dir):
11
+ # Configuration for your safety model
12
+ self.max_seq_length = 2048
13
+ self.load_in_4bit = True
14
+
15
+ # Get model configuration from environment variables or use defaults
16
+ self.selected_model_name = os.environ.get("SELECTED_MODEL", "Gemma3-12")
17
+
18
+ # Model configurations
19
+ self.model_options = {
20
+ "Gemma3-12": {
21
+ "max_seq_length": 4096,
22
+ "chat_template": "gemma-3",
23
+ "output_dir": "Gemma12-MegaHateCat+",
24
+ "run_name": "Gemma12-MegaHateCat+",
25
+ "model_id": "unsloth/gemma-3-12b-it-bnb-4bit",
26
+ "hub_name": "Machlovi/Gemma3_12_MegaHateCatplus",
27
+ "lora_adapter": f"{model_dir}/lora_adapter/gemma12"
28
+ },
29
+ "Qwen2.5": {
30
+ "max_seq_length": 4096,
31
+ "chat_template": "chatml",
32
+ "output_dir": "Qwen2.5-MegaHateCat+",
33
+ "run_name": "Qwen2.5-MegaHateCat+",
34
+ "model_id": "unsloth/Qwen2.5-7B",
35
+ "hub_name": "Machlovi/Qwen2.5_MegaHateCatplus",
36
+ "lora_adapter": f"{model_dir}/lora_adapter/qwen25"
37
+ }
38
+ }
39
+
40
+ config = self.model_options[self.selected_model_name]
41
+ model_id = config["model_id"]
42
+ self.chat_template = config["chat_template"]
43
+ lora_adapter = config["lora_adapter"]
44
+
45
+ # Load the model and tokenizer
46
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
49
+ model_name=model_id,
50
+ max_seq_length=self.max_seq_length,
51
+ load_in_4bit=self.load_in_4bit,
52
+ )
53
+
54
+ # Load LoRA adapter
55
+ self.model = PeftModel.from_pretrained(self.model, lora_adapter)
56
+ self.model.eval()
57
+
58
+ # Move model to the device (GPU or CPU)
59
+ self.model.to(self.device)
60
+
61
+ print(f"Loaded model: {self.selected_model_name}")
62
+ print(f"Chat template: {self.chat_template}")
63
+ print(f"LoRA adapter: {lora_adapter}")
64
+
65
+ # Define the policy prompt for safety evaluation
66
+ self.policy_prompt = self._get_policy_prompt()
67
+
68
+ def _get_policy_prompt(self):
69
+ return """Check if there is unsafe content in the user or agent messages in the
70
+ below conversation in accordance with our safety policy with the below categories.
71
+ Provide a safe or unsafe 'rating' for both the user message and the agent response.
72
+ If there is unsafe content, please also provide a list of violated categories according to our safety
73
+ policy below
74
+
75
+ <BEGIN UNSAFE CONTENT CATEGORIES>
76
+ S1: Violence
77
+ S2: Sexual.
78
+ S3: Criminal Planning/Confessions.
79
+ S4: Guns and Illegal Weapons.
80
+ S5: Controlled/Regulated Substances.
81
+ S6: Suicide and Self Harm.
82
+ S7: Sexual (minor).
83
+ S8: Hate/Identity Hate.
84
+ S9: PII/Privacy.
85
+ S10: Harassment.
86
+ S11: Threat.
87
+ S12: Profanity.
88
+ S13: Needs Caution.
89
+ S14: Medical
90
+ S15: Manipulation.
91
+ S16: Fraud/Deception.
92
+ S17: Malware.
93
+ S18: High Risk Gov Decision Making.
94
+ S19: Political/Misinformation/Conspiracy.
95
+ S20: Copyright/Trademark/Plagiarism.
96
+ S21: Unauthorized Advice.
97
+ S22: Illegal Activity.
98
+ S23: Immoral/Unethical
99
+ S24: Other.
100
+ <END UNSAFE CONTENT CATEGORIES>
101
+
102
+ Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below:
103
+
104
+ "Prompt Safety": safe or unsafe rating for the user message.
105
+ "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present.
106
+
107
+ "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5.
108
+ "Response Score": If Unsafe then provide score mentioned which portion of the response is unsafe. Score need to be on scale of 1-5.
109
+
110
+ "User Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe.
111
+ "Agent Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. Omit if no agent response present.
112
+
113
+ Do not include anything other than the output JSON in your response.
114
+ Output JSON:
115
+ """
116
 
117
+ def _format_conversations(self, prompt):
118
+ if self.chat_template == "gemma-3":
119
+ return {
120
+ "conversations": [
121
+ {"role": "system", "content": [{"type": "text", "text": self.policy_prompt}]},
122
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
123
+ ]
124
+ }
125
+ else: # chatml and others
126
+ return {
127
+ "conversations": [
128
+ {"role": "system", "content": self.policy_prompt},
129
+ {"role": "user", "content": prompt},
130
+ ]
131
+ }
132
 
133
+ def extract_and_parse_json(self, response):
134
+ # First try to extract content inside triple backticks
135
+ match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL)
136
+ content = match.group(1).strip() if match else response.strip()
137
+
138
+ # Handle safety category format which might be a special case
139
+ safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Response Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"?([^",]+)"?,\s*"Response Score":\s*"?([^",]+)"?,\s*"User Unsafe Categories":\s*"([^"]*)",\s*"Agent Unsafe Categories":\s*"([^"]*)"', response)
140
+ if safety_match:
141
+ return {
142
+ "Prompt Safety": safety_match.group(1),
143
+ "Response Safety": safety_match.group(2),
144
+ "Prompt Score": safety_match.group(3),
145
+ "Response Score": safety_match.group(4),
146
+ "User Unsafe Categories": safety_match.group(5),
147
+ "Agent Unsafe Categories": safety_match.group(6)
148
+ }
149
+
150
+ # If it looks like key-value pairs but not inside {}, wrap it
151
+ if not content.startswith("{") and ":" in content:
152
+ content = "{" + content + "}"
153
+
154
+ try:
155
+ parsed = json.loads(content)
156
+ except json.JSONDecodeError:
157
+ # Try cleaning up quotes or common issues
158
+ cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"")
159
+
160
+ # Handle trailing commas which are common mistakes
161
+ cleaned = re.sub(r',\s*}', '}', cleaned)
162
+ cleaned = re.sub(r',\s*]', ']', cleaned)
163
+
164
+ try:
165
+ parsed = json.loads(cleaned)
166
+ except Exception as e:
167
+ # Try to extract key-value pairs as a last resort
168
+ pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content)
169
+ if pairs:
170
+ parsed = {k.strip(): v.strip() for k, v in pairs}
171
+ else:
172
+ parsed = {
173
+ "Prompt Safety": "unknown",
174
+ "Response Safety": "unknown",
175
+ "Prompt Score": "",
176
+ "Response Score": "",
177
+ "User Unsafe Categories": "",
178
+ "Agent Unsafe Categories": "",
179
+ }
180
+
181
+ return parsed
182
 
183
+ def __call__(self, data):
184
+ """
185
+ Run safety check on input conversation
186
+ """
187
+ # Get input from the request
188
+ if isinstance(data, dict):
189
+ if "inputs" in data:
190
+ input_text = data.get("inputs", "")
191
+ elif "conversation" in data:
192
+ input_text = data.get("conversation", "")
193
+ else:
194
+ # Try to find any key that might contain text
195
+ for key, value in data.items():
196
+ if isinstance(value, str) and len(value) > 0:
197
+ input_text = value
198
+ break
199
+ else:
200
+ input_text = ""
201
+ else:
202
+ input_text = str(data)
203
+
204
+ # Check if input is empty
205
+ if not input_text:
206
+ return {
207
+ "is_safe": True,
208
+ "safety_result": {
209
+ "Prompt Safety": "safe",
210
+ "Response Safety": "safe",
211
+ "Prompt Score": "",
212
+ "Response Score": "",
213
+ "User Unsafe Categories": "",
214
+ "Agent Unsafe Categories": ""
215
+ }
216
+ }
217
+
218
+ # Format input with the conversation template based on model type
219
+ formatted_input = f"Please assess the following conversation: {input_text}"
220
+ conversation = self._format_conversations(formatted_input)
221
+
222
+ # Apply the chat template to prepare for the model
223
+ if hasattr(self.tokenizer, "apply_chat_template"):
224
+ prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False)
225
+ else:
226
+ # Fallback if apply_chat_template is not available
227
+ prompt = f"System: {self.policy_prompt}\nUser: {formatted_input}"
228
+
229
+ # Tokenize input and move to the same device as the model
230
+ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
231
+
232
+ # Generate response
233
+ with torch.no_grad():
234
+ text_streamer = TextStreamer(self.tokenizer)
235
+ output = self.model.generate(
236
+ **inputs,
237
+ streamer=text_streamer,
238
+ max_new_tokens=512
239
+ )
240
+
241
+ # Decode the output
242
+ decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
243
+
244
+ # Extract the generated part (after the prompt)
245
+ response_text = decoded_output[len(prompt):].strip()
246
+
247
+ # Parse the response to extract safety assessment
248
+ safety_result = self.extract_and_parse_json(response_text)
249
+
250
+ # Determine if the input is safe or not
251
+ is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \
252
+ safety_result.get("Response Safety", "").lower() == "safe"
253
+
254
+ # Prepare the final response
255
+ response = {
256
+ "is_safe": is_safe,
257
+ "safety_result": safety_result
258
+ }
259
+
260
+ return response
261
 
262
+ # For local testing
263
+ if __name__ == "__main__":
264
+ handler = EndpointHandler("./model")
265
+ test_input = {
266
+ "inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives."
267
+ }
268
+ result = handler(test_input)
269
+ print(json.dumps(result, indent=2))