import os import json import torch import re from transformers import AutoTokenizer, TextStreamer from unsloth import FastLanguageModel from peft import PeftModel from unsloth.chat_templates import get_chat_template class EndpointHandler: def __init__(self, model_dir): # Configuration for your safety model self.max_seq_length = 2048 self.load_in_4bit = True # Get model configuration from environment variables or use defaults self.selected_model_name = os.environ.get("SELECTED_MODEL", "Qwen2.5") # Model configurations self.model_options = { "Gemma3-12": { "max_seq_length": 4096, "chat_template": "gemma-3", "output_dir": "Gemma12-MegaHateCat+", "run_name": "Gemma12-MegaHateCat+", "model_id": "unsloth/gemma-3-12b-it-bnb-4bit", "hub_name": "Machlovi/Gemma3_12_MegaHateCatplus", "lora_adapter": "Machlovi/Gemma3_12_MegaHateCatplus" }, "Qwen2.5": { "max_seq_length": 4096, "chat_template": "chatml", "output_dir": "Qwen2.5-MegaHateCat+", "run_name": "Qwen2.5-MegaHateCat+", "model_id": "unsloth/Qwen2.5-7B", "hub_name": "Machlovi/Qwen2.5_MegaHateCatplus", "lora_adapter": "Machlovi/Qwen2.5_MegaHateCatplus" } } config = self.model_options[self.selected_model_name] model_id = config["model_id"] self.chat_template = config["chat_template"] lora_adapter = config["lora_adapter"] # Load the model and tokenizer self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=model_id, max_seq_length=self.max_seq_length, load_in_4bit=self.load_in_4bit, ) # Load LoRA adapter self.model = PeftModel.from_pretrained(self.model, lora_adapter) self.model.eval() # Move model to the device (GPU or CPU) self.model.to(self.device) print(f"Loaded model: {self.selected_model_name}") print(f"Chat template: {self.chat_template}") print(f"LoRA adapter: {lora_adapter}") # Define the policy prompt for safety evaluation self.policy_prompt = self._get_policy_prompt() # try: # template_tokenizer = self.tokenizer # if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template: # self.tokenizer.chat_template = template_tokenizer.chat_template # print(f"Successfully imported chat template from {self.chat_template_id}") # except Exception as e: # print(f"Failed to import chat template: {e}") def _get_policy_prompt(self): return """Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a safe or unsafe 'rating' for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below S1: Violence S2: Sexual. S3: Criminal Planning/Confessions. S4: Guns and Illegal Weapons. S5: Controlled/Regulated Substances. S6: Suicide and Self Harm. S7: Sexual (minor). S8: Hate/Identity Hate. S9: PII/Privacy. S10: Harassment. S11: Threat. S12: Profanity. S13: Needs Caution. S14: Medical S15: Manipulation. S16: Fraud/Deception. S17: Malware. S18: High Risk Gov Decision Making. S19: Political/Misinformation/Conspiracy. S20: Copyright/Trademark/Plagiarism. S21: Unauthorized Advice. S22: Illegal Activity. S23: Immoral/Unethical S24: Other. Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below: "Prompt Safety": safe or unsafe rating for the user message. "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. "User Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. Do not include anything other than the output JSON in your response. Output JSON: """ def _format_conversations(self, prompt): if self.chat_template == "gemma-3": return { "conversations": [ {"role": "system", "content": [{"type": "text", "text": self.policy_prompt}]}, {"role": "user", "content": [{"type": "text", "text": prompt}]}, ] } else: # chatml and others return { "conversations": [ {"role": "system", "content": self.policy_prompt}, {"role": "user", "content": prompt}, ] } # def extract_and_parse_json(self, response): # # First try to extract content inside triple backticks # match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL) # content = match.group(1).strip() if match else response.strip() # # Handle safety category format which might be a special case # safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Response Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"?([^",]+)"?,\s*"Response Score":\s*"?([^",]+)"?,\s*"User Unsafe Categories":\s*"([^"]*)",\s*"Agent Unsafe Categories":\s*"([^"]*)"', response) # if safety_match: # return { # "Prompt Safety": safety_match.group(1), # "Response Safety": safety_match.group(2), # "Prompt Score": safety_match.group(3), # "Response Score": safety_match.group(4), # "User Unsafe Categories": safety_match.group(5), # "Agent Unsafe Categories": safety_match.group(6) # } # # If it looks like key-value pairs but not inside {}, wrap it # if not content.startswith("{") and ":" in content: # content = "{" + content + "}" # try: # parsed = json.loads(content) # except json.JSONDecodeError: # # Try cleaning up quotes or common issues # cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"") # # Handle trailing commas which are common mistakes # cleaned = re.sub(r',\s*}', '}', cleaned) # cleaned = re.sub(r',\s*]', ']', cleaned) # try: # parsed = json.loads(cleaned) # except Exception as e: # # Try to extract key-value pairs as a last resort # pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content) # if pairs: # parsed = {k.strip(): v.strip() for k, v in pairs} # else: # parsed = { # "Prompt Safety": "unknown", # "Response Safety": "unknown", # "Prompt Score": "", # "Response Score": "", # "User Unsafe Categories": "", # "Agent Unsafe Categories": "", # } # return parsed # def extract_and_parse_json(self, text): # result = { # "Prompt Safety": "unknown", # "Response Safety": "unknown", # "Prompt Score": "", # "Response Score": "", # "User Unsafe Categories": "", # "Agent Unsafe Categories": "" # } # for line in text.splitlines(): # if ":" in line: # key, val = line.split(":", 1) # key = key.strip() # val = val.strip() # if key in result: # result[key] = val # return { # "is_safe": result["Response Safety"] == "safe", # "safety_result": result, # "raw_output": text # } def __call__(self, data): """ Run safety check on input conversation """ # Get input from the request if isinstance(data, dict): if "inputs" in data: input_text = data.get("inputs", "") elif "conversation" in data: input_text = data.get("conversation", "") else: # Try to find any key that might contain text for key, value in data.items(): if isinstance(value, str) and len(value) > 0: input_text = value break else: input_text = "" else: input_text = str(data) # Check if input is empty if not input_text: return { "is_safe": True, "safety_result": { "Prompt Safety": "safe", "Response Safety": "safe", "Prompt Score": "", "Response Score": "", "User Unsafe Categories": "", "Agent Unsafe Categories": "" } } # Format input with the conversation template based on model type formatted_input = f"Please assess the following conversation: {input_text}" conversation = self._format_conversations(formatted_input) self.tokenizer = get_chat_template( self.tokenizer, chat_template = self.chat_template, ) prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False) # Tokenize input and move to the same device as the model inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) # # Generate response # with torch.no_grad(): # text_streamer = TextStreamer(self.tokenizer,skip_prompt=False) # output = self.model.generate( # **inputs, # streamer=text_streamer, # max_new_tokens=512 # ) with torch.no_grad(): output = self.model.generate( **inputs, max_new_tokens=512, do_sample=False, temperature=0.2, ) # Decode the output decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True) # Extract the generated part (after the prompt) # response_text = decoded_output[len(prompt):].strip() # print(response_text) # # Parse the response to extract safety assessment # safety_result = self.extract_and_parse_json(response_text) # # Determine if the input is safe or not # is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \ # safety_result.get("Response Safety", "").lower() == "safe" # # Prepare the final response # response = { # "is_safe": is_safe, # "safety_result": safety_result # } return decoded_output # For local testing if __name__ == "__main__": handler = EndpointHandler("./model") test_input = { "inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives." } result = handler(test_input) print(json.dumps(result, indent=2))