File size: 15,790 Bytes
2129c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
"""
Post-LLM response correction and sanitization module.

This module applies the TruthTable constraints to LLM responses:
1. Re-injects protected placeholder values
2. Removes unauthorized entities not present in original prompt
3. Enforces FORBID/MANDATE semantic restrictions

Mathematical Foundations
------------------------
1. Placeholder Substitution:
    Given placeholder map M = {p₁→v₁, ..., pₖ→vₖ} and response R:
        R' = R[p₁→v₁][p₂→v₂]...[pₖ→vₖ]
    Order: process by descending |pᵢ| to avoid partial substitution conflicts.

2. Entity Authorization Check:
    Let E_orig = {(typeᵢ, valueᵢ)} from shielded prompt
    Let E_resp = {(typeⱼ, valueⱼ)} extracted from response
    Unauthorized: E_unauth = E_resp \ E_orig
    Action: Replace each (t, v) ∈ E_unauth with sanitization marker.

3. Restriction Enforcement:
    For restriction r with type T and entity e:
        if T = FORBID:  R = R \ {occurrences of e}
        if T = MANDATE: if e ∉ R: R = R ∥ "[Note: must use e]"
    Where \ = set difference on text occurrences, ∥ = string concatenation.

4. Regex Pattern Complexity:
    Pattern matching: O(n · m) where n = text length, m = pattern length
    Multiple patterns: O(n · Σ|pᵢ|) with optimized regex engine (RE2-style)

References
----------
[1] Cox, R. (2007). Regular Expression Matching Can Be Simple And Fast.
    https://swtch.com/~rsc/regexp/regexp1.html

[2] Aho, A. V., & Corasick, M. J. (1975). Efficient string matching.
    Communications of the ACM, 18(6), 333-340.

[3] OpenAI. (2024). Prompt injection and output sanitization best practices.
    https://platform.openai.com/docs/guides/safety

Performance Characteristics
---------------------------
- _build_entity_patterns(): O(1) - constant number of patterns
- correct() full pipeline: O(n · (k + p + r)) where:
    n = response length, k = placeholders, p = entity patterns, r = restrictions
- Memory: O(|E_orig| + |M|) for entity/placeholder lookup sets

Author: IntelliDeep Lab Team
License: BSL 1.1
"""

from __future__ import annotations

import logging
import re
from typing import List, Optional, Set, Tuple

from nlproxy.core.shield import ShieldResult
from nlproxy.core.restriction import Restriction

logger = logging.getLogger(__name__)


class ResponseCorrector:
    """
    Applies TruthTable constraints to sanitize LLM responses.

    This class ensures that responses respect the security and semantic
    constraints extracted from the original prompt:

    1. Placeholder Re-injection: Restores protected values (code, PII, etc.)
    2. Entity Sanitization: Removes entities not authorized in original prompt
    3. Restriction Enforcement: Applies FORBID/MANDATE rules to final output

    Key Design Decisions
    --------------------
    - Longest-first placeholder substitution prevents partial match corruption
    - Entity type + value tuple matching avoids false positives (e.g., same IP appearing legitimately)
    - Case-insensitive restriction matching for robust enforcement
    - Minimal output modification: only redact/add what's necessary

    Usage Example
    -------------
    >>> corrector = ResponseCorrector(mode="code")
    >>> sanitized = corrector.correct(llm_response, shield_result)
    >>> # Response now respects all original constraints
    """

    # Pre-compiled entity patterns (shared across instances for efficiency)
    # Each pattern uses word boundaries (\b) for exact token matching
    _BASE_PATTERNS: dict[str, re.Pattern] = {
        "ip": re.compile(
            r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
            r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)'
            r'|\b(?:[A-F0-9]{1,4}:){7}[A-F0-9]{1,4}\b',
            flags=re.IGNORECASE
        ),
        "date": re.compile(
            r'\b\d{4}-\d{2}-\d{2}\b'      # ISO: 2025-06-15
            r'|\b\d{2}/\d{2}/\d{4}\b'     # DD/MM/YYYY
            r'|\b\d{2}\.\d{2}\.\d{4}\b'   # DD.MM.YYYY
        ),
        "price": re.compile(
            r'(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)\s*[\$\€\£\¥]?\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?'
            r'|[\$\€\£\¥]\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?',
            flags=re.IGNORECASE
        ),
        "hash": re.compile(r'\b[A-Fa-f0-9]{32,64}\b'),
        "percentage": re.compile(r'\b\d+(?:\.\d+)?\s*%\b'),
    }

    # Sanitization markers (configurable for audit trail)
    _ENTITY_REDACT_MARKER: str = "[REDACTED]"
    _FORBIDDEN_MARKER: str = "[PROHIBITED]"
    _MANDATE_NOTE_PREFIX: str = "[Note: required entity missing: "


    def __init__(self, mode: str = "general") -> None:
        """
        Initialize the ResponseCorrector.

        Parameters
        ----------
        mode : str, optional
            Domain mode for potential future extensions (default: "general").
            Currently affects logging; pattern set is uniform across modes.
        """
        self.mode = mode
        self.entity_patterns = self._build_entity_patterns()
        logger.debug(f"ResponseCorrector initialized (mode={mode})")


    @staticmethod
    def _build_entity_patterns() -> List[Tuple[str, re.Pattern]]:
        """
        Build the list of entity detection patterns.

        Returns
        -------
        List[Tuple[str, re.Pattern]]
            List of (entity_type, compiled_regex) pairs for detection.

        Pattern Specifications
        ---------------------
        - IP: IPv4 (dotted decimal) or IPv6 (hex groups) with word boundaries
        - Date: ISO 8601, DD/MM/YYYY, or DD.MM.YYYY formats
        - Price: Currency code + symbol + amount with optional decimals
        - Hash: 32-64 character hexadecimal strings (MD5, SHA-256, etc.)
        - Percentage: Numeric value followed by % symbol

        Complexity
        ----------
        Time: O(1) - constant number of pattern compilations
        Space: O(1) - fixed pattern set stored at class level
        """
        return [(name, pattern) for name, pattern in ResponseCorrector._BASE_PATTERNS.items()]


    def _extract_entities_from_text(self, text: str) -> Set[Tuple[str, str]]:
        """
        Extract typed entities from text using registered patterns.

        Parameters
        ----------
        text : str
            Text to scan for entities.

        Returns
        -------
        Set[Tuple[str, str]]
            Set of (entity_type, entity_value) pairs found in text.

        Complexity
        ----------
        Time: O(n · p) where n = text length, p = number of patterns
        Space: O(e) where e = number of unique entities found
        """
        found: Set[Tuple[str, str]] = set()
        for entity_type, pattern in self.entity_patterns:
            for match in pattern.finditer(text):
                found.add((entity_type, match.group()))
        return found


    def _reinject_placeholders(self, text: str, placeholder_map: dict[str, str]) -> str:
        """
        Replace placeholders with their original protected values.

        Processes placeholders in descending length order to prevent
        partial substitution (e.g., "__PROT_ab" matching inside "__PROT_abc").

        Parameters
        ----------
        text : str
            Text containing placeholders to replace.
        placeholder_map : Dict[str, str]
            Mapping: placeholder → original value.

        Returns
        -------
        str
            Text with all placeholders substituted.

        Mathematical Note
        -----------------
        Substitution order matters: if |p₁| > |p₂| and p₂ is a prefix of p₁,
        substituting p₂ first would corrupt p₁. Sorting by descending length
        ensures atomic replacement of longer tokens first.

        Complexity
        ----------
        Time: O(k · n · m) where k = placeholders, n = text length, m = avg placeholder length
        Space: O(n) for intermediate string during substitution
        """
        # Sort by descending length to avoid partial match conflicts
        sorted_placeholders = sorted(placeholder_map.keys(), key=len, reverse=True)

        result = text
        for placeholder in sorted_placeholders:
            value = placeholder_map[placeholder]
            # Escape special regex characters; case-sensitive match for placeholders
            pattern = re.escape(placeholder)
            result = re.sub(pattern, value, result)

        return result


    def _sanitize_unauthorized_entities(
        self,
        text: str,
        authorized_entities: Set[Tuple[str, str]]
    ) -> str:
        """
        Remove or redact entities not present in the authorized set.

        Parameters
        ----------
        text : str
            Text to sanitize.
        authorized_entities : Set[Tuple[str, str]]
            Set of (type, value) pairs that are permitted in output.

        Returns
        -------
        str
            Text with unauthorized entities replaced by redaction marker.

        Algorithm
        ---------
        1. Extract all entities from response text
        2. Compute set difference: unauthorized = found \ authorized
        3. Replace each unauthorized value with [REDACTED] marker

        Note: Replacement is value-based (not type-based) to avoid
        over-redaction when same entity type appears legitimately.

        Complexity
        ----------
        Time: O(n · p + u · n) where n = text length, p = patterns, u = unauthorized entities
        Space: O(u) for unauthorized entity set
        """
        # Extract entities present in response
        response_entities = self._extract_entities_from_text(text)

        # Identify unauthorized: in response but not in original
        unauthorized = response_entities - authorized_entities

        result = text
        for entity_type, value in unauthorized:
            # Escape value for safe regex substitution
            pattern = re.escape(value)
            result = re.sub(pattern, self._ENTITY_REDACT_MARKER, result)

        return result


    def _enforce_restrictions(self, text: str, restrictions: List[Restriction]) -> str:
        """
        Apply FORBID/MANDATE semantic restrictions to the response.

        Parameters
        ----------
        text : str
            Response text to constrain.
        restrictions : List[Restriction]
            List of semantic constraints from prompt analysis.

        Returns
        -------
        str
            Text with restrictions enforced.

        Enforcement Rules
        -----------------
        FORBID: Remove all case-insensitive occurrences of the entity.
                Uses word-boundary regex to avoid partial matches.
        
        MANDATE: If entity is absent, append a note requiring its use.
                Does not modify existing content; only adds guidance.

        Complexity
        ----------
        Time: O(r · n · m) where r = restrictions, n = text length, m = avg entity length
        Space: O(n) for intermediate string during substitutions
        """
        result = text

        for restriction in restrictions:
            entity = re.escape(restriction.entity)
            word_boundary_pattern = r'\b' + entity + r'\b'

            if restriction.type == "FORBID":
                # Remove all occurrences (case-insensitive, word-boundary matched)
                result = re.sub(
                    word_boundary_pattern,
                    self._FORBIDDEN_MARKER,
                    result,
                    flags=re.IGNORECASE
                )
                logger.debug(f"Enforced FORBID restriction: '{restriction.entity}'")

            elif restriction.type == "MANDATE":
                # Check presence (case-insensitive substring match for flexibility)
                if restriction.entity.lower() not in result.lower():
                    # Append mandate note to guide downstream processing
                    note = f"{self._MANDATE_NOTE_PREFIX}{restriction.entity}]"
                    result = result.rstrip() + "\n" + note
                    logger.debug(f"Enforced MANDATE restriction: '{restriction.entity}'")

        return result


    def _normalize_whitespace(self, text: str) -> str:
        """
        Normalize whitespace and punctuation artifacts from substitutions.

        Operations:
        - Collapse multiple spaces to single space
        - Reduce multiple newlines to single newline
        - Strip leading/trailing whitespace

        Parameters
        ----------
        text : str
            Text to normalize.

        Returns
        -------
        str
            Cleaned text with consistent formatting.

        Complexity
        ----------
        Time: O(n) where n = text length
        Space: O(n) for output string
        """
        # Collapse multiple spaces
        text = re.sub(r' +', ' ', text)
        # Reduce multiple newlines (with optional whitespace) to single newline
        text = re.sub(r'\n\s*\n+', '\n', text)
        # Strip leading/trailing whitespace
        return text.strip()


    def correct(self, response_text: str, shield_result: ShieldResult) -> str:
        """
        Apply all correction steps to sanitize an LLM response.

        Pipeline:
        1. Re-inject protected placeholder values
        2. Extract authorized entities from original prompt
        3. Redact unauthorized entities in response
        4. Enforce FORBID/MANDATE semantic restrictions
        5. Normalize whitespace and formatting

        Parameters
        ----------
        response_text : str
            Raw response from the LLM to be corrected.
        shield_result : ShieldResult
            Result from PromptShield containing:
            - placeholder_map: for re-injection
            - entities: authorized entity set
            - restrictions: semantic constraints to enforce

        Returns
        -------
        str
            Sanitized response respecting all TruthTable constraints.

        Complexity
        ----------
        Overall: O(n · (k + p + r)) where:
            n = response length
            k = number of placeholders
            p = number of entity patterns
            r = number of restrictions

        Space: O(|E_auth| + k) for authorized entity set + placeholder cache

        Example
        -------
        >>> corrector = ResponseCorrector()
        >>> sanitized = corrector.correct(
        ...     "The server IP is 192.168.1.1 and we use Python.",
        ...     shield_result
        ... )
        >>> # If 192.168.1.1 was authorized but Python was forbidden:
        >>> # Output: "The server IP is 192.168.1.1 and we use [PROHIBITED]."
        """
        # Stage 1: Re-inject protected placeholder values
        text = self._reinject_placeholders(response_text, shield_result.placeholder_map)

        # Stage 2: Build authorized entity set from original prompt
        authorized_entities: Set[Tuple[str, str]] = set()
        if hasattr(shield_result, 'entities'):
            for entity in shield_result.entities:
                authorized_entities.add((entity.entity_type, entity.value))

        # Stage 3: Redact entities not in authorized set
        text = self._sanitize_unauthorized_entities(text, authorized_entities)

        # Stage 4: Enforce semantic restrictions (FORBID/MANDATE)
        if hasattr(shield_result, 'restrictions') and shield_result.restrictions:
            text = self._enforce_restrictions(text, shield_result.restrictions)

        # Stage 5: Normalize whitespace and formatting
        text = self._normalize_whitespace(text)

        logger.debug(f"Response correction complete: {len(response_text)}{len(text)} chars")
        return text