Spaces:
Sleeping
Sleeping
| # backend.py | |
| import importlib | |
| import json | |
| import time | |
| from typing import Tuple, Generator, Any | |
| import config | |
| from llm_clients.base import LlmClient | |
| from llm_clients.performance_utils import apply_performance_optimizations | |
| from english_detector import is_english_by_ascii_letters_only | |
| # Apply performance optimizations early | |
| apply_performance_optimizations() | |
| class OutputGuardrailManager: | |
| """Manages the loading and application of modular output-specific guardrails.""" | |
| def __init__(self, guard_configs: dict): | |
| self.guards = [] | |
| print("\nInitializing Modular Output Guardrail Manager...") | |
| for name, g_config in guard_configs.items(): | |
| if g_config.get("enabled"): | |
| try: | |
| # Dynamically import the guardrail module | |
| module = importlib.import_module(f"guardrails.{name}") | |
| # Construct the class name from the guardrail name | |
| guard_class_name = name.replace("_", " ").title().replace(" ", "") | |
| guard_class = getattr(module, guard_class_name) | |
| guard_instance = guard_class(g_config) | |
| self.guards.append(guard_instance) | |
| print(f" β Loaded output guardrail: {name}") | |
| except (ModuleNotFoundError, AttributeError, ImportError) as e: | |
| print(f" β οΈ Could not load output guardrail '{name}': {e}") | |
| def process_complete_output(self, text: str) -> Tuple[str, bool]: | |
| """Process complete output text through all loaded output guardrails.""" | |
| safe = True | |
| current_text = text | |
| for guard in self.guards: | |
| if hasattr(guard, "process_complete_output"): | |
| current_text, safe = guard.process_complete_output(current_text) | |
| if not safe: | |
| return current_text, False | |
| return current_text, True | |
| class Backend: | |
| """Handles the core logic of processing requests with AI detection and modular output guardrails.""" | |
| def __init__(self, output_test_mode: bool = False): | |
| self.output_test_mode = output_test_mode | |
| self._translator_client: LlmClient | None = None | |
| if output_test_mode: | |
| print("\nπ Output Testing Mode: ENABLED") | |
| print(" Only modular output guardrails will be active.") | |
| self.attack_detector = None | |
| else: | |
| print("\nπ AI Detection Mode: ENABLED") | |
| print(" Using finetuned model for input guardrails.") | |
| try: | |
| self.attack_detector = self._load_attack_detector() | |
| except Exception as e: | |
| print(f"β οΈ WARNING: Failed to load attack detector: {e}") | |
| print(" π Falling back to output-only mode for better compatibility") | |
| print(" π‘ The system will still work with output guardrails only") | |
| self.attack_detector = None | |
| self.output_test_mode = True # Switch to output-only mode | |
| # Initialize output guardrails in both modes | |
| self.output_guardrail_manager = OutputGuardrailManager(config.OUTPUT_GUARDRAILS_CONFIG) | |
| # Initialize attachment guardrails | |
| self.attachment_guardrail_manager = self._load_attachment_guardrails() | |
| self.llm_client = self._load_llm_client() | |
| def _get_translator_client(self) -> LlmClient: | |
| """Lazily load and return the translation client for non-English text.""" | |
| if self._translator_client is not None: | |
| return self._translator_client | |
| translator_cfg = getattr(config, "NON_ENGLISH_TRANSLATOR", {"enabled": False}) | |
| if not translator_cfg.get("enabled", False): | |
| raise ImportError("Non-English translator disabled in config.") | |
| provider = translator_cfg.get("provider", "qwen_translator") | |
| provider_cfg = translator_cfg.get("config", {}) | |
| try: | |
| module = importlib.import_module(f"llm_clients.{provider}") | |
| client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client" | |
| client_class = getattr(module, client_class_name) | |
| # System prompt not needed (client has its own translation prompt), pass empty string | |
| self._translator_client = client_class(provider_cfg, "") | |
| print(f" π Translation client loaded: {provider} ({provider_cfg.get('model', '')})") | |
| return self._translator_client | |
| except Exception as e: | |
| raise ImportError(f"Could not load translation client '{provider}': {e}") | |
| def _load_attack_detector(self) -> LlmClient: | |
| """Loads the attack detection client (finetuned model via FinetunedGuard).""" | |
| ai_config = config.AI_DETECTION_MODE | |
| provider = ai_config["attack_llm_provider"] | |
| llm_config = ai_config["attack_llm_config"] | |
| try: | |
| # Use shared model for finetuned_guard provider to avoid duplicate loading | |
| if provider == "finetuned_guard": | |
| from llm_clients.shared_models import shared_model_manager | |
| model_name = llm_config.get("model_name", "zazaman/fmb") | |
| shared_client = shared_model_manager.get_finetuned_guard_client(model_name) | |
| if shared_client: | |
| print(f" π Main Attack Detector: Using shared model {model_name}") | |
| return shared_client | |
| else: | |
| raise ImportError(f"Could not get shared finetuned model {model_name}") | |
| else: | |
| # For other providers, load normally | |
| module = importlib.import_module(f"llm_clients.{provider}") | |
| client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client" | |
| client_class = getattr(module, client_class_name) | |
| return client_class(llm_config, "") # No system prompt needed for classification model | |
| except (ModuleNotFoundError, AttributeError, ImportError) as e: | |
| raise ImportError(f"Could not load attack detection client for '{provider}': {e}") | |
| def _load_attachment_guardrails(self): | |
| """Load and initialize attachment guardrails manager.""" | |
| try: | |
| from guardrails.attachments.base import AttachmentGuardrailManager | |
| return AttachmentGuardrailManager(config.ATTACHMENT_GUARDRAILS_CONFIG) | |
| except Exception as e: | |
| print(f"β οΈ Could not load attachment guardrails: {e}") | |
| return None | |
| def _load_llm_client(self) -> LlmClient: | |
| """Dynamically loads and initializes the configured LLM client.""" | |
| provider = config.LLM_PROVIDER | |
| llm_config = config.LLM_CONFIG.get(provider) | |
| if llm_config is None: | |
| raise ValueError(f"LLM provider '{provider}' not configured in config.py") | |
| try: | |
| module = importlib.import_module(f"llm_clients.{provider}") | |
| client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client" | |
| client_class = getattr(module, client_class_name) | |
| return client_class(llm_config, config.SYSTEM_PROMPT) | |
| except (ModuleNotFoundError, AttributeError, ImportError) as e: | |
| raise ImportError(f"Could not load LLM client for '{provider}': {e}") | |
| def _check_with_ai_detector(self, prompt: str) -> Tuple[bool, str]: | |
| """ | |
| Checks the prompt with the AI attack detector (finetuned model). | |
| If the prompt is non-English, translates it to English first, then classifies. | |
| Returns (is_safe, reason). | |
| """ | |
| original_prompt = prompt | |
| translated_prompt = prompt | |
| # Check if prompt is non-English and translate if needed | |
| if not is_english_by_ascii_letters_only(prompt): | |
| try: | |
| print("π Detected non-English input. Translating to English...", flush=True) | |
| print(f" Original text: '{prompt[:100]}...'", flush=True) | |
| translator_client = self._get_translator_client() | |
| translation_start = time.time() | |
| translated_prompt = translator_client.generate_content(prompt) | |
| translation_time = (time.time() - translation_start) * 1000 | |
| print(f" β Translated to English ({translation_time:.1f}ms): '{translated_prompt[:200]}...'", flush=True) | |
| print(f" π Will classify translated text (length: {len(translated_prompt)} chars)", flush=True) | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"β οΈ Translation failed: {error_msg}", flush=True) | |
| print(f" Proceeding with original text (may cause classification issues).", flush=True) | |
| # Continue with original prompt - the classifier might still work or fail gracefully | |
| translated_prompt = prompt | |
| else: | |
| print(f" β Text is English, no translation needed", flush=True) | |
| translated_prompt = prompt | |
| try: | |
| # Measure classification latency (always use ModernBERT on translated/English text) | |
| print(f" π Classifying text: '{translated_prompt[:100]}...'", flush=True) | |
| print(f" Text length: {len(translated_prompt)} chars", flush=True) | |
| start_time = time.time() | |
| response = self.attack_detector.generate_content(translated_prompt) | |
| end_time = time.time() | |
| latency_ms = (end_time - start_time) * 1000 | |
| # Parse the JSON response | |
| json_response = self._extract_json_from_response(response) | |
| try: | |
| result = json.loads(json_response) | |
| safety_status = result.get("safety_status", "unsafe") | |
| attack_type = result.get("attack_type", "unknown") | |
| confidence = result.get("confidence", 1.0) | |
| reason = result.get("reason", "No specific reason provided") | |
| is_safe = safety_status.lower() == "safe" | |
| print(f" π Classification result: safety_status='{safety_status}', is_safe={is_safe}, confidence={confidence:.2f}", flush=True) | |
| if not is_safe: | |
| block_reason = f"π€ AI Security Scanner: Detected {attack_type} attack (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms). Reason: {reason}" | |
| if original_prompt != translated_prompt: | |
| block_reason += f" [Original non-English text was translated to English for analysis]" | |
| print(f"π¨ Attack detected: {attack_type} - {reason} (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)") | |
| return False, block_reason | |
| else: | |
| safe_msg = f"β AI Security Scanner: Prompt classified as safe (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)" | |
| if original_prompt != translated_prompt: | |
| safe_msg += f" [Non-English text was translated to English for analysis]" | |
| print(safe_msg) | |
| return True, "" | |
| except json.JSONDecodeError as e: | |
| print(f"β οΈ Could not parse AI detector JSON response: {json_response}") | |
| print(f" JSON Error: {e}") | |
| print(f" Full response: {response[:200]}...") | |
| # Default to unsafe if we can't parse the response | |
| return False, f"π€ AI Security Scanner: Could not parse security analysis (latency: {latency_ms:.1f}ms). Request blocked for safety." | |
| except Exception as e: | |
| print(f"β Error communicating with AI attack detector: {e}") | |
| # Default to unsafe if there's an error | |
| return False, f"π€ AI Security Scanner: Error during security analysis: {str(e)}. Request blocked for safety." | |
| def _extract_json_from_response(self, response: str) -> str: | |
| """ | |
| Extract JSON from the response, handling thinking tags and other extra content. | |
| """ | |
| # Remove thinking tags if present | |
| if "<think>" in response: | |
| # Find the end of thinking tag and get everything after it | |
| think_end = response.find("</think>") | |
| if think_end != -1: | |
| response = response[think_end + 8:].strip() | |
| # Look for JSON object boundaries | |
| json_start = response.find("{") | |
| if json_start == -1: | |
| return response.strip() | |
| # Find the matching closing brace | |
| brace_count = 0 | |
| json_end = -1 | |
| for i in range(json_start, len(response)): | |
| if response[i] == "{": | |
| brace_count += 1 | |
| elif response[i] == "}": | |
| brace_count -= 1 | |
| if brace_count == 0: | |
| json_end = i + 1 | |
| break | |
| if json_end == -1: | |
| # If we can't find proper JSON boundaries, return everything after the first { | |
| return response[json_start:].strip() | |
| return response[json_start:json_end].strip() | |
| def _adapt_stream_to_text(self, stream: Generator[Any, None, None]) -> Generator[str, None, None]: | |
| """ | |
| Adapts an LLM client's output stream into a consistent stream of text chunks. | |
| This is necessary because different LLM clients may yield different object types. | |
| """ | |
| # The Gemini client yields `GenerateContentResponse` objects. We need to extract the text. | |
| if config.LLM_PROVIDER == "gemini": | |
| for chunk in stream: | |
| if hasattr(chunk, 'text'): | |
| yield chunk.text | |
| # Other clients are expected to yield strings directly. | |
| else: | |
| yield from stream | |
| def _apply_output_guardrails_to_stream(self, stream: Generator[str, None, None]) -> Generator[str, None, None]: | |
| """ | |
| Applies output guardrails to a streaming response by collecting the full response first, | |
| then processing it through guardrails and yielding the result. | |
| """ | |
| # Collect the full response from the stream | |
| full_response = "" | |
| for chunk in stream: | |
| full_response += chunk | |
| # Apply output guardrails to the complete response | |
| processed_response, is_safe = self.output_guardrail_manager.process_complete_output(full_response) | |
| if not is_safe: | |
| # If blocked, yield the block message | |
| yield processed_response | |
| else: | |
| # If safe, yield the processed response (may be anonymized) | |
| yield processed_response | |
| def process_request( | |
| self, prompt: str, stream: bool = False | |
| ) -> Tuple[Any, bool, str]: | |
| """ | |
| Processes a request by applying AI detection, calling the LLM, and returning the response. | |
| Returns: | |
| - The response (blocked message or stream) | |
| - A boolean indicating if the request was safe | |
| - The processed prompt that was sent to the LLM | |
| """ | |
| if not self.output_test_mode: | |
| # Check with AI detector (finetuned model) | |
| is_safe, block_reason = self._check_with_ai_detector(prompt) | |
| if not is_safe: | |
| return block_reason, False, prompt | |
| # Prompt may have been translated in _check_with_ai_detector, but we use original for LLM | |
| processed_prompt = prompt | |
| else: | |
| # In output test mode, skip AI detection | |
| processed_prompt = prompt | |
| # Send to LLM | |
| if stream: | |
| response_stream = self.llm_client.generate_content_stream(processed_prompt) | |
| # Adapt the stream to a consistent text-only stream | |
| text_stream = self._adapt_stream_to_text(response_stream) | |
| # Apply output guardrails to streaming response | |
| processed_stream = self._apply_output_guardrails_to_stream(text_stream) | |
| return processed_stream, True, processed_prompt | |
| else: | |
| # For non-streaming, we expect a simple string response from the client | |
| response = self.llm_client.generate_content(processed_prompt) | |
| # Apply output guardrails to complete response | |
| processed_response, is_safe = self.output_guardrail_manager.process_complete_output(response) | |
| if not is_safe: | |
| return processed_response, False, processed_prompt | |
| return processed_response, True, processed_prompt | |
| def test_output_guardrails(self, prompt: str, manual_output: str) -> Tuple[str, bool]: | |
| """ | |
| Test modular output guardrails with manual input. This method is specifically | |
| for output testing mode where users provide both prompt and expected output. | |
| """ | |
| if not self.output_test_mode: | |
| raise ValueError("Backend must be initialized in output_test_mode to use this method") | |
| print(f"\nπ Testing modular output guardrails on provided text...") | |
| print(f" Input length: {len(manual_output)} characters") | |
| # Process the manual output through modular output guardrails | |
| processed_output, is_safe = self.output_guardrail_manager.process_complete_output(manual_output) | |
| if not is_safe: | |
| print(f"π Output was BLOCKED by guardrails") | |
| return processed_output, False | |
| else: | |
| print(f"β Output passed all guardrails") | |
| if processed_output != manual_output: | |
| print(f" (Output was modified by guardrails)") | |
| return processed_output, True |