import os import re import json import base64 import threading from pathlib import Path from typing import Any import pycountry # Constants from demo.py BASE_DIR = Path(".") HF_TOKEN_PATH = BASE_DIR / "hf_token" HF_TOKEN = HF_TOKEN_PATH.read_text(encoding="utf-8").strip() or None if HF_TOKEN is not None: from huggingface_hub import login login(token=HF_TOKEN, add_to_git_credential=False) HF_MODEL = os.environ.get("HF_MODEL", "google/gemma-4-E2B-it") JAILBREAK_MODEL = os.environ.get("JAILBREAK_MODEL", "DerivedFunction1/xlmr-prompt-injection") JAILBREAK_THRESHOLD = float(os.environ.get("JAILBREAK_THRESHOLD", "0.65")) PROMPT_INJECTION_MODEL = os.environ.get( "PROMPT_INJECTION_MODEL", "protectai/deberta-v3-base-prompt-injection-v2" ) REFUSAL_LANGUAGE_MODEL = os.environ.get( "REFUSAL_LANGUAGE_MODEL", "polyglot-tagger/multilabel-language-identification", ) SUPPORTED_GEMMA_LANGS = { "EN", "ES", "FR", "DE", "IT", "PT", "NL", "DA", "RU", "PL", "ZH", "JA", "KO", "VI", "HI", "BN", "TH", "ID", "MS", "MR", "TE", "TA", "GU", "PA", "AR", "TR", "HE", "SW", } SUPPORTED_JAILBREAK_LANGS = { "EN", "AR", "DE", "ES", "FR", "HI", "IT", "JA", "KO", "NL", "TH", "ZH", } # Imports for model loading from transformers import AutoProcessor, Gemma4ForConditionalGeneration, BitsAndBytesConfig, pipeline # Model loading print(f"Loading model: {HF_MODEL}") _processor = AutoProcessor.from_pretrained(HF_MODEL, padding_side="left") _bnb_config = BitsAndBytesConfig( load_in_8bit=True, # llm_int8_enable_fp32_cpu_offload=True, ) _model = Gemma4ForConditionalGeneration.from_pretrained( HF_MODEL, # quantization_config=_bnb_config, device_map="auto", ) _GENERATION_CONFIG = { "max_new_tokens": 8192, "temperature": 1.2, "do_sample": True, "pad_token_id": _processor.tokenizer.eos_token_id, } print(f"Loading jailbreak detector: {JAILBREAK_MODEL}") _jailbreak_pipe = pipeline("text-classification", model=JAILBREAK_MODEL) print(f"Loading prompt injection detector: {PROMPT_INJECTION_MODEL}") _prompt_injection_pipe = pipeline("text-classification", model=PROMPT_INJECTION_MODEL) print(f"Loading refusal language detector: {REFUSAL_LANGUAGE_MODEL}") _refusal_language_pipe = pipeline("text-classification", model=REFUSAL_LANGUAGE_MODEL) # Tool call regex and markup stripping (from demo.py) TOOL_CALL_RE = re.compile( r"(?:<\|?tool_call\|?>|^)\s*" r"(?:call:)?(?P[a-zA-Z_][a-zA-Z0-9_\-\s]*?)\s*" r"(?:\{|\()(?P.*?)(?:\}|\))\s*" r"(?P<\|?tool_call\|?>|||||$)", re.DOTALL, ) TOOL_CALL_MARKUP_RE = re.compile( r"<\|?tool_call\|?>.*?(?:<\|?tool_call\|?>||$)", re.DOTALL, ) TOOL_RESPONSE_RE = re.compile( r"<\|?tool_response\|?>.*$", re.DOTALL, ) CLEANUP_RE = re.compile( r"(<\|?turn\|?>|||\[REDIRECT\])", re.DOTALL, ) THOUGHT_BLOCK_RE = re.compile( r"<\|?channel\|?>(?:thought\s*)?.*?(?:|$)", re.DOTALL, ) QUOTES_RE = re.compile(r"<\|\"\|>") TOOL_RESPONSE_MARKERS_RE = re.compile(r"<\|?tool_response\|?>", re.DOTALL) MALFORMED_TOOL_TAIL_RE = re.compile(r"(<\|?tool_call(?:\|)?$|<\|?$|<\|?\?$)") def _strip_tool_call_markup(text: str) -> str: cleaned = (text or "").replace("\r", "").strip() if not cleaned: return "" cleaned = QUOTES_RE.sub('"', cleaned) cleaned = THOUGHT_BLOCK_RE.sub("", cleaned) cleaned = TOOL_CALL_MARKUP_RE.sub("", cleaned) cleaned = TOOL_RESPONSE_RE.sub("", cleaned) # Remove various special tokens and the REDIRECT token if present cleaned = CLEANUP_RE.sub("", cleaned) return cleaned.strip() def _clean_tool_text(text: str) -> str: cleaned = _strip_tool_call_markup(text) if not cleaned: return "" cleaned = TOOL_RESPONSE_MARKERS_RE.sub("", cleaned) return cleaned.strip() def _strip_trailing_malformed_tool_tokens(text: str) -> str: cleaned = (text or "").strip() while cleaned: if MALFORMED_TOOL_TAIL_RE.search(cleaned): cleaned = cleaned[:-1].rstrip() continue break return cleaned def _clean_language_detector_text(text: str) -> str: cleaned = [] for ch in str(text or ""): if ch.isalpha() or ch.isspace(): cleaned.append(ch) else: cleaned.append(" ") return " ".join("".join(cleaned).split()) def detect_jailbreak(text: str) -> dict: """Return detector metadata for a user message.""" result = _jailbreak_pipe(text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).lower() score = float(result.get("score", 0.0)) unsafe_score = score if label == "unsafe" else (1.0 - score if label == "safe" else score) return { "score": unsafe_score, "blocked": unsafe_score >= JAILBREAK_THRESHOLD, "predicted_label": label, } def detect_prompt_injection(text: str) -> dict: """Return detector metadata for a user message using the prompt injection model.""" result = _prompt_injection_pipe(text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).lower() score = float(result.get("score", 0.0)) # Assuming 'INJECTION' is the unsafe label for this model unsafe_score = ( score if label.lower() == "injection" else (1.0 - score if label == "safe" else score) ) return { "score": unsafe_score, "blocked": unsafe_score >= JAILBREAK_THRESHOLD, # Reusing JAILBREAK_THRESHOLD for consistency "predicted_label": label, } def detect_refusal_language(text: str) -> str: cleaned_text = _clean_language_detector_text(text) result = _refusal_language_pipe(cleaned_text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).upper().strip() normalized = _normalize_language_label(label) if normalized in SUPPORTED_GEMMA_LANGS: return normalized return "EN" def detect_preferred_language(text: str) -> str: cleaned_text = _clean_language_detector_text(text) result = _refusal_language_pipe(cleaned_text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).upper().strip() normalized = _normalize_language_label(label) return normalized or "EN" def _normalize_language_label(label: str) -> str: cleaned = str(label or "").strip() if not cleaned: return "" upper = cleaned.upper() if upper in SUPPORTED_GEMMA_LANGS: return upper lowered = cleaned.lower() lang = pycountry.languages.get(alpha_2=lowered) if lang is None and len(lowered) == 3: lang = pycountry.languages.get(alpha_3=lowered) if lang is None: try: lang = pycountry.languages.lookup(cleaned) except LookupError: lang = None if lang is None: return upper alpha_2 = getattr(lang, "alpha_2", None) if alpha_2: return str(alpha_2).upper() alpha_3 = getattr(lang, "alpha_3", None) if alpha_3: return str(alpha_3).upper() return upper def _sanitize_display_text(text: str, system_prompt: str | None = None) -> str: cleaned = _clean_tool_text(text) if not cleaned: return "" # New logic to handle [{'text': "...", 'type': 'text'}] format try: parsed_json = json.loads(cleaned) if ( isinstance(parsed_json, list) and len(parsed_json) > 0 and isinstance(parsed_json[0], dict) and "text" in parsed_json[0] ): return parsed_json[0]["text"].strip() except json.JSONDecodeError: pass # Not a JSON string, proceed with normal text processing return cleaned.strip() # These imports are needed for generate_response and generate_response_stream # They are imported here to avoid circular dependencies with demo.py from bob_resources import ( connect, validate, skip, clarify_intent, store_policy, store_information, store_app_website, food_safety_endpoint, legal_endpoint, emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order ) def generate_response( messages: list, system_prompt: str, enable_thinking: bool = False, ) -> str: full = [{"role": "system", "content": system_prompt}] + messages full.append({"role": "assistant", "content": ""}) inputs = _processor.apply_chat_template( full, tools=[connect, validate, skip, clarify_intent, store_policy, store_information, store_app_website, food_safety_endpoint, legal_endpoint, emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, enable_thinking=enable_thinking, ).to(_model.device) with __import__("torch").no_grad(): out = _model.generate( # pyright: ignore[reportAttributeAccessIssue] **inputs, **_GENERATION_CONFIG, ) new_tokens = out[0][inputs["input_ids"].shape[1]:] return _processor.decode(new_tokens, skip_special_tokens=True).strip() def generate_response_stream( messages: list, system_prompt: str, enable_thinking: bool = False, ): full = [{"role": "system", "content": system_prompt}] + messages inputs = _processor.apply_chat_template( full, tools=[connect, validate, skip, clarify_intent, store_policy, store_information, store_app_website, food_safety_endpoint, legal_endpoint, emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, enable_thinking=enable_thinking, ).to(_model.device) from transformers import TextIteratorStreamer streamer = TextIteratorStreamer(_processor.tokenizer, skip_prompt=True, skip_special_tokens=False) thread = threading.Thread( target=_model.generate, # pyright: ignore[reportAttributeAccessIssue] kwargs={ **inputs, **_GENERATION_CONFIG, "streamer": streamer, }, daemon=True, ) thread.start() generated = "" for chunk in streamer: generated += chunk yield chunk # Yield only the new delta chunk thread.join()