ai-security-future / bob_utils.py
DerivedFunction1's picture
add
7d11eef
import os
import re
import json
import base64
import threading
from pathlib import Path
from typing import Any
import pycountry
# Constants from demo.py
BASE_DIR = Path(".")
HF_TOKEN_PATH = BASE_DIR / "hf_token"
HF_TOKEN = HF_TOKEN_PATH.read_text(encoding="utf-8").strip() or None
if HF_TOKEN is not None:
from huggingface_hub import login
login(token=HF_TOKEN, add_to_git_credential=False)
HF_MODEL = os.environ.get("HF_MODEL", "google/gemma-4-E2B-it")
JAILBREAK_MODEL = os.environ.get("JAILBREAK_MODEL", "DerivedFunction1/xlmr-prompt-injection")
JAILBREAK_THRESHOLD = float(os.environ.get("JAILBREAK_THRESHOLD", "0.65"))
PROMPT_INJECTION_MODEL = os.environ.get(
"PROMPT_INJECTION_MODEL", "protectai/deberta-v3-base-prompt-injection-v2"
)
REFUSAL_LANGUAGE_MODEL = os.environ.get(
"REFUSAL_LANGUAGE_MODEL",
"polyglot-tagger/multilabel-language-identification",
)
SUPPORTED_GEMMA_LANGS = {
"EN", "ES", "FR", "DE", "IT", "PT", "NL",
"DA", "RU", "PL",
"ZH", "JA", "KO", "VI",
"HI", "BN", "TH", "ID", "MS", "MR", "TE", "TA", "GU", "PA",
"AR", "TR", "HE", "SW",
}
SUPPORTED_JAILBREAK_LANGS = {
"EN",
"AR",
"DE",
"ES",
"FR",
"HI",
"IT",
"JA",
"KO",
"NL",
"TH",
"ZH",
}
# Imports for model loading
from transformers import AutoProcessor, Gemma4ForConditionalGeneration, BitsAndBytesConfig, pipeline
# Model loading
print(f"Loading model: {HF_MODEL}")
_processor = AutoProcessor.from_pretrained(HF_MODEL, padding_side="left")
_bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
# llm_int8_enable_fp32_cpu_offload=True,
)
_model = Gemma4ForConditionalGeneration.from_pretrained(
HF_MODEL,
# quantization_config=_bnb_config,
device_map="auto",
)
_GENERATION_CONFIG = {
"max_new_tokens": 8192,
"temperature": 1.2,
"do_sample": True,
"pad_token_id": _processor.tokenizer.eos_token_id,
}
print(f"Loading jailbreak detector: {JAILBREAK_MODEL}")
_jailbreak_pipe = pipeline("text-classification", model=JAILBREAK_MODEL)
print(f"Loading prompt injection detector: {PROMPT_INJECTION_MODEL}")
_prompt_injection_pipe = pipeline("text-classification", model=PROMPT_INJECTION_MODEL)
print(f"Loading refusal language detector: {REFUSAL_LANGUAGE_MODEL}")
_refusal_language_pipe = pipeline("text-classification", model=REFUSAL_LANGUAGE_MODEL)
# Tool call regex and markup stripping (from demo.py)
TOOL_CALL_RE = re.compile(
r"(?:<\|?tool_call\|?>|^)\s*"
r"(?:call:)?(?P<name>[a-zA-Z_][a-zA-Z0-9_\-\s]*?)\s*"
r"(?:\{|\()(?P<args>.*?)(?:\}|\))\s*"
r"(?P<close><\|?tool_call\|?>|<eos>|<end_of_turn>|<turn\|?>|</s>|$)",
re.DOTALL,
)
TOOL_CALL_MARKUP_RE = re.compile(
r"<\|?tool_call\|?>.*?(?:<\|?tool_call\|?>|<eos>|$)",
re.DOTALL,
)
TOOL_RESPONSE_RE = re.compile(
r"<\|?tool_response\|?>.*$",
re.DOTALL,
)
CLEANUP_RE = re.compile(
r"(<\|?turn\|?>|<eos>|</s>|\[REDIRECT\])",
re.DOTALL,
)
THOUGHT_BLOCK_RE = re.compile(
r"<\|?channel\|?>(?:thought\s*)?.*?(?:<channel\|>|$)",
re.DOTALL,
)
QUOTES_RE = re.compile(r"<\|\"\|>")
TOOL_RESPONSE_MARKERS_RE = re.compile(r"<\|?tool_response\|?>", re.DOTALL)
MALFORMED_TOOL_TAIL_RE = re.compile(r"(<\|?tool_call(?:\|)?$|<\|?$|<\|?\?$)")
def _strip_tool_call_markup(text: str) -> str:
cleaned = (text or "").replace("\r", "").strip()
if not cleaned:
return ""
cleaned = QUOTES_RE.sub('"', cleaned)
cleaned = THOUGHT_BLOCK_RE.sub("", cleaned)
cleaned = TOOL_CALL_MARKUP_RE.sub("", cleaned)
cleaned = TOOL_RESPONSE_RE.sub("", cleaned)
# Remove various special tokens and the REDIRECT token if present
cleaned = CLEANUP_RE.sub("", cleaned)
return cleaned.strip()
def _clean_tool_text(text: str) -> str:
cleaned = _strip_tool_call_markup(text)
if not cleaned:
return ""
cleaned = TOOL_RESPONSE_MARKERS_RE.sub("", cleaned)
return cleaned.strip()
def _strip_trailing_malformed_tool_tokens(text: str) -> str:
cleaned = (text or "").strip()
while cleaned:
if MALFORMED_TOOL_TAIL_RE.search(cleaned):
cleaned = cleaned[:-1].rstrip()
continue
break
return cleaned
def _clean_language_detector_text(text: str) -> str:
cleaned = []
for ch in str(text or ""):
if ch.isalpha() or ch.isspace():
cleaned.append(ch)
else:
cleaned.append(" ")
return " ".join("".join(cleaned).split())
def detect_jailbreak(text: str) -> dict:
"""Return detector metadata for a user message."""
result = _jailbreak_pipe(text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).lower()
score = float(result.get("score", 0.0))
unsafe_score = score if label == "unsafe" else (1.0 - score if label == "safe" else score)
return {
"score": unsafe_score,
"blocked": unsafe_score >= JAILBREAK_THRESHOLD,
"predicted_label": label,
}
def detect_prompt_injection(text: str) -> dict:
"""Return detector metadata for a user message using the prompt injection model."""
result = _prompt_injection_pipe(text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).lower()
score = float(result.get("score", 0.0))
# Assuming 'INJECTION' is the unsafe label for this model
unsafe_score = (
score if label.lower() == "injection" else (1.0 - score if label == "safe" else score)
)
return {
"score": unsafe_score,
"blocked": unsafe_score >= JAILBREAK_THRESHOLD, # Reusing JAILBREAK_THRESHOLD for consistency
"predicted_label": label,
}
def detect_refusal_language(text: str) -> str:
cleaned_text = _clean_language_detector_text(text)
result = _refusal_language_pipe(cleaned_text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).upper().strip()
normalized = _normalize_language_label(label)
if normalized in SUPPORTED_GEMMA_LANGS:
return normalized
return "EN"
def detect_preferred_language(text: str) -> str:
cleaned_text = _clean_language_detector_text(text)
result = _refusal_language_pipe(cleaned_text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).upper().strip()
normalized = _normalize_language_label(label)
return normalized or "EN"
def _normalize_language_label(label: str) -> str:
cleaned = str(label or "").strip()
if not cleaned:
return ""
upper = cleaned.upper()
if upper in SUPPORTED_GEMMA_LANGS:
return upper
lowered = cleaned.lower()
lang = pycountry.languages.get(alpha_2=lowered)
if lang is None and len(lowered) == 3:
lang = pycountry.languages.get(alpha_3=lowered)
if lang is None:
try:
lang = pycountry.languages.lookup(cleaned)
except LookupError:
lang = None
if lang is None:
return upper
alpha_2 = getattr(lang, "alpha_2", None)
if alpha_2:
return str(alpha_2).upper()
alpha_3 = getattr(lang, "alpha_3", None)
if alpha_3:
return str(alpha_3).upper()
return upper
def _sanitize_display_text(text: str, system_prompt: str | None = None) -> str:
cleaned = _clean_tool_text(text)
if not cleaned:
return ""
# New logic to handle [{'text': "...", 'type': 'text'}] format
try:
parsed_json = json.loads(cleaned)
if (
isinstance(parsed_json, list)
and len(parsed_json) > 0
and isinstance(parsed_json[0], dict)
and "text" in parsed_json[0]
):
return parsed_json[0]["text"].strip()
except json.JSONDecodeError:
pass # Not a JSON string, proceed with normal text processing
return cleaned.strip()
# These imports are needed for generate_response and generate_response_stream
# They are imported here to avoid circular dependencies with demo.py
from bob_resources import (
connect,
validate,
skip,
clarify_intent,
store_policy,
store_information,
store_app_website,
food_safety_endpoint,
legal_endpoint,
emergency_crisis,
apply_discount,
loyalty_program,
competitor_mentions,
take_order
)
def generate_response(
messages: list,
system_prompt: str,
enable_thinking: bool = False,
) -> str:
full = [{"role": "system", "content": system_prompt}] + messages
full.append({"role": "assistant", "content": ""})
inputs = _processor.apply_chat_template(
full,
tools=[connect, validate, skip, clarify_intent, store_policy,
store_information, store_app_website, food_safety_endpoint, legal_endpoint,
emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order],
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=enable_thinking,
).to(_model.device)
with __import__("torch").no_grad():
out = _model.generate( # pyright: ignore[reportAttributeAccessIssue]
**inputs,
**_GENERATION_CONFIG,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
return _processor.decode(new_tokens, skip_special_tokens=True).strip()
def generate_response_stream(
messages: list,
system_prompt: str,
enable_thinking: bool = False,
):
full = [{"role": "system", "content": system_prompt}] + messages
inputs = _processor.apply_chat_template(
full,
tools=[connect, validate, skip, clarify_intent, store_policy,
store_information, store_app_website, food_safety_endpoint, legal_endpoint,
emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order],
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=enable_thinking,
).to(_model.device)
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(_processor.tokenizer, skip_prompt=True, skip_special_tokens=False)
thread = threading.Thread(
target=_model.generate, # pyright: ignore[reportAttributeAccessIssue]
kwargs={
**inputs,
**_GENERATION_CONFIG,
"streamer": streamer,
},
daemon=True,
)
thread.start()
generated = ""
for chunk in streamer:
generated += chunk
yield chunk # Yield only the new delta chunk
thread.join()