guardrails-final / backend.py
zazaman's picture
Add comprehensive logging with flush for translation debugging
1ff012c
# backend.py
import importlib
import json
import time
from typing import Tuple, Generator, Any
import config
from llm_clients.base import LlmClient
from llm_clients.performance_utils import apply_performance_optimizations
from english_detector import is_english_by_ascii_letters_only
# Apply performance optimizations early
apply_performance_optimizations()
class OutputGuardrailManager:
"""Manages the loading and application of modular output-specific guardrails."""
def __init__(self, guard_configs: dict):
self.guards = []
print("\nInitializing Modular Output Guardrail Manager...")
for name, g_config in guard_configs.items():
if g_config.get("enabled"):
try:
# Dynamically import the guardrail module
module = importlib.import_module(f"guardrails.{name}")
# Construct the class name from the guardrail name
guard_class_name = name.replace("_", " ").title().replace(" ", "")
guard_class = getattr(module, guard_class_name)
guard_instance = guard_class(g_config)
self.guards.append(guard_instance)
print(f" βœ… Loaded output guardrail: {name}")
except (ModuleNotFoundError, AttributeError, ImportError) as e:
print(f" ⚠️ Could not load output guardrail '{name}': {e}")
def process_complete_output(self, text: str) -> Tuple[str, bool]:
"""Process complete output text through all loaded output guardrails."""
safe = True
current_text = text
for guard in self.guards:
if hasattr(guard, "process_complete_output"):
current_text, safe = guard.process_complete_output(current_text)
if not safe:
return current_text, False
return current_text, True
class Backend:
"""Handles the core logic of processing requests with AI detection and modular output guardrails."""
def __init__(self, output_test_mode: bool = False):
self.output_test_mode = output_test_mode
self._translator_client: LlmClient | None = None
if output_test_mode:
print("\nπŸ“ Output Testing Mode: ENABLED")
print(" Only modular output guardrails will be active.")
self.attack_detector = None
else:
print("\nπŸ”’ AI Detection Mode: ENABLED")
print(" Using finetuned model for input guardrails.")
try:
self.attack_detector = self._load_attack_detector()
except Exception as e:
print(f"⚠️ WARNING: Failed to load attack detector: {e}")
print(" πŸ”„ Falling back to output-only mode for better compatibility")
print(" πŸ’‘ The system will still work with output guardrails only")
self.attack_detector = None
self.output_test_mode = True # Switch to output-only mode
# Initialize output guardrails in both modes
self.output_guardrail_manager = OutputGuardrailManager(config.OUTPUT_GUARDRAILS_CONFIG)
# Initialize attachment guardrails
self.attachment_guardrail_manager = self._load_attachment_guardrails()
self.llm_client = self._load_llm_client()
def _get_translator_client(self) -> LlmClient:
"""Lazily load and return the translation client for non-English text."""
if self._translator_client is not None:
return self._translator_client
translator_cfg = getattr(config, "NON_ENGLISH_TRANSLATOR", {"enabled": False})
if not translator_cfg.get("enabled", False):
raise ImportError("Non-English translator disabled in config.")
provider = translator_cfg.get("provider", "qwen_translator")
provider_cfg = translator_cfg.get("config", {})
try:
module = importlib.import_module(f"llm_clients.{provider}")
client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client"
client_class = getattr(module, client_class_name)
# System prompt not needed (client has its own translation prompt), pass empty string
self._translator_client = client_class(provider_cfg, "")
print(f" 🌍 Translation client loaded: {provider} ({provider_cfg.get('model', '')})")
return self._translator_client
except Exception as e:
raise ImportError(f"Could not load translation client '{provider}': {e}")
def _load_attack_detector(self) -> LlmClient:
"""Loads the attack detection client (finetuned model via FinetunedGuard)."""
ai_config = config.AI_DETECTION_MODE
provider = ai_config["attack_llm_provider"]
llm_config = ai_config["attack_llm_config"]
try:
# Use shared model for finetuned_guard provider to avoid duplicate loading
if provider == "finetuned_guard":
from llm_clients.shared_models import shared_model_manager
model_name = llm_config.get("model_name", "zazaman/fmb")
shared_client = shared_model_manager.get_finetuned_guard_client(model_name)
if shared_client:
print(f" πŸ” Main Attack Detector: Using shared model {model_name}")
return shared_client
else:
raise ImportError(f"Could not get shared finetuned model {model_name}")
else:
# For other providers, load normally
module = importlib.import_module(f"llm_clients.{provider}")
client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client"
client_class = getattr(module, client_class_name)
return client_class(llm_config, "") # No system prompt needed for classification model
except (ModuleNotFoundError, AttributeError, ImportError) as e:
raise ImportError(f"Could not load attack detection client for '{provider}': {e}")
def _load_attachment_guardrails(self):
"""Load and initialize attachment guardrails manager."""
try:
from guardrails.attachments.base import AttachmentGuardrailManager
return AttachmentGuardrailManager(config.ATTACHMENT_GUARDRAILS_CONFIG)
except Exception as e:
print(f"⚠️ Could not load attachment guardrails: {e}")
return None
def _load_llm_client(self) -> LlmClient:
"""Dynamically loads and initializes the configured LLM client."""
provider = config.LLM_PROVIDER
llm_config = config.LLM_CONFIG.get(provider)
if llm_config is None:
raise ValueError(f"LLM provider '{provider}' not configured in config.py")
try:
module = importlib.import_module(f"llm_clients.{provider}")
client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client"
client_class = getattr(module, client_class_name)
return client_class(llm_config, config.SYSTEM_PROMPT)
except (ModuleNotFoundError, AttributeError, ImportError) as e:
raise ImportError(f"Could not load LLM client for '{provider}': {e}")
def _check_with_ai_detector(self, prompt: str) -> Tuple[bool, str]:
"""
Checks the prompt with the AI attack detector (finetuned model).
If the prompt is non-English, translates it to English first, then classifies.
Returns (is_safe, reason).
"""
original_prompt = prompt
translated_prompt = prompt
# Check if prompt is non-English and translate if needed
if not is_english_by_ascii_letters_only(prompt):
try:
print("🌍 Detected non-English input. Translating to English...", flush=True)
print(f" Original text: '{prompt[:100]}...'", flush=True)
translator_client = self._get_translator_client()
translation_start = time.time()
translated_prompt = translator_client.generate_content(prompt)
translation_time = (time.time() - translation_start) * 1000
print(f" βœ… Translated to English ({translation_time:.1f}ms): '{translated_prompt[:200]}...'", flush=True)
print(f" πŸ” Will classify translated text (length: {len(translated_prompt)} chars)", flush=True)
except Exception as e:
error_msg = str(e)
print(f"⚠️ Translation failed: {error_msg}", flush=True)
print(f" Proceeding with original text (may cause classification issues).", flush=True)
# Continue with original prompt - the classifier might still work or fail gracefully
translated_prompt = prompt
else:
print(f" βœ… Text is English, no translation needed", flush=True)
translated_prompt = prompt
try:
# Measure classification latency (always use ModernBERT on translated/English text)
print(f" πŸ” Classifying text: '{translated_prompt[:100]}...'", flush=True)
print(f" Text length: {len(translated_prompt)} chars", flush=True)
start_time = time.time()
response = self.attack_detector.generate_content(translated_prompt)
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Parse the JSON response
json_response = self._extract_json_from_response(response)
try:
result = json.loads(json_response)
safety_status = result.get("safety_status", "unsafe")
attack_type = result.get("attack_type", "unknown")
confidence = result.get("confidence", 1.0)
reason = result.get("reason", "No specific reason provided")
is_safe = safety_status.lower() == "safe"
print(f" πŸ“Š Classification result: safety_status='{safety_status}', is_safe={is_safe}, confidence={confidence:.2f}", flush=True)
if not is_safe:
block_reason = f"πŸ€– AI Security Scanner: Detected {attack_type} attack (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms). Reason: {reason}"
if original_prompt != translated_prompt:
block_reason += f" [Original non-English text was translated to English for analysis]"
print(f"🚨 Attack detected: {attack_type} - {reason} (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)")
return False, block_reason
else:
safe_msg = f"βœ… AI Security Scanner: Prompt classified as safe (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)"
if original_prompt != translated_prompt:
safe_msg += f" [Non-English text was translated to English for analysis]"
print(safe_msg)
return True, ""
except json.JSONDecodeError as e:
print(f"⚠️ Could not parse AI detector JSON response: {json_response}")
print(f" JSON Error: {e}")
print(f" Full response: {response[:200]}...")
# Default to unsafe if we can't parse the response
return False, f"πŸ€– AI Security Scanner: Could not parse security analysis (latency: {latency_ms:.1f}ms). Request blocked for safety."
except Exception as e:
print(f"❌ Error communicating with AI attack detector: {e}")
# Default to unsafe if there's an error
return False, f"πŸ€– AI Security Scanner: Error during security analysis: {str(e)}. Request blocked for safety."
def _extract_json_from_response(self, response: str) -> str:
"""
Extract JSON from the response, handling thinking tags and other extra content.
"""
# Remove thinking tags if present
if "<think>" in response:
# Find the end of thinking tag and get everything after it
think_end = response.find("</think>")
if think_end != -1:
response = response[think_end + 8:].strip()
# Look for JSON object boundaries
json_start = response.find("{")
if json_start == -1:
return response.strip()
# Find the matching closing brace
brace_count = 0
json_end = -1
for i in range(json_start, len(response)):
if response[i] == "{":
brace_count += 1
elif response[i] == "}":
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
if json_end == -1:
# If we can't find proper JSON boundaries, return everything after the first {
return response[json_start:].strip()
return response[json_start:json_end].strip()
def _adapt_stream_to_text(self, stream: Generator[Any, None, None]) -> Generator[str, None, None]:
"""
Adapts an LLM client's output stream into a consistent stream of text chunks.
This is necessary because different LLM clients may yield different object types.
"""
# The Gemini client yields `GenerateContentResponse` objects. We need to extract the text.
if config.LLM_PROVIDER == "gemini":
for chunk in stream:
if hasattr(chunk, 'text'):
yield chunk.text
# Other clients are expected to yield strings directly.
else:
yield from stream
def _apply_output_guardrails_to_stream(self, stream: Generator[str, None, None]) -> Generator[str, None, None]:
"""
Applies output guardrails to a streaming response by collecting the full response first,
then processing it through guardrails and yielding the result.
"""
# Collect the full response from the stream
full_response = ""
for chunk in stream:
full_response += chunk
# Apply output guardrails to the complete response
processed_response, is_safe = self.output_guardrail_manager.process_complete_output(full_response)
if not is_safe:
# If blocked, yield the block message
yield processed_response
else:
# If safe, yield the processed response (may be anonymized)
yield processed_response
def process_request(
self, prompt: str, stream: bool = False
) -> Tuple[Any, bool, str]:
"""
Processes a request by applying AI detection, calling the LLM, and returning the response.
Returns:
- The response (blocked message or stream)
- A boolean indicating if the request was safe
- The processed prompt that was sent to the LLM
"""
if not self.output_test_mode:
# Check with AI detector (finetuned model)
is_safe, block_reason = self._check_with_ai_detector(prompt)
if not is_safe:
return block_reason, False, prompt
# Prompt may have been translated in _check_with_ai_detector, but we use original for LLM
processed_prompt = prompt
else:
# In output test mode, skip AI detection
processed_prompt = prompt
# Send to LLM
if stream:
response_stream = self.llm_client.generate_content_stream(processed_prompt)
# Adapt the stream to a consistent text-only stream
text_stream = self._adapt_stream_to_text(response_stream)
# Apply output guardrails to streaming response
processed_stream = self._apply_output_guardrails_to_stream(text_stream)
return processed_stream, True, processed_prompt
else:
# For non-streaming, we expect a simple string response from the client
response = self.llm_client.generate_content(processed_prompt)
# Apply output guardrails to complete response
processed_response, is_safe = self.output_guardrail_manager.process_complete_output(response)
if not is_safe:
return processed_response, False, processed_prompt
return processed_response, True, processed_prompt
def test_output_guardrails(self, prompt: str, manual_output: str) -> Tuple[str, bool]:
"""
Test modular output guardrails with manual input. This method is specifically
for output testing mode where users provide both prompt and expected output.
"""
if not self.output_test_mode:
raise ValueError("Backend must be initialized in output_test_mode to use this method")
print(f"\nπŸ” Testing modular output guardrails on provided text...")
print(f" Input length: {len(manual_output)} characters")
# Process the manual output through modular output guardrails
processed_output, is_safe = self.output_guardrail_manager.process_complete_output(manual_output)
if not is_safe:
print(f"πŸ”’ Output was BLOCKED by guardrails")
return processed_output, False
else:
print(f"βœ… Output passed all guardrails")
if processed_output != manual_output:
print(f" (Output was modified by guardrails)")
return processed_output, True