# backend.py import importlib import json import time from typing import Tuple, Generator, Any import config from llm_clients.base import LlmClient from llm_clients.performance_utils import apply_performance_optimizations from english_detector import is_english_by_ascii_letters_only # Apply performance optimizations early apply_performance_optimizations() class OutputGuardrailManager: """Manages the loading and application of modular output-specific guardrails.""" def __init__(self, guard_configs: dict): self.guards = [] print("\nInitializing Modular Output Guardrail Manager...") for name, g_config in guard_configs.items(): if g_config.get("enabled"): try: # Dynamically import the guardrail module module = importlib.import_module(f"guardrails.{name}") # Construct the class name from the guardrail name guard_class_name = name.replace("_", " ").title().replace(" ", "") guard_class = getattr(module, guard_class_name) guard_instance = guard_class(g_config) self.guards.append(guard_instance) print(f" āœ… Loaded output guardrail: {name}") except (ModuleNotFoundError, AttributeError, ImportError) as e: print(f" āš ļø Could not load output guardrail '{name}': {e}") def process_complete_output(self, text: str) -> Tuple[str, bool]: """Process complete output text through all loaded output guardrails.""" safe = True current_text = text for guard in self.guards: if hasattr(guard, "process_complete_output"): current_text, safe = guard.process_complete_output(current_text) if not safe: return current_text, False return current_text, True class Backend: """Handles the core logic of processing requests with AI detection and modular output guardrails.""" def __init__(self, output_test_mode: bool = False): self.output_test_mode = output_test_mode self._translator_client: LlmClient | None = None if output_test_mode: print("\nšŸ“ Output Testing Mode: ENABLED") print(" Only modular output guardrails will be active.") self.attack_detector = None else: print("\nšŸ”’ AI Detection Mode: ENABLED") print(" Using finetuned model for input guardrails.") try: self.attack_detector = self._load_attack_detector() except Exception as e: print(f"āš ļø WARNING: Failed to load attack detector: {e}") print(" šŸ”„ Falling back to output-only mode for better compatibility") print(" šŸ’” The system will still work with output guardrails only") self.attack_detector = None self.output_test_mode = True # Switch to output-only mode # Initialize output guardrails in both modes self.output_guardrail_manager = OutputGuardrailManager(config.OUTPUT_GUARDRAILS_CONFIG) # Initialize attachment guardrails self.attachment_guardrail_manager = self._load_attachment_guardrails() self.llm_client = self._load_llm_client() def _get_translator_client(self) -> LlmClient: """Lazily load and return the translation client for non-English text.""" if self._translator_client is not None: return self._translator_client translator_cfg = getattr(config, "NON_ENGLISH_TRANSLATOR", {"enabled": False}) if not translator_cfg.get("enabled", False): raise ImportError("Non-English translator disabled in config.") provider = translator_cfg.get("provider", "qwen_translator") provider_cfg = translator_cfg.get("config", {}) try: module = importlib.import_module(f"llm_clients.{provider}") client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client" client_class = getattr(module, client_class_name) # System prompt not needed (client has its own translation prompt), pass empty string self._translator_client = client_class(provider_cfg, "") print(f" šŸŒ Translation client loaded: {provider} ({provider_cfg.get('model', '')})") return self._translator_client except Exception as e: raise ImportError(f"Could not load translation client '{provider}': {e}") def _load_attack_detector(self) -> LlmClient: """Loads the attack detection client (finetuned model via FinetunedGuard).""" ai_config = config.AI_DETECTION_MODE provider = ai_config["attack_llm_provider"] llm_config = ai_config["attack_llm_config"] try: # Use shared model for finetuned_guard provider to avoid duplicate loading if provider == "finetuned_guard": from llm_clients.shared_models import shared_model_manager model_name = llm_config.get("model_name", "zazaman/fmb") shared_client = shared_model_manager.get_finetuned_guard_client(model_name) if shared_client: print(f" šŸ” Main Attack Detector: Using shared model {model_name}") return shared_client else: raise ImportError(f"Could not get shared finetuned model {model_name}") else: # For other providers, load normally module = importlib.import_module(f"llm_clients.{provider}") client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client" client_class = getattr(module, client_class_name) return client_class(llm_config, "") # No system prompt needed for classification model except (ModuleNotFoundError, AttributeError, ImportError) as e: raise ImportError(f"Could not load attack detection client for '{provider}': {e}") def _load_attachment_guardrails(self): """Load and initialize attachment guardrails manager.""" try: from guardrails.attachments.base import AttachmentGuardrailManager return AttachmentGuardrailManager(config.ATTACHMENT_GUARDRAILS_CONFIG) except Exception as e: print(f"āš ļø Could not load attachment guardrails: {e}") return None def _load_llm_client(self) -> LlmClient: """Dynamically loads and initializes the configured LLM client.""" provider = config.LLM_PROVIDER llm_config = config.LLM_CONFIG.get(provider) if llm_config is None: raise ValueError(f"LLM provider '{provider}' not configured in config.py") try: module = importlib.import_module(f"llm_clients.{provider}") client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client" client_class = getattr(module, client_class_name) return client_class(llm_config, config.SYSTEM_PROMPT) except (ModuleNotFoundError, AttributeError, ImportError) as e: raise ImportError(f"Could not load LLM client for '{provider}': {e}") def _check_with_ai_detector(self, prompt: str) -> Tuple[bool, str]: """ Checks the prompt with the AI attack detector (finetuned model). If the prompt is non-English, translates it to English first, then classifies. Returns (is_safe, reason). """ original_prompt = prompt translated_prompt = prompt # Check if prompt is non-English and translate if needed if not is_english_by_ascii_letters_only(prompt): try: print("šŸŒ Detected non-English input. Translating to English...", flush=True) print(f" Original text: '{prompt[:100]}...'", flush=True) translator_client = self._get_translator_client() translation_start = time.time() translated_prompt = translator_client.generate_content(prompt) translation_time = (time.time() - translation_start) * 1000 print(f" āœ… Translated to English ({translation_time:.1f}ms): '{translated_prompt[:200]}...'", flush=True) print(f" šŸ” Will classify translated text (length: {len(translated_prompt)} chars)", flush=True) except Exception as e: error_msg = str(e) print(f"āš ļø Translation failed: {error_msg}", flush=True) print(f" Proceeding with original text (may cause classification issues).", flush=True) # Continue with original prompt - the classifier might still work or fail gracefully translated_prompt = prompt else: print(f" āœ… Text is English, no translation needed", flush=True) translated_prompt = prompt try: # Measure classification latency (always use ModernBERT on translated/English text) print(f" šŸ” Classifying text: '{translated_prompt[:100]}...'", flush=True) print(f" Text length: {len(translated_prompt)} chars", flush=True) start_time = time.time() response = self.attack_detector.generate_content(translated_prompt) end_time = time.time() latency_ms = (end_time - start_time) * 1000 # Parse the JSON response json_response = self._extract_json_from_response(response) try: result = json.loads(json_response) safety_status = result.get("safety_status", "unsafe") attack_type = result.get("attack_type", "unknown") confidence = result.get("confidence", 1.0) reason = result.get("reason", "No specific reason provided") is_safe = safety_status.lower() == "safe" print(f" šŸ“Š Classification result: safety_status='{safety_status}', is_safe={is_safe}, confidence={confidence:.2f}", flush=True) if not is_safe: block_reason = f"šŸ¤– AI Security Scanner: Detected {attack_type} attack (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms). Reason: {reason}" if original_prompt != translated_prompt: block_reason += f" [Original non-English text was translated to English for analysis]" print(f"🚨 Attack detected: {attack_type} - {reason} (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)") return False, block_reason else: safe_msg = f"āœ… AI Security Scanner: Prompt classified as safe (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)" if original_prompt != translated_prompt: safe_msg += f" [Non-English text was translated to English for analysis]" print(safe_msg) return True, "" except json.JSONDecodeError as e: print(f"āš ļø Could not parse AI detector JSON response: {json_response}") print(f" JSON Error: {e}") print(f" Full response: {response[:200]}...") # Default to unsafe if we can't parse the response return False, f"šŸ¤– AI Security Scanner: Could not parse security analysis (latency: {latency_ms:.1f}ms). Request blocked for safety." except Exception as e: print(f"āŒ Error communicating with AI attack detector: {e}") # Default to unsafe if there's an error return False, f"šŸ¤– AI Security Scanner: Error during security analysis: {str(e)}. Request blocked for safety." def _extract_json_from_response(self, response: str) -> str: """ Extract JSON from the response, handling thinking tags and other extra content. """ # Remove thinking tags if present if "" in response: # Find the end of thinking tag and get everything after it think_end = response.find("") if think_end != -1: response = response[think_end + 8:].strip() # Look for JSON object boundaries json_start = response.find("{") if json_start == -1: return response.strip() # Find the matching closing brace brace_count = 0 json_end = -1 for i in range(json_start, len(response)): if response[i] == "{": brace_count += 1 elif response[i] == "}": brace_count -= 1 if brace_count == 0: json_end = i + 1 break if json_end == -1: # If we can't find proper JSON boundaries, return everything after the first { return response[json_start:].strip() return response[json_start:json_end].strip() def _adapt_stream_to_text(self, stream: Generator[Any, None, None]) -> Generator[str, None, None]: """ Adapts an LLM client's output stream into a consistent stream of text chunks. This is necessary because different LLM clients may yield different object types. """ # The Gemini client yields `GenerateContentResponse` objects. We need to extract the text. if config.LLM_PROVIDER == "gemini": for chunk in stream: if hasattr(chunk, 'text'): yield chunk.text # Other clients are expected to yield strings directly. else: yield from stream def _apply_output_guardrails_to_stream(self, stream: Generator[str, None, None]) -> Generator[str, None, None]: """ Applies output guardrails to a streaming response by collecting the full response first, then processing it through guardrails and yielding the result. """ # Collect the full response from the stream full_response = "" for chunk in stream: full_response += chunk # Apply output guardrails to the complete response processed_response, is_safe = self.output_guardrail_manager.process_complete_output(full_response) if not is_safe: # If blocked, yield the block message yield processed_response else: # If safe, yield the processed response (may be anonymized) yield processed_response def process_request( self, prompt: str, stream: bool = False ) -> Tuple[Any, bool, str]: """ Processes a request by applying AI detection, calling the LLM, and returning the response. Returns: - The response (blocked message or stream) - A boolean indicating if the request was safe - The processed prompt that was sent to the LLM """ if not self.output_test_mode: # Check with AI detector (finetuned model) is_safe, block_reason = self._check_with_ai_detector(prompt) if not is_safe: return block_reason, False, prompt # Prompt may have been translated in _check_with_ai_detector, but we use original for LLM processed_prompt = prompt else: # In output test mode, skip AI detection processed_prompt = prompt # Send to LLM if stream: response_stream = self.llm_client.generate_content_stream(processed_prompt) # Adapt the stream to a consistent text-only stream text_stream = self._adapt_stream_to_text(response_stream) # Apply output guardrails to streaming response processed_stream = self._apply_output_guardrails_to_stream(text_stream) return processed_stream, True, processed_prompt else: # For non-streaming, we expect a simple string response from the client response = self.llm_client.generate_content(processed_prompt) # Apply output guardrails to complete response processed_response, is_safe = self.output_guardrail_manager.process_complete_output(response) if not is_safe: return processed_response, False, processed_prompt return processed_response, True, processed_prompt def test_output_guardrails(self, prompt: str, manual_output: str) -> Tuple[str, bool]: """ Test modular output guardrails with manual input. This method is specifically for output testing mode where users provide both prompt and expected output. """ if not self.output_test_mode: raise ValueError("Backend must be initialized in output_test_mode to use this method") print(f"\nšŸ” Testing modular output guardrails on provided text...") print(f" Input length: {len(manual_output)} characters") # Process the manual output through modular output guardrails processed_output, is_safe = self.output_guardrail_manager.process_complete_output(manual_output) if not is_safe: print(f"šŸ”’ Output was BLOCKED by guardrails") return processed_output, False else: print(f"āœ… Output passed all guardrails") if processed_output != manual_output: print(f" (Output was modified by guardrails)") return processed_output, True