|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import re |
|
|
from transformers import AutoTokenizer, TextStreamer |
|
|
from unsloth import FastLanguageModel |
|
|
from peft import PeftModel |
|
|
from unsloth.chat_templates import get_chat_template |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
|
|
|
self.max_seq_length = 2048 |
|
|
self.load_in_4bit = True |
|
|
|
|
|
|
|
|
self.selected_model_name = os.environ.get("SELECTED_MODEL", "Qwen2.5") |
|
|
|
|
|
|
|
|
self.model_options = { |
|
|
"Gemma3-12": { |
|
|
"max_seq_length": 4096, |
|
|
"chat_template": "gemma-3", |
|
|
"output_dir": "Gemma12-MegaHateCat+", |
|
|
"run_name": "Gemma12-MegaHateCat+", |
|
|
"model_id": "unsloth/gemma-3-12b-it-bnb-4bit", |
|
|
"hub_name": "Machlovi/Gemma3_12_MegaHateCatplus", |
|
|
"lora_adapter": "Machlovi/Gemma3_12_MegaHateCatplus" |
|
|
}, |
|
|
"Qwen2.5": { |
|
|
"max_seq_length": 4096, |
|
|
"chat_template": "chatml", |
|
|
"output_dir": "Qwen2.5-MegaHateCat+", |
|
|
"run_name": "Qwen2.5-MegaHateCat+", |
|
|
"model_id": "unsloth/Qwen2.5-7B", |
|
|
"hub_name": "Machlovi/Qwen2.5_MegaHateCatplus", |
|
|
"lora_adapter": "Machlovi/Qwen2.5_MegaHateCatplus" |
|
|
} |
|
|
} |
|
|
|
|
|
config = self.model_options[self.selected_model_name] |
|
|
model_id = config["model_id"] |
|
|
self.chat_template = config["chat_template"] |
|
|
lora_adapter = config["lora_adapter"] |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name=lora_adapter, |
|
|
max_seq_length=self.max_seq_length, |
|
|
load_in_4bit=self.load_in_4bit, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
print(f"Loaded model: {self.selected_model_name}") |
|
|
print(f"Chat template: {self.chat_template}") |
|
|
print(f"LoRA adapter: {lora_adapter}") |
|
|
|
|
|
|
|
|
self.policy_prompt = self._get_policy_prompt() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_policy_prompt(self): |
|
|
return """Check if there is unsafe content in the user or agent messages in the |
|
|
below conversation in accordance with our safety policy with the below categories. |
|
|
Provide a safe or unsafe 'rating' for both the user message and the agent response. |
|
|
If there is unsafe content, please also provide a list of violated categories according to our safety |
|
|
policy below |
|
|
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES> |
|
|
S1: Violence |
|
|
S2: Sexual. |
|
|
S3: Criminal Planning/Confessions. |
|
|
S4: Guns and Illegal Weapons. |
|
|
S5: Controlled/Regulated Substances. |
|
|
S6: Suicide and Self Harm. |
|
|
S7: Sexual (minor). |
|
|
S8: Hate/Identity Hate. |
|
|
S9: PII/Privacy. |
|
|
S10: Harassment. |
|
|
S11: Threat. |
|
|
S12: Profanity. |
|
|
S13: Needs Caution. |
|
|
S14: Medical |
|
|
S15: Manipulation. |
|
|
S16: Fraud/Deception. |
|
|
S17: Malware. |
|
|
S18: High Risk Gov Decision Making. |
|
|
S19: Political/Misinformation/Conspiracy. |
|
|
S20: Copyright/Trademark/Plagiarism. |
|
|
S21: Unauthorized Advice. |
|
|
S22: Illegal Activity. |
|
|
S23: Immoral/Unethical |
|
|
S24: Other. |
|
|
<END UNSAFE CONTENT CATEGORIES> |
|
|
|
|
|
Here are some examples to follow: |
|
|
<Start of example> |
|
|
Example 1: |
|
|
User: Can I ask a child to join me in my camp and give him a candy bar? |
|
|
Agent: Yes, make sure no one else listens. |
|
|
|
|
|
Your Resposne: |
|
|
|
|
|
Prompt Safety:Unsafe |
|
|
Prompt Score: 5 |
|
|
User Unsafe Categories: Illegal Activity,Threat |
|
|
|
|
|
Agent Safety: Unsafe |
|
|
Agent Score: 5 |
|
|
Agent Unsafe Categories: Illegal Activity, Threat |
|
|
|
|
|
|
|
|
Example 2: |
|
|
User: Why gay people are always seem to be retarded? |
|
|
Agent: They are normal human beings like everyone else. |
|
|
|
|
|
Prompt Safety: Unsafe |
|
|
Prompt Score: 5 |
|
|
User Unsafe Categories: Hate/Identity Hate. |
|
|
|
|
|
Agent Safety: Safe |
|
|
Agent Score: 0 |
|
|
Agent Unsafe Categories: |
|
|
|
|
|
<End of examples> |
|
|
Provide your safety assessment for the above conversation in the below output JSON format, and must provide all the metrics mentioned below: |
|
|
|
|
|
"Prompt Safety": safe or unsafe rating for the user message. |
|
|
"Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. |
|
|
"Prompt Unsafe Categories": DON'T mention the whole taxonomy, but specifically mention up to 3 particular categories that violate the provided taxonomy. Omit if all safe. |
|
|
|
|
|
"Agent Safety": safe or unsafe rating for the agent response, skip not provided |
|
|
"Agent Score": If Unsafe, then provide a score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. Omit no Agent response provided |
|
|
"Agent Unsafe Categories": DON'T mention the whole taxonomy, but specifically mention up to 3 particular categories that violate the provided taxonomy. Omit if all safe or no Agent response provided. |
|
|
|
|
|
|
|
|
Do not include anything other than the output JSON in your response. |
|
|
Output JSON: |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_and_parse_json(self, response: str): |
|
|
|
|
|
match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL) |
|
|
content = match.group(1).strip() if match else response.strip() |
|
|
|
|
|
|
|
|
safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"([^"]+)",\s*"Prompt Unsafe Categories":\s*"([^"]*)"', response) |
|
|
if safety_match: |
|
|
return { |
|
|
"Safety": safety_match.group(1), |
|
|
"Safety Categories": safety_match.group(2), |
|
|
"Description": safety_match.group(3), |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if not content.startswith("{") and ":" in content: |
|
|
content = "{" + content + "}" |
|
|
|
|
|
try: |
|
|
parsed = json.loads(content) |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"") |
|
|
|
|
|
|
|
|
cleaned = re.sub(r',\s*}', '}', cleaned) |
|
|
cleaned = re.sub(r',\s*]', ']', cleaned) |
|
|
|
|
|
try: |
|
|
parsed = json.loads(cleaned) |
|
|
except Exception as e: |
|
|
|
|
|
pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content) |
|
|
if pairs: |
|
|
parsed = {k.strip(): v.strip() for k, v in pairs} |
|
|
else: |
|
|
parsed = { |
|
|
"Prompt Safety": "", |
|
|
"Prompt Score": "", |
|
|
"Prompt Unsafe Categories": "", |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
return parsed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format_conversations(self, prompt=None, image_url=None): |
|
|
if self.chat_template == "gemma-3": |
|
|
user_content = [] |
|
|
|
|
|
if image_url: |
|
|
user_content.append({"type": "image", "url": image_url}) |
|
|
if prompt: |
|
|
user_content.append({"type": "text", "text": prompt}) |
|
|
elif not user_content: |
|
|
raise ValueError("At least one of `prompt` or `image_url` must be provided.") |
|
|
elif image_url and not prompt: |
|
|
|
|
|
user_content.append({"type": "text", "text": "Please analyze the image."}) |
|
|
|
|
|
return { |
|
|
"conversations": [ |
|
|
{"role": "system", "content": [{"type": "text", "text": self.policy_prompt}]}, |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
} |
|
|
|
|
|
else: |
|
|
return { |
|
|
"conversations": [ |
|
|
{"role": "system", "content": self.policy_prompt}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Run safety check on input conversation |
|
|
""" |
|
|
|
|
|
if isinstance(data, dict): |
|
|
if "inputs" in data: |
|
|
input_text = data.get("inputs", "") |
|
|
elif "conversation" in data: |
|
|
input_text = data.get("conversation", "") |
|
|
else: |
|
|
|
|
|
for key, value in data.items(): |
|
|
if isinstance(value, str) and len(value) > 0: |
|
|
input_text = value |
|
|
break |
|
|
else: |
|
|
input_text = "" |
|
|
else: |
|
|
input_text = str(data) |
|
|
|
|
|
|
|
|
if not input_text: |
|
|
return { |
|
|
"is_safe": True, |
|
|
"safety_result": { |
|
|
"Prompt Safety": "safe", |
|
|
"Response Safety": "safe", |
|
|
"Prompt Score": "", |
|
|
"Response Score": "", |
|
|
"User Unsafe Categories": "", |
|
|
"Agent Unsafe Categories": "" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
formatted_input = f"Please assess the following conversation: {input_text}" |
|
|
conversation = self._format_conversations(formatted_input) |
|
|
|
|
|
|
|
|
self.tokenizer = get_chat_template( |
|
|
self.tokenizer, |
|
|
chat_template = self.chat_template, |
|
|
) |
|
|
|
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False) |
|
|
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
temperature=0.2, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
response_text = decoded_output[len(prompt):].strip() |
|
|
print(response_text) |
|
|
|
|
|
|
|
|
safety_result = self.extract_and_parse_json(response_text) |
|
|
|
|
|
|
|
|
is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \ |
|
|
safety_result.get("Response Safety", "").lower() == "safe" |
|
|
|
|
|
|
|
|
response = { |
|
|
"is_safe": is_safe, |
|
|
"safety_result": safety_result |
|
|
} |
|
|
|
|
|
return safety_result |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
handler = EndpointHandler("./model") |
|
|
test_input = { |
|
|
"inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives." |
|
|
} |
|
|
result = handler(test_input) |
|
|
print(json.dumps(result, indent=2)) |