File size: 17,153 Bytes
b5386e2
 
a2e1879
 
b5386e2
 
 
 
a2e1879
 
b5386e2
a2e1879
 
b5386e2
a2e1879
 
 
b5386e2
 
 
a2e1879
b5386e2
 
 
 
 
a2e1879
b5386e2
 
a2e1879
 
 
b5386e2
a2e1879
b5386e2
a2e1879
 
b5386e2
a2e1879
 
b5386e2
a2e1879
 
b5386e2
a2e1879
 
 
b5386e2
 
 
a2e1879
b5386e2
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5386e2
 
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5386e2
 
 
 
 
a2e1879
b5386e2
 
 
 
 
 
 
 
 
 
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5386e2
 
 
 
 
 
 
 
 
 
a2e1879
b5386e2
 
 
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5386e2
 
 
 
a2e1879
b5386e2
 
 
 
 
a2e1879
 
 
 
 
 
 
 
 
 
 
 
b5386e2
a2e1879
b5386e2
 
a2e1879
b5386e2
a2e1879
 
 
 
b5386e2
 
 
a2e1879
 
 
 
 
 
 
 
b5386e2
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
# backend.py
import importlib
import json
import time
from typing import Tuple, Generator, Any

import config
from llm_clients.base import LlmClient
from llm_clients.performance_utils import apply_performance_optimizations
from english_detector import is_english_by_ascii_letters_only

# Apply performance optimizations early
apply_performance_optimizations()


class OutputGuardrailManager:
    """Manages the loading and application of modular output-specific guardrails."""

    def __init__(self, guard_configs: dict):
        self.guards = []
        print("\nInitializing Modular Output Guardrail Manager...")
        for name, g_config in guard_configs.items():
            if g_config.get("enabled"):
                try:
                    # Dynamically import the guardrail module
                    module = importlib.import_module(f"guardrails.{name}")
                    # Construct the class name from the guardrail name
                    guard_class_name = name.replace("_", " ").title().replace(" ", "")
                    guard_class = getattr(module, guard_class_name)
                    guard_instance = guard_class(g_config)
                    self.guards.append(guard_instance)
                    print(f"   βœ… Loaded output guardrail: {name}")
                except (ModuleNotFoundError, AttributeError, ImportError) as e:
                    print(f"   ⚠️  Could not load output guardrail '{name}': {e}")

    def process_complete_output(self, text: str) -> Tuple[str, bool]:
        """Process complete output text through all loaded output guardrails."""
        safe = True
        current_text = text
        
        for guard in self.guards:
            if hasattr(guard, "process_complete_output"):
                current_text, safe = guard.process_complete_output(current_text)
                if not safe:
                    return current_text, False
        
        return current_text, True


class Backend:
    """Handles the core logic of processing requests with AI detection and modular output guardrails."""

    def __init__(self, output_test_mode: bool = False):
        self.output_test_mode = output_test_mode
        self._translator_client: LlmClient | None = None
        
        if output_test_mode:
            print("\nπŸ“ Output Testing Mode: ENABLED")
            print("   Only modular output guardrails will be active.")
            self.attack_detector = None
        else:
            print("\nπŸ”’ AI Detection Mode: ENABLED")
            print("   Using finetuned model for input guardrails.")
            try:
                self.attack_detector = self._load_attack_detector()
            except Exception as e:
                print(f"⚠️  WARNING: Failed to load attack detector: {e}")
                print("   πŸ”„ Falling back to output-only mode for better compatibility")
                print("   πŸ’‘ The system will still work with output guardrails only")
                self.attack_detector = None
                self.output_test_mode = True  # Switch to output-only mode
            
        # Initialize output guardrails in both modes
        self.output_guardrail_manager = OutputGuardrailManager(config.OUTPUT_GUARDRAILS_CONFIG)
        
        # Initialize attachment guardrails
        self.attachment_guardrail_manager = self._load_attachment_guardrails()
        
        self.llm_client = self._load_llm_client()

    def _get_translator_client(self) -> LlmClient:
        """Lazily load and return the translation client for non-English text."""
        if self._translator_client is not None:
            return self._translator_client
        translator_cfg = getattr(config, "NON_ENGLISH_TRANSLATOR", {"enabled": False})
        if not translator_cfg.get("enabled", False):
            raise ImportError("Non-English translator disabled in config.")
        provider = translator_cfg.get("provider", "qwen_translator")
        provider_cfg = translator_cfg.get("config", {})
        try:
            module = importlib.import_module(f"llm_clients.{provider}")
            client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client"
            client_class = getattr(module, client_class_name)
            # System prompt not needed (client has its own translation prompt), pass empty string
            self._translator_client = client_class(provider_cfg, "")
            print(f"   🌍 Translation client loaded: {provider} ({provider_cfg.get('model', '')})")
            return self._translator_client
        except Exception as e:
            raise ImportError(f"Could not load translation client '{provider}': {e}")

    def _load_attack_detector(self) -> LlmClient:
        """Loads the attack detection client (finetuned model via FinetunedGuard)."""
        ai_config = config.AI_DETECTION_MODE
        provider = ai_config["attack_llm_provider"]
        llm_config = ai_config["attack_llm_config"]

        try:
            # Use shared model for finetuned_guard provider to avoid duplicate loading
            if provider == "finetuned_guard":
                from llm_clients.shared_models import shared_model_manager
                model_name = llm_config.get("model_name", "zazaman/fmb")
                shared_client = shared_model_manager.get_finetuned_guard_client(model_name)
                if shared_client:
                    print(f"   πŸ” Main Attack Detector: Using shared model {model_name}")
                    return shared_client
                else:
                    raise ImportError(f"Could not get shared finetuned model {model_name}")
            else:
                # For other providers, load normally
                module = importlib.import_module(f"llm_clients.{provider}")
                client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client"
                client_class = getattr(module, client_class_name)
                return client_class(llm_config, "")  # No system prompt needed for classification model
        except (ModuleNotFoundError, AttributeError, ImportError) as e:
            raise ImportError(f"Could not load attack detection client for '{provider}': {e}")

    def _load_attachment_guardrails(self):
        """Load and initialize attachment guardrails manager."""
        try:
            from guardrails.attachments.base import AttachmentGuardrailManager
            return AttachmentGuardrailManager(config.ATTACHMENT_GUARDRAILS_CONFIG)
        except Exception as e:
            print(f"⚠️  Could not load attachment guardrails: {e}")
            return None

    def _load_llm_client(self) -> LlmClient:
        """Dynamically loads and initializes the configured LLM client."""
        provider = config.LLM_PROVIDER
        llm_config = config.LLM_CONFIG.get(provider)

        if llm_config is None:
            raise ValueError(f"LLM provider '{provider}' not configured in config.py")

        try:
            module = importlib.import_module(f"llm_clients.{provider}")
            client_class_name = provider.replace("_", " ").title().replace(" ", "") + "Client"
            client_class = getattr(module, client_class_name)
            return client_class(llm_config, config.SYSTEM_PROMPT)
        except (ModuleNotFoundError, AttributeError, ImportError) as e:
            raise ImportError(f"Could not load LLM client for '{provider}': {e}")

    def _check_with_ai_detector(self, prompt: str) -> Tuple[bool, str]:
        """
        Checks the prompt with the AI attack detector (finetuned model).
        If the prompt is non-English, translates it to English first, then classifies.
        Returns (is_safe, reason).
        """
        original_prompt = prompt
        translated_prompt = prompt
        
        # Check if prompt is non-English and translate if needed
        if not is_english_by_ascii_letters_only(prompt):
            try:
                print("🌍 Detected non-English input. Translating to English...")
                translator_client = self._get_translator_client()
                translation_start = time.time()
                translated_prompt = translator_client.generate_content(prompt)
                translation_time = (time.time() - translation_start) * 1000
                print(f"   βœ… Translated to English ({translation_time:.1f}ms): '{translated_prompt[:100]}...'")
            except Exception as e:
                print(f"⚠️  Translation failed: {e}. Proceeding with original text (may cause classification issues).")
                # Continue with original prompt - the classifier might still work or fail gracefully

        try:
            # Measure classification latency (always use ModernBERT on translated/English text)
            start_time = time.time()
            response = self.attack_detector.generate_content(translated_prompt)
            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            
            # Parse the JSON response
            json_response = self._extract_json_from_response(response)
            
            try:
                result = json.loads(json_response)
                safety_status = result.get("safety_status", "unsafe")
                attack_type = result.get("attack_type", "unknown")
                confidence = result.get("confidence", 1.0)
                reason = result.get("reason", "No specific reason provided")
                
                is_safe = safety_status.lower() == "safe"
                
                if not is_safe:
                    block_reason = f"πŸ€– AI Security Scanner: Detected {attack_type} attack (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms). Reason: {reason}"
                    if original_prompt != translated_prompt:
                        block_reason += f" [Original non-English text was translated to English for analysis]"
                    print(f"🚨 Attack detected: {attack_type} - {reason} (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)")
                    return False, block_reason
                else:
                    safe_msg = f"βœ… AI Security Scanner: Prompt classified as safe (confidence: {confidence:.2f}, latency: {latency_ms:.1f}ms)"
                    if original_prompt != translated_prompt:
                        safe_msg += f" [Non-English text was translated to English for analysis]"
                    print(safe_msg)
                    return True, ""
                    
            except json.JSONDecodeError as e:
                print(f"⚠️  Could not parse AI detector JSON response: {json_response}")
                print(f"   JSON Error: {e}")
                print(f"   Full response: {response[:200]}...")
                # Default to unsafe if we can't parse the response
                return False, f"πŸ€– AI Security Scanner: Could not parse security analysis (latency: {latency_ms:.1f}ms). Request blocked for safety."
                
        except Exception as e:
            print(f"❌ Error communicating with AI attack detector: {e}")
            # Default to unsafe if there's an error
            return False, f"πŸ€– AI Security Scanner: Error during security analysis: {str(e)}. Request blocked for safety."

    def _extract_json_from_response(self, response: str) -> str:
        """
        Extract JSON from the response, handling thinking tags and other extra content.
        """
        # Remove thinking tags if present
        if "<think>" in response:
            # Find the end of thinking tag and get everything after it
            think_end = response.find("</think>")
            if think_end != -1:
                response = response[think_end + 8:].strip()
        
        # Look for JSON object boundaries
        json_start = response.find("{")
        if json_start == -1:
            return response.strip()
            
        # Find the matching closing brace
        brace_count = 0
        json_end = -1
        
        for i in range(json_start, len(response)):
            if response[i] == "{":
                brace_count += 1
            elif response[i] == "}":
                brace_count -= 1
                if brace_count == 0:
                    json_end = i + 1
                    break
        
        if json_end == -1:
            # If we can't find proper JSON boundaries, return everything after the first {
            return response[json_start:].strip()
        
        return response[json_start:json_end].strip()

    def _adapt_stream_to_text(self, stream: Generator[Any, None, None]) -> Generator[str, None, None]:
        """
        Adapts an LLM client's output stream into a consistent stream of text chunks.
        This is necessary because different LLM clients may yield different object types.
        """
        # The Gemini client yields `GenerateContentResponse` objects. We need to extract the text.
        if config.LLM_PROVIDER == "gemini":
            for chunk in stream:
                if hasattr(chunk, 'text'):
                    yield chunk.text
        # Other clients are expected to yield strings directly.
        else:
            yield from stream

    def _apply_output_guardrails_to_stream(self, stream: Generator[str, None, None]) -> Generator[str, None, None]:
        """
        Applies output guardrails to a streaming response by collecting the full response first,
        then processing it through guardrails and yielding the result.
        """
        # Collect the full response from the stream
        full_response = ""
        for chunk in stream:
            full_response += chunk
        
        # Apply output guardrails to the complete response
        processed_response, is_safe = self.output_guardrail_manager.process_complete_output(full_response)
        
        if not is_safe:
            # If blocked, yield the block message
            yield processed_response
        else:
            # If safe, yield the processed response (may be anonymized)
            yield processed_response

    def process_request(
        self, prompt: str, stream: bool = False
    ) -> Tuple[Any, bool, str]:
        """
        Processes a request by applying AI detection, calling the LLM, and returning the response.
        Returns:
            - The response (blocked message or stream)
            - A boolean indicating if the request was safe
            - The processed prompt that was sent to the LLM
        """
        if not self.output_test_mode:
            # Check with AI detector (finetuned model)
            is_safe, block_reason = self._check_with_ai_detector(prompt)
            
            if not is_safe:
                return block_reason, False, prompt
                
            # Prompt may have been translated in _check_with_ai_detector, but we use original for LLM
            processed_prompt = prompt
        else:
            # In output test mode, skip AI detection
            processed_prompt = prompt

        # Send to LLM
        if stream:
            response_stream = self.llm_client.generate_content_stream(processed_prompt)
            # Adapt the stream to a consistent text-only stream
            text_stream = self._adapt_stream_to_text(response_stream)
            
            # Apply output guardrails to streaming response
            processed_stream = self._apply_output_guardrails_to_stream(text_stream)
            return processed_stream, True, processed_prompt
        else:
            # For non-streaming, we expect a simple string response from the client
            response = self.llm_client.generate_content(processed_prompt)
            
            # Apply output guardrails to complete response
            processed_response, is_safe = self.output_guardrail_manager.process_complete_output(response)
            
            if not is_safe:
                return processed_response, False, processed_prompt
                
            return processed_response, True, processed_prompt

    def test_output_guardrails(self, prompt: str, manual_output: str) -> Tuple[str, bool]:
        """
        Test modular output guardrails with manual input. This method is specifically
        for output testing mode where users provide both prompt and expected output.
        """
        if not self.output_test_mode:
            raise ValueError("Backend must be initialized in output_test_mode to use this method")
            
        print(f"\nπŸ” Testing modular output guardrails on provided text...")
        print(f"   Input length: {len(manual_output)} characters")
        
        # Process the manual output through modular output guardrails
        processed_output, is_safe = self.output_guardrail_manager.process_complete_output(manual_output)
        
        if not is_safe:
            print(f"πŸ”’ Output was BLOCKED by guardrails")
            return processed_output, False
        else:
            print(f"βœ… Output passed all guardrails")
            if processed_output != manual_output:
                print(f"   (Output was modified by guardrails)")
            return processed_output, True