""" Post-LLM response correction and sanitization module. This module applies the TruthTable constraints to LLM responses: 1. Re-injects protected placeholder values 2. Removes unauthorized entities not present in original prompt 3. Enforces FORBID/MANDATE semantic restrictions Mathematical Foundations ------------------------ 1. Placeholder Substitution: Given placeholder map M = {p₁→v₁, ..., pₖ→vₖ} and response R: R' = R[p₁→v₁][p₂→v₂]...[pₖ→vₖ] Order: process by descending |pᵢ| to avoid partial substitution conflicts. 2. Entity Authorization Check: Let E_orig = {(typeᵢ, valueᵢ)} from shielded prompt Let E_resp = {(typeⱼ, valueⱼ)} extracted from response Unauthorized: E_unauth = E_resp \ E_orig Action: Replace each (t, v) ∈ E_unauth with sanitization marker. 3. Restriction Enforcement: For restriction r with type T and entity e: if T = FORBID: R = R \ {occurrences of e} if T = MANDATE: if e ∉ R: R = R ∥ "[Note: must use e]" Where \ = set difference on text occurrences, ∥ = string concatenation. 4. Regex Pattern Complexity: Pattern matching: O(n · m) where n = text length, m = pattern length Multiple patterns: O(n · Σ|pᵢ|) with optimized regex engine (RE2-style) References ---------- [1] Cox, R. (2007). Regular Expression Matching Can Be Simple And Fast. https://swtch.com/~rsc/regexp/regexp1.html [2] Aho, A. V., & Corasick, M. J. (1975). Efficient string matching. Communications of the ACM, 18(6), 333-340. [3] OpenAI. (2024). Prompt injection and output sanitization best practices. https://platform.openai.com/docs/guides/safety Performance Characteristics --------------------------- - _build_entity_patterns(): O(1) - constant number of patterns - correct() full pipeline: O(n · (k + p + r)) where: n = response length, k = placeholders, p = entity patterns, r = restrictions - Memory: O(|E_orig| + |M|) for entity/placeholder lookup sets Author: IntelliDeep Lab Team License: BSL 1.1 """ from __future__ import annotations import logging import re from typing import List, Optional, Set, Tuple from nlproxy.core.shield import ShieldResult from nlproxy.core.restriction import Restriction logger = logging.getLogger(__name__) class ResponseCorrector: """ Applies TruthTable constraints to sanitize LLM responses. This class ensures that responses respect the security and semantic constraints extracted from the original prompt: 1. Placeholder Re-injection: Restores protected values (code, PII, etc.) 2. Entity Sanitization: Removes entities not authorized in original prompt 3. Restriction Enforcement: Applies FORBID/MANDATE rules to final output Key Design Decisions -------------------- - Longest-first placeholder substitution prevents partial match corruption - Entity type + value tuple matching avoids false positives (e.g., same IP appearing legitimately) - Case-insensitive restriction matching for robust enforcement - Minimal output modification: only redact/add what's necessary Usage Example ------------- >>> corrector = ResponseCorrector(mode="code") >>> sanitized = corrector.correct(llm_response, shield_result) >>> # Response now respects all original constraints """ # Pre-compiled entity patterns (shared across instances for efficiency) # Each pattern uses word boundaries (\b) for exact token matching _BASE_PATTERNS: dict[str, re.Pattern] = { "ip": re.compile( r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}' r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)' r'|\b(?:[A-F0-9]{1,4}:){7}[A-F0-9]{1,4}\b', flags=re.IGNORECASE ), "date": re.compile( r'\b\d{4}-\d{2}-\d{2}\b' # ISO: 2025-06-15 r'|\b\d{2}/\d{2}/\d{4}\b' # DD/MM/YYYY r'|\b\d{2}\.\d{2}\.\d{4}\b' # DD.MM.YYYY ), "price": re.compile( r'(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)\s*[\$\€\£\¥]?\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?' r'|[\$\€\£\¥]\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?', flags=re.IGNORECASE ), "hash": re.compile(r'\b[A-Fa-f0-9]{32,64}\b'), "percentage": re.compile(r'\b\d+(?:\.\d+)?\s*%\b'), } # Sanitization markers (configurable for audit trail) _ENTITY_REDACT_MARKER: str = "[REDACTED]" _FORBIDDEN_MARKER: str = "[PROHIBITED]" _MANDATE_NOTE_PREFIX: str = "[Note: required entity missing: " def __init__(self, mode: str = "general") -> None: """ Initialize the ResponseCorrector. Parameters ---------- mode : str, optional Domain mode for potential future extensions (default: "general"). Currently affects logging; pattern set is uniform across modes. """ self.mode = mode self.entity_patterns = self._build_entity_patterns() logger.debug(f"ResponseCorrector initialized (mode={mode})") @staticmethod def _build_entity_patterns() -> List[Tuple[str, re.Pattern]]: """ Build the list of entity detection patterns. Returns ------- List[Tuple[str, re.Pattern]] List of (entity_type, compiled_regex) pairs for detection. Pattern Specifications --------------------- - IP: IPv4 (dotted decimal) or IPv6 (hex groups) with word boundaries - Date: ISO 8601, DD/MM/YYYY, or DD.MM.YYYY formats - Price: Currency code + symbol + amount with optional decimals - Hash: 32-64 character hexadecimal strings (MD5, SHA-256, etc.) - Percentage: Numeric value followed by % symbol Complexity ---------- Time: O(1) - constant number of pattern compilations Space: O(1) - fixed pattern set stored at class level """ return [(name, pattern) for name, pattern in ResponseCorrector._BASE_PATTERNS.items()] def _extract_entities_from_text(self, text: str) -> Set[Tuple[str, str]]: """ Extract typed entities from text using registered patterns. Parameters ---------- text : str Text to scan for entities. Returns ------- Set[Tuple[str, str]] Set of (entity_type, entity_value) pairs found in text. Complexity ---------- Time: O(n · p) where n = text length, p = number of patterns Space: O(e) where e = number of unique entities found """ found: Set[Tuple[str, str]] = set() for entity_type, pattern in self.entity_patterns: for match in pattern.finditer(text): found.add((entity_type, match.group())) return found def _reinject_placeholders(self, text: str, placeholder_map: dict[str, str]) -> str: """ Replace placeholders with their original protected values. Processes placeholders in descending length order to prevent partial substitution (e.g., "__PROT_ab" matching inside "__PROT_abc"). Parameters ---------- text : str Text containing placeholders to replace. placeholder_map : Dict[str, str] Mapping: placeholder → original value. Returns ------- str Text with all placeholders substituted. Mathematical Note ----------------- Substitution order matters: if |p₁| > |p₂| and p₂ is a prefix of p₁, substituting p₂ first would corrupt p₁. Sorting by descending length ensures atomic replacement of longer tokens first. Complexity ---------- Time: O(k · n · m) where k = placeholders, n = text length, m = avg placeholder length Space: O(n) for intermediate string during substitution """ # Sort by descending length to avoid partial match conflicts sorted_placeholders = sorted(placeholder_map.keys(), key=len, reverse=True) result = text for placeholder in sorted_placeholders: value = placeholder_map[placeholder] # Escape special regex characters; case-sensitive match for placeholders pattern = re.escape(placeholder) result = re.sub(pattern, value, result) return result def _sanitize_unauthorized_entities( self, text: str, authorized_entities: Set[Tuple[str, str]] ) -> str: """ Remove or redact entities not present in the authorized set. Parameters ---------- text : str Text to sanitize. authorized_entities : Set[Tuple[str, str]] Set of (type, value) pairs that are permitted in output. Returns ------- str Text with unauthorized entities replaced by redaction marker. Algorithm --------- 1. Extract all entities from response text 2. Compute set difference: unauthorized = found \ authorized 3. Replace each unauthorized value with [REDACTED] marker Note: Replacement is value-based (not type-based) to avoid over-redaction when same entity type appears legitimately. Complexity ---------- Time: O(n · p + u · n) where n = text length, p = patterns, u = unauthorized entities Space: O(u) for unauthorized entity set """ # Extract entities present in response response_entities = self._extract_entities_from_text(text) # Identify unauthorized: in response but not in original unauthorized = response_entities - authorized_entities result = text for entity_type, value in unauthorized: # Escape value for safe regex substitution pattern = re.escape(value) result = re.sub(pattern, self._ENTITY_REDACT_MARKER, result) return result def _enforce_restrictions(self, text: str, restrictions: List[Restriction]) -> str: """ Apply FORBID/MANDATE semantic restrictions to the response. Parameters ---------- text : str Response text to constrain. restrictions : List[Restriction] List of semantic constraints from prompt analysis. Returns ------- str Text with restrictions enforced. Enforcement Rules ----------------- FORBID: Remove all case-insensitive occurrences of the entity. Uses word-boundary regex to avoid partial matches. MANDATE: If entity is absent, append a note requiring its use. Does not modify existing content; only adds guidance. Complexity ---------- Time: O(r · n · m) where r = restrictions, n = text length, m = avg entity length Space: O(n) for intermediate string during substitutions """ result = text for restriction in restrictions: entity = re.escape(restriction.entity) word_boundary_pattern = r'\b' + entity + r'\b' if restriction.type == "FORBID": # Remove all occurrences (case-insensitive, word-boundary matched) result = re.sub( word_boundary_pattern, self._FORBIDDEN_MARKER, result, flags=re.IGNORECASE ) logger.debug(f"Enforced FORBID restriction: '{restriction.entity}'") elif restriction.type == "MANDATE": # Check presence (case-insensitive substring match for flexibility) if restriction.entity.lower() not in result.lower(): # Append mandate note to guide downstream processing note = f"{self._MANDATE_NOTE_PREFIX}{restriction.entity}]" result = result.rstrip() + "\n" + note logger.debug(f"Enforced MANDATE restriction: '{restriction.entity}'") return result def _normalize_whitespace(self, text: str) -> str: """ Normalize whitespace and punctuation artifacts from substitutions. Operations: - Collapse multiple spaces to single space - Reduce multiple newlines to single newline - Strip leading/trailing whitespace Parameters ---------- text : str Text to normalize. Returns ------- str Cleaned text with consistent formatting. Complexity ---------- Time: O(n) where n = text length Space: O(n) for output string """ # Collapse multiple spaces text = re.sub(r' +', ' ', text) # Reduce multiple newlines (with optional whitespace) to single newline text = re.sub(r'\n\s*\n+', '\n', text) # Strip leading/trailing whitespace return text.strip() def correct(self, response_text: str, shield_result: ShieldResult) -> str: """ Apply all correction steps to sanitize an LLM response. Pipeline: 1. Re-inject protected placeholder values 2. Extract authorized entities from original prompt 3. Redact unauthorized entities in response 4. Enforce FORBID/MANDATE semantic restrictions 5. Normalize whitespace and formatting Parameters ---------- response_text : str Raw response from the LLM to be corrected. shield_result : ShieldResult Result from PromptShield containing: - placeholder_map: for re-injection - entities: authorized entity set - restrictions: semantic constraints to enforce Returns ------- str Sanitized response respecting all TruthTable constraints. Complexity ---------- Overall: O(n · (k + p + r)) where: n = response length k = number of placeholders p = number of entity patterns r = number of restrictions Space: O(|E_auth| + k) for authorized entity set + placeholder cache Example ------- >>> corrector = ResponseCorrector() >>> sanitized = corrector.correct( ... "The server IP is 192.168.1.1 and we use Python.", ... shield_result ... ) >>> # If 192.168.1.1 was authorized but Python was forbidden: >>> # Output: "The server IP is 192.168.1.1 and we use [PROHIBITED]." """ # Stage 1: Re-inject protected placeholder values text = self._reinject_placeholders(response_text, shield_result.placeholder_map) # Stage 2: Build authorized entity set from original prompt authorized_entities: Set[Tuple[str, str]] = set() if hasattr(shield_result, 'entities'): for entity in shield_result.entities: authorized_entities.add((entity.entity_type, entity.value)) # Stage 3: Redact entities not in authorized set text = self._sanitize_unauthorized_entities(text, authorized_entities) # Stage 4: Enforce semantic restrictions (FORBID/MANDATE) if hasattr(shield_result, 'restrictions') and shield_result.restrictions: text = self._enforce_restrictions(text, shield_result.restrictions) # Stage 5: Normalize whitespace and formatting text = self._normalize_whitespace(text) logger.debug(f"Response correction complete: {len(response_text)} → {len(text)} chars") return text