|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import re |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
|
|
|
self.max_seq_length = 2048 |
|
|
|
|
|
|
|
|
self.selected_model_name = os.environ.get("SELECTED_MODEL", "Gemma3-12") |
|
|
|
|
|
|
|
|
self.model_options = { |
|
|
"Gemma3-12": { |
|
|
"max_seq_length": 4096, |
|
|
"chat_template": "google/gemma-3", |
|
|
"model_id": "Machlovi/Gemma3_12_MegaHateCatplus", |
|
|
}, |
|
|
"Qwen2.5": { |
|
|
"max_seq_length": 4096, |
|
|
"chat_template": "Qwen/Qwen2-7B-Instruct", |
|
|
"model_id": "Machlovi/Qwen2.5_MegaHateCatplus", |
|
|
} |
|
|
} |
|
|
|
|
|
config = self.model_options[self.selected_model_name] |
|
|
model_id = config["model_id"] |
|
|
self.chat_template_id = config["chat_template"] |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
local_model_path = os.path.join(model_dir, self.selected_model_name) |
|
|
if os.path.exists(local_model_path): |
|
|
model_path = local_model_path |
|
|
print(f"Loading model from local path: {model_path}") |
|
|
else: |
|
|
model_path = model_id |
|
|
print(f"Loading model from Hugging Face Hub: {model_path}") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
try: |
|
|
template_tokenizer = AutoTokenizer.from_pretrained(self.chat_template_id) |
|
|
if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template: |
|
|
self.tokenizer.chat_template = template_tokenizer.chat_template |
|
|
print(f"Successfully imported chat template from {self.chat_template_id}") |
|
|
else: |
|
|
|
|
|
self._set_fallback_template() |
|
|
except Exception as e: |
|
|
print(f"Failed to import chat template: {e}") |
|
|
|
|
|
self._set_fallback_template() |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.policy_prompt = self._get_policy_prompt() |
|
|
|
|
|
print(f"Loaded model: {self.selected_model_name}") |
|
|
print(f"Chat template ID: {self.chat_template_id}") |
|
|
|
|
|
def _set_fallback_template(self): |
|
|
"""Set fallback chat templates based on model type""" |
|
|
if "gemma" in self.chat_template_id.lower(): |
|
|
self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'assistant' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}""" |
|
|
print("Using fallback Gemma chat template") |
|
|
elif "qwen" in self.chat_template_id.lower(): |
|
|
self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}""" |
|
|
print("Using fallback Qwen chat template") |
|
|
else: |
|
|
|
|
|
self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}System: {{ message['content'] }}\n{% elif message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}Assistant: {% endif %}""" |
|
|
print("Using generic fallback chat template") |
|
|
|
|
|
def _get_policy_prompt(self): |
|
|
return """Check if there is unsafe content in the user or agent messages in the |
|
|
below conversation in accordance with our safety policy with the below categories. |
|
|
Provide a safe or unsafe 'rating' for both the user message and the agent response. |
|
|
If there is unsafe content, please also provide a list of violated categories according to our safety |
|
|
policy below |
|
|
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES> |
|
|
S1: Violence |
|
|
S2: Sexual. |
|
|
S3: Criminal Planning/Confessions. |
|
|
S4: Guns and Illegal Weapons. |
|
|
S5: Controlled/Regulated Substances. |
|
|
S6: Suicide and Self Harm. |
|
|
S7: Sexual (minor). |
|
|
S8: Hate/Identity Hate. |
|
|
S9: PII/Privacy. |
|
|
S10: Harassment. |
|
|
S11: Threat. |
|
|
S12: Profanity. |
|
|
S13: Needs Caution. |
|
|
S14: Medical |
|
|
S15: Manipulation. |
|
|
S16: Fraud/Deception. |
|
|
S17: Malware. |
|
|
S18: High Risk Gov Decision Making. |
|
|
S19: Political/Misinformation/Conspiracy. |
|
|
S20: Copyright/Trademark/Plagiarism. |
|
|
S21: Unauthorized Advice. |
|
|
S22: Illegal Activity. |
|
|
S23: Immoral/Unethical |
|
|
S24: Other. |
|
|
<END UNSAFE CONTENT CATEGORIES> |
|
|
|
|
|
Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below: |
|
|
|
|
|
"Prompt Safety": safe or unsafe rating for the user message. |
|
|
"Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. |
|
|
|
|
|
"Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. |
|
|
"Response Score": If Unsafe then provide score mentioned which portion of the response is unsafe. Score need to be on scale of 1-5. |
|
|
|
|
|
"User Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. |
|
|
"Agent Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. Omit if no agent response present. |
|
|
|
|
|
Do not include anything other than the output JSON in your response. |
|
|
Output JSON: |
|
|
""" |
|
|
|
|
|
def _format_conversations(self, input_text): |
|
|
"""Format the conversation object for the chat template""" |
|
|
return [ |
|
|
{"role": "system", "content": self.policy_prompt}, |
|
|
{"role": "user", "content": f"Please assess the following conversation: {input_text}"} |
|
|
] |
|
|
|
|
|
def extract_and_parse_json(self, response): |
|
|
|
|
|
match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL) |
|
|
content = match.group(1).strip() if match else response.strip() |
|
|
|
|
|
|
|
|
safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Response Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"?([^",]+)"?,\s*"Response Score":\s*"?([^",]+)"?,\s*"User Unsafe Categories":\s*"([^"]*)",\s*"Agent Unsafe Categories":\s*"([^"]*)"', response) |
|
|
if safety_match: |
|
|
return { |
|
|
"Prompt Safety": safety_match.group(1), |
|
|
"Response Safety": safety_match.group(2), |
|
|
"Prompt Score": safety_match.group(3), |
|
|
"Response Score": safety_match.group(4), |
|
|
"User Unsafe Categories": safety_match.group(5), |
|
|
"Agent Unsafe Categories": safety_match.group(6) |
|
|
} |
|
|
|
|
|
|
|
|
if not content.startswith("{") and ":" in content: |
|
|
content = "{" + content + "}" |
|
|
|
|
|
try: |
|
|
parsed = json.loads(content) |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"") |
|
|
|
|
|
|
|
|
cleaned = re.sub(r',\s*}', '}', cleaned) |
|
|
cleaned = re.sub(r',\s*]', ']', cleaned) |
|
|
|
|
|
try: |
|
|
parsed = json.loads(cleaned) |
|
|
except Exception as e: |
|
|
|
|
|
pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content) |
|
|
if pairs: |
|
|
parsed = {k.strip(): v.strip() for k, v in pairs} |
|
|
else: |
|
|
parsed = { |
|
|
"Prompt Safety": "unknown", |
|
|
"Response Safety": "unknown", |
|
|
"Prompt Score": "", |
|
|
"Response Score": "", |
|
|
"User Unsafe Categories": "", |
|
|
"Agent Unsafe Categories": "", |
|
|
} |
|
|
|
|
|
return parsed |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Run safety check on input conversation |
|
|
""" |
|
|
|
|
|
if isinstance(data, dict): |
|
|
if "inputs" in data: |
|
|
input_text = data.get("inputs", "") |
|
|
elif "conversation" in data: |
|
|
input_text = data.get("conversation", "") |
|
|
else: |
|
|
|
|
|
for key, value in data.items(): |
|
|
if isinstance(value, str) and len(value) > 0: |
|
|
input_text = value |
|
|
break |
|
|
else: |
|
|
input_text = "" |
|
|
else: |
|
|
input_text = str(data) |
|
|
|
|
|
|
|
|
if not input_text: |
|
|
return { |
|
|
"is_safe": True, |
|
|
"safety_result": { |
|
|
"Prompt Safety": "safe", |
|
|
"Response Safety": "safe", |
|
|
"Prompt Score": "", |
|
|
"Response Score": "", |
|
|
"User Unsafe Categories": "", |
|
|
"Agent Unsafe Categories": "" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
conversation = self._format_conversations(input_text) |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template: |
|
|
prompt = self.tokenizer.apply_chat_template( |
|
|
conversation, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
else: |
|
|
|
|
|
prompt = f"System: {self.policy_prompt}\nUser: Please assess the following conversation: {input_text}\nAssistant:" |
|
|
except Exception as e: |
|
|
print(f"Error applying chat template: {e}") |
|
|
|
|
|
prompt = f"System: {self.policy_prompt}\nUser: Please assess the following conversation: {input_text}\nAssistant:" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
|
|
|
full_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
response_text = full_output[len(prompt):].strip() |
|
|
|
|
|
|
|
|
safety_result = self.extract_and_parse_json(response_text) |
|
|
|
|
|
|
|
|
is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \ |
|
|
safety_result.get("Response Safety", "").lower() == "safe" |
|
|
|
|
|
|
|
|
response = { |
|
|
"is_safe": is_safe, |
|
|
"safety_result": safety_result |
|
|
} |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
handler = EndpointHandler("./model") |
|
|
test_input = { |
|
|
"inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives." |
|
|
} |
|
|
result = handler(test_input) |
|
|
print(json.dumps(result, indent=2)) |