File size: 21,942 Bytes
afad319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
#!/usr/bin/env python3
"""
# Guard Rails System for RAG

This module provides comprehensive guard rails for the RAG system to ensure:
- Input validation and sanitization
- Output safety and content filtering
- Model safety and prompt injection protection
- Data privacy and PII detection
- Rate limiting and abuse prevention

## Guard Rail Categories

1. **Input Guards**: Validate and sanitize user inputs
2. **Output Guards**: Filter and validate generated responses
3. **Model Guards**: Protect against prompt injection and jailbreaks
4. **Data Guards**: Detect and handle sensitive information
5. **System Guards**: Rate limiting and resource protection
"""

import re
import time
import hashlib
from typing import List, Dict, Optional, Tuple, Any
from dataclasses import dataclass
from collections import defaultdict, deque
import logging
from loguru import logger


# =============================================================================
# DATA STRUCTURES
# =============================================================================


@dataclass
class GuardRailResult:
    """
    Result from a guard rail check

    Attributes:
        passed: Whether the check passed
        blocked: Whether the input/output should be blocked
        reason: Reason for blocking or warning
        confidence: Confidence score for the decision
        metadata: Additional information about the check
    """

    passed: bool
    blocked: bool
    reason: str
    confidence: float
    metadata: Dict[str, Any]


@dataclass
class GuardRailConfig:
    """
    Configuration for guard rail system

    Attributes:
        max_query_length: Maximum allowed query length
        max_response_length: Maximum allowed response length
        min_confidence_threshold: Minimum confidence for responses
        rate_limit_requests: Maximum requests per time window
        rate_limit_window: Time window for rate limiting (seconds)
        enable_pii_detection: Whether to detect PII in documents
        enable_content_filtering: Whether to filter harmful content
        enable_prompt_injection_detection: Whether to detect prompt injection
    """

    max_query_length: int = 1000
    max_response_length: int = 5000
    min_confidence_threshold: float = 0.3
    rate_limit_requests: int = 100
    rate_limit_window: int = 3600  # 1 hour
    enable_pii_detection: bool = True
    enable_content_filtering: bool = True
    enable_prompt_injection_detection: bool = True


# =============================================================================
# INPUT GUARD RAILS
# =============================================================================


class InputGuards:
    """Guard rails for input validation and sanitization"""

    def __init__(self, config: GuardRailConfig):
        self.config = config

        # Compile regex patterns for efficiency
        self.suspicious_patterns = [
            re.compile(r"system:|assistant:|user:", re.IGNORECASE),
            re.compile(r"ignore previous|forget everything", re.IGNORECASE),
            re.compile(r"you are now|act as|pretend to be", re.IGNORECASE),
            re.compile(r"<script|javascript:|eval\(", re.IGNORECASE),
            re.compile(
                r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
            ),
        ]

        # Harmful content patterns
        self.harmful_patterns = [
            re.compile(r"\b(hack|crack|exploit|vulnerability)\b", re.IGNORECASE),
            re.compile(r"\b(bomb|weapon|explosive)\b", re.IGNORECASE),
            re.compile(r"\b(drug|illegal|contraband)\b", re.IGNORECASE),
        ]

    def validate_query(self, query: str, user_id: str = "anonymous") -> GuardRailResult:
        """
        Validate user query for safety and appropriateness

        Args:
            query: User's query string
            user_id: User identifier for rate limiting

        Returns:
            GuardRailResult with validation outcome
        """
        # Check query length
        if len(query) > self.config.max_query_length:
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason=f"Query too long ({len(query)} chars, max {self.config.max_query_length})",
                confidence=1.0,
                metadata={"query_length": len(query)},
            )

        # Check for empty or whitespace-only queries
        if not query.strip():
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason="Empty or whitespace-only query",
                confidence=1.0,
                metadata={},
            )

        # Check for suspicious patterns (potential prompt injection)
        if self.config.enable_prompt_injection_detection:
            for pattern in self.suspicious_patterns:
                if pattern.search(query):
                    return GuardRailResult(
                        passed=False,
                        blocked=True,
                        reason="Suspicious pattern detected (potential prompt injection)",
                        confidence=0.8,
                        metadata={"pattern": pattern.pattern},
                    )

        # Check for harmful content
        if self.config.enable_content_filtering:
            harmful_matches = []
            for pattern in self.harmful_patterns:
                if pattern.search(query):
                    harmful_matches.append(pattern.pattern)

            if harmful_matches:
                return GuardRailResult(
                    passed=False,
                    blocked=True,
                    reason="Harmful content detected",
                    confidence=0.7,
                    metadata={"harmful_patterns": harmful_matches},
                )

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="Query validated successfully",
            confidence=1.0,
            metadata={},
        )

    def sanitize_query(self, query: str) -> str:
        """
        Sanitize query to remove potentially harmful content

        Args:
            query: Raw query string

        Returns:
            Sanitized query string
        """
        # Remove HTML tags
        query = re.sub(r"<[^>]+>", "", query)

        # Remove script tags and content
        query = re.sub(
            r"<script.*?</script>", "", query, flags=re.IGNORECASE | re.DOTALL
        )

        # Remove excessive whitespace
        query = re.sub(r"\s+", " ", query).strip()

        return query


# =============================================================================
# OUTPUT GUARD RAILS
# =============================================================================


class OutputGuards:
    """Guard rails for output validation and filtering"""

    def __init__(self, config: GuardRailConfig):
        self.config = config

        # Response quality patterns
        self.low_quality_patterns = [
            re.compile(r"\b(i don\'t know|i cannot|i am unable)\b", re.IGNORECASE),
            re.compile(r"\b(no information|not found|not available)\b", re.IGNORECASE),
        ]

        # Hallucination indicators
        self.hallucination_patterns = [
            re.compile(
                r"\b(according to the document|as mentioned in|the document states)\b",
                re.IGNORECASE,
            ),
            re.compile(
                r"\b(based on the provided|in the given|from the text)\b", re.IGNORECASE
            ),
        ]

    def validate_response(
        self, response: str, confidence: float, context: str = ""
    ) -> GuardRailResult:
        """
        Validate generated response for safety and quality

        Args:
            response: Generated response text
            confidence: Confidence score from RAG system
            context: Retrieved context for validation

        Returns:
            GuardRailResult with validation outcome
        """
        # Check response length
        if len(response) > self.config.max_response_length:
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason=f"Response too long ({len(response)} chars, max {self.config.max_response_length})",
                confidence=1.0,
                metadata={"response_length": len(response)},
            )

        # Check confidence threshold
        if confidence < self.config.min_confidence_threshold:
            return GuardRailResult(
                passed=False,
                blocked=False,
                reason=f"Low confidence response ({confidence:.2f} < {self.config.min_confidence_threshold})",
                confidence=confidence,
                metadata={"confidence": confidence},
            )

        # Check for low quality responses
        low_quality_count = 0
        for pattern in self.low_quality_patterns:
            if pattern.search(response):
                low_quality_count += 1

        if low_quality_count >= 2:
            return GuardRailResult(
                passed=False,
                blocked=False,
                reason="Low quality response detected",
                confidence=0.6,
                metadata={"low_quality_indicators": low_quality_count},
            )

        # Check for potential hallucinations
        if context and self._detect_hallucination(response, context):
            return GuardRailResult(
                passed=False,
                blocked=False,
                reason="Potential hallucination detected",
                confidence=0.7,
                metadata={"hallucination_risk": "high"},
            )

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="Response validated successfully",
            confidence=confidence,
            metadata={},
        )

    def _detect_hallucination(self, response: str, context: str) -> bool:
        """
        Detect potential hallucinations in response

        Args:
            response: Generated response
            context: Retrieved context

        Returns:
            True if hallucination is likely detected
        """
        # Simple heuristic: check if response contains specific claims not in context
        response_lower = response.lower()
        context_lower = context.lower()

        # Check for specific claims that should be in context
        claim_indicators = [
            "the document states",
            "according to the text",
            "as mentioned in",
            "the information shows",
        ]

        for indicator in claim_indicators:
            if indicator in response_lower:
                # Check if the surrounding text is actually in context
                # This is a simplified check - more sophisticated methods would be needed
                return False  # For now, we'll be conservative

        return False

    def filter_response(self, response: str) -> str:
        """
        Filter response to remove potentially harmful content

        Args:
            response: Raw response string

        Returns:
            Filtered response string
        """
        # Remove HTML tags
        response = re.sub(r"<[^>]+>", "", response)

        # Remove script content
        response = re.sub(
            r"<script.*?</script>", "", response, flags=re.IGNORECASE | re.DOTALL
        )

        # Remove excessive newlines
        response = re.sub(r"\n\s*\n\s*\n+", "\n\n", response)

        return response.strip()


# =============================================================================
# DATA GUARD RAILS
# =============================================================================


class DataGuards:
    """Guard rails for data privacy and PII detection"""

    def __init__(self, config: GuardRailConfig):
        self.config = config

        # PII patterns
        self.pii_patterns = {
            "email": re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"),
            "phone": re.compile(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"),
            "ssn": re.compile(r"\b\d{3}-\d{2}-\d{4}\b"),
            "credit_card": re.compile(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b"),
            "ip_address": re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"),
        }

    def detect_pii(self, text: str) -> GuardRailResult:
        """
        Detect personally identifiable information in text

        Args:
            text: Text to analyze for PII

        Returns:
            GuardRailResult with PII detection outcome
        """
        if not self.config.enable_pii_detection:
            return GuardRailResult(
                passed=True,
                blocked=False,
                reason="PII detection disabled",
                confidence=1.0,
                metadata={},
            )

        detected_pii = {}
        for pii_type, pattern in self.pii_patterns.items():
            matches = pattern.findall(text)
            if matches:
                detected_pii[pii_type] = len(matches)

        if detected_pii:
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason=f"PII detected: {', '.join(detected_pii.keys())}",
                confidence=0.9,
                metadata={"detected_pii": detected_pii},
            )

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="No PII detected",
            confidence=1.0,
            metadata={},
        )

    def sanitize_pii(self, text: str) -> str:
        """
        Sanitize text by removing or masking PII

        Args:
            text: Text containing potential PII

        Returns:
            Sanitized text with PII masked
        """
        # Mask email addresses
        text = re.sub(
            r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "[EMAIL]", text
        )

        # Mask phone numbers
        text = re.sub(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b", "[PHONE]", text)

        # Mask SSN
        text = re.sub(r"\b\d{3}-\d{2}-\d{4}\b", "[SSN]", text)

        # Mask credit card numbers
        text = re.sub(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", "[CREDIT_CARD]", text)

        # Mask IP addresses
        text = re.sub(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", "[IP_ADDRESS]", text)

        return text


# =============================================================================
# SYSTEM GUARD RAILS
# =============================================================================


class SystemGuards:
    """Guard rails for system-level protection"""

    def __init__(self, config: GuardRailConfig):
        self.config = config
        self.request_history = defaultdict(lambda: deque(maxlen=1000))
        self.blocked_users = set()

    def check_rate_limit(self, user_id: str) -> GuardRailResult:
        """
        Check if user has exceeded rate limits

        Args:
            user_id: User identifier

        Returns:
            GuardRailResult with rate limit check outcome
        """
        current_time = time.time()
        user_requests = self.request_history[user_id]

        # Remove old requests outside the window
        while (
            user_requests
            and current_time - user_requests[0] > self.config.rate_limit_window
        ):
            user_requests.popleft()

        # Check if user is blocked
        if user_id in self.blocked_users:
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason="User is blocked due to previous violations",
                confidence=1.0,
                metadata={"user_id": user_id},
            )

        # Check rate limit
        if len(user_requests) >= self.config.rate_limit_requests:
            # Block user temporarily
            self.blocked_users.add(user_id)
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason=f"Rate limit exceeded ({len(user_requests)} requests in {self.config.rate_limit_window}s)",
                confidence=1.0,
                metadata={"requests": len(user_requests)},
            )

        # Add current request
        user_requests.append(current_time)

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="Rate limit check passed",
            confidence=1.0,
            metadata={"requests": len(user_requests)},
        )

    def check_resource_usage(
        self, memory_usage: float, cpu_usage: float
    ) -> GuardRailResult:
        """
        Check system resource usage

        Args:
            memory_usage: Current memory usage percentage
            cpu_usage: Current CPU usage percentage

        Returns:
            GuardRailResult with resource check outcome
        """
        # Define thresholds
        memory_threshold = 90.0  # 90% memory usage
        cpu_threshold = 95.0  # 95% CPU usage

        if memory_usage > memory_threshold:
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason=f"High memory usage ({memory_usage:.1f}%)",
                confidence=1.0,
                metadata={"memory_usage": memory_usage},
            )

        if cpu_usage > cpu_threshold:
            return GuardRailResult(
                passed=False,
                blocked=True,
                reason=f"High CPU usage ({cpu_usage:.1f}%)",
                confidence=1.0,
                metadata={"cpu_usage": cpu_usage},
            )

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="Resource usage acceptable",
            confidence=1.0,
            metadata={"memory_usage": memory_usage, "cpu_usage": cpu_usage},
        )


# =============================================================================
# MAIN GUARD RAIL SYSTEM
# =============================================================================


class GuardRailSystem:
    """
    Comprehensive guard rail system for RAG

    This class orchestrates all guard rail components to ensure
    safe and reliable operation of the RAG system.
    """

    def __init__(self, config: GuardRailConfig = None):
        self.config = config or GuardRailConfig()

        # Initialize all guard rail components
        self.input_guards = InputGuards(self.config)
        self.output_guards = OutputGuards(self.config)
        self.data_guards = DataGuards(self.config)
        self.system_guards = SystemGuards(self.config)

        logger.info("Guard rail system initialized successfully")

    def validate_input(self, query: str, user_id: str = "anonymous") -> GuardRailResult:
        """
        Comprehensive input validation

        Args:
            query: User query
            user_id: User identifier

        Returns:
            GuardRailResult with validation outcome
        """
        # Check rate limits first
        rate_limit_result = self.system_guards.check_rate_limit(user_id)
        if not rate_limit_result.passed:
            return rate_limit_result

        # Validate query
        query_result = self.input_guards.validate_query(query, user_id)
        if not query_result.passed:
            return query_result

        # Check for PII in query
        pii_result = self.data_guards.detect_pii(query)
        if not pii_result.passed:
            return pii_result

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="Input validation passed",
            confidence=1.0,
            metadata={},
        )

    def validate_output(
        self, response: str, confidence: float, context: str = ""
    ) -> GuardRailResult:
        """
        Comprehensive output validation

        Args:
            response: Generated response
            confidence: Confidence score
            context: Retrieved context

        Returns:
            GuardRailResult with validation outcome
        """
        # Validate response
        response_result = self.output_guards.validate_response(
            response, confidence, context
        )
        if not response_result.passed:
            return response_result

        # Check for PII in response
        pii_result = self.data_guards.detect_pii(response)
        if not pii_result.passed:
            return pii_result

        return GuardRailResult(
            passed=True,
            blocked=False,
            reason="Output validation passed",
            confidence=confidence,
            metadata={},
        )

    def sanitize_input(self, query: str) -> str:
        """Sanitize user input"""
        return self.input_guards.sanitize_query(query)

    def sanitize_output(self, response: str) -> str:
        """Sanitize generated output"""
        return self.output_guards.filter_response(response)

    def sanitize_data(self, text: str) -> str:
        """Sanitize data by removing PII"""
        return self.data_guards.sanitize_pii(text)

    def get_system_status(self) -> Dict[str, Any]:
        """
        Get current system status and statistics

        Returns:
            Dictionary with system status information
        """
        return {
            "total_users": len(self.system_guards.request_history),
            "blocked_users": len(self.system_guards.blocked_users),
            "config": {
                "max_query_length": self.config.max_query_length,
                "max_response_length": self.config.max_response_length,
                "min_confidence_threshold": self.config.min_confidence_threshold,
                "rate_limit_requests": self.config.rate_limit_requests,
                "rate_limit_window": self.config.rate_limit_window,
            },
        }