Spaces:
Running
Running
| """ | |
| Post-LLM response correction and sanitization module. | |
| This module applies the TruthTable constraints to LLM responses: | |
| 1. Re-injects protected placeholder values | |
| 2. Removes unauthorized entities not present in original prompt | |
| 3. Enforces FORBID/MANDATE semantic restrictions | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Placeholder Substitution: | |
| Given placeholder map M = {p₁→v₁, ..., pₖ→vₖ} and response R: | |
| R' = R[p₁→v₁][p₂→v₂]...[pₖ→vₖ] | |
| Order: process by descending |pᵢ| to avoid partial substitution conflicts. | |
| 2. Entity Authorization Check: | |
| Let E_orig = {(typeᵢ, valueᵢ)} from shielded prompt | |
| Let E_resp = {(typeⱼ, valueⱼ)} extracted from response | |
| Unauthorized: E_unauth = E_resp \ E_orig | |
| Action: Replace each (t, v) ∈ E_unauth with sanitization marker. | |
| 3. Restriction Enforcement: | |
| For restriction r with type T and entity e: | |
| if T = FORBID: R = R \ {occurrences of e} | |
| if T = MANDATE: if e ∉ R: R = R ∥ "[Note: must use e]" | |
| Where \ = set difference on text occurrences, ∥ = string concatenation. | |
| 4. Regex Pattern Complexity: | |
| Pattern matching: O(n · m) where n = text length, m = pattern length | |
| Multiple patterns: O(n · Σ|pᵢ|) with optimized regex engine (RE2-style) | |
| References | |
| ---------- | |
| [1] Cox, R. (2007). Regular Expression Matching Can Be Simple And Fast. | |
| https://swtch.com/~rsc/regexp/regexp1.html | |
| [2] Aho, A. V., & Corasick, M. J. (1975). Efficient string matching. | |
| Communications of the ACM, 18(6), 333-340. | |
| [3] OpenAI. (2024). Prompt injection and output sanitization best practices. | |
| https://platform.openai.com/docs/guides/safety | |
| Performance Characteristics | |
| --------------------------- | |
| - _build_entity_patterns(): O(1) - constant number of patterns | |
| - correct() full pipeline: O(n · (k + p + r)) where: | |
| n = response length, k = placeholders, p = entity patterns, r = restrictions | |
| - Memory: O(|E_orig| + |M|) for entity/placeholder lookup sets | |
| Author: IntelliDeep Lab Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| from typing import List, Optional, Set, Tuple | |
| from nlproxy.core.shield import ShieldResult | |
| from nlproxy.core.restriction import Restriction | |
| logger = logging.getLogger(__name__) | |
| class ResponseCorrector: | |
| """ | |
| Applies TruthTable constraints to sanitize LLM responses. | |
| This class ensures that responses respect the security and semantic | |
| constraints extracted from the original prompt: | |
| 1. Placeholder Re-injection: Restores protected values (code, PII, etc.) | |
| 2. Entity Sanitization: Removes entities not authorized in original prompt | |
| 3. Restriction Enforcement: Applies FORBID/MANDATE rules to final output | |
| Key Design Decisions | |
| -------------------- | |
| - Longest-first placeholder substitution prevents partial match corruption | |
| - Entity type + value tuple matching avoids false positives (e.g., same IP appearing legitimately) | |
| - Case-insensitive restriction matching for robust enforcement | |
| - Minimal output modification: only redact/add what's necessary | |
| Usage Example | |
| ------------- | |
| >>> corrector = ResponseCorrector(mode="code") | |
| >>> sanitized = corrector.correct(llm_response, shield_result) | |
| >>> # Response now respects all original constraints | |
| """ | |
| # Pre-compiled entity patterns (shared across instances for efficiency) | |
| # Each pattern uses word boundaries (\b) for exact token matching | |
| _BASE_PATTERNS: dict[str, re.Pattern] = { | |
| "ip": re.compile( | |
| r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}' | |
| r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)' | |
| r'|\b(?:[A-F0-9]{1,4}:){7}[A-F0-9]{1,4}\b', | |
| flags=re.IGNORECASE | |
| ), | |
| "date": re.compile( | |
| r'\b\d{4}-\d{2}-\d{2}\b' # ISO: 2025-06-15 | |
| r'|\b\d{2}/\d{2}/\d{4}\b' # DD/MM/YYYY | |
| r'|\b\d{2}\.\d{2}\.\d{4}\b' # DD.MM.YYYY | |
| ), | |
| "price": re.compile( | |
| r'(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)\s*[\$\€\£\¥]?\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?' | |
| r'|[\$\€\£\¥]\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?', | |
| flags=re.IGNORECASE | |
| ), | |
| "hash": re.compile(r'\b[A-Fa-f0-9]{32,64}\b'), | |
| "percentage": re.compile(r'\b\d+(?:\.\d+)?\s*%\b'), | |
| } | |
| # Sanitization markers (configurable for audit trail) | |
| _ENTITY_REDACT_MARKER: str = "[REDACTED]" | |
| _FORBIDDEN_MARKER: str = "[PROHIBITED]" | |
| _MANDATE_NOTE_PREFIX: str = "[Note: required entity missing: " | |
| def __init__(self, mode: str = "general") -> None: | |
| """ | |
| Initialize the ResponseCorrector. | |
| Parameters | |
| ---------- | |
| mode : str, optional | |
| Domain mode for potential future extensions (default: "general"). | |
| Currently affects logging; pattern set is uniform across modes. | |
| """ | |
| self.mode = mode | |
| self.entity_patterns = self._build_entity_patterns() | |
| logger.debug(f"ResponseCorrector initialized (mode={mode})") | |
| def _build_entity_patterns() -> List[Tuple[str, re.Pattern]]: | |
| """ | |
| Build the list of entity detection patterns. | |
| Returns | |
| ------- | |
| List[Tuple[str, re.Pattern]] | |
| List of (entity_type, compiled_regex) pairs for detection. | |
| Pattern Specifications | |
| --------------------- | |
| - IP: IPv4 (dotted decimal) or IPv6 (hex groups) with word boundaries | |
| - Date: ISO 8601, DD/MM/YYYY, or DD.MM.YYYY formats | |
| - Price: Currency code + symbol + amount with optional decimals | |
| - Hash: 32-64 character hexadecimal strings (MD5, SHA-256, etc.) | |
| - Percentage: Numeric value followed by % symbol | |
| Complexity | |
| ---------- | |
| Time: O(1) - constant number of pattern compilations | |
| Space: O(1) - fixed pattern set stored at class level | |
| """ | |
| return [(name, pattern) for name, pattern in ResponseCorrector._BASE_PATTERNS.items()] | |
| def _extract_entities_from_text(self, text: str) -> Set[Tuple[str, str]]: | |
| """ | |
| Extract typed entities from text using registered patterns. | |
| Parameters | |
| ---------- | |
| text : str | |
| Text to scan for entities. | |
| Returns | |
| ------- | |
| Set[Tuple[str, str]] | |
| Set of (entity_type, entity_value) pairs found in text. | |
| Complexity | |
| ---------- | |
| Time: O(n · p) where n = text length, p = number of patterns | |
| Space: O(e) where e = number of unique entities found | |
| """ | |
| found: Set[Tuple[str, str]] = set() | |
| for entity_type, pattern in self.entity_patterns: | |
| for match in pattern.finditer(text): | |
| found.add((entity_type, match.group())) | |
| return found | |
| def _reinject_placeholders(self, text: str, placeholder_map: dict[str, str]) -> str: | |
| """ | |
| Replace placeholders with their original protected values. | |
| Processes placeholders in descending length order to prevent | |
| partial substitution (e.g., "__PROT_ab" matching inside "__PROT_abc"). | |
| Parameters | |
| ---------- | |
| text : str | |
| Text containing placeholders to replace. | |
| placeholder_map : Dict[str, str] | |
| Mapping: placeholder → original value. | |
| Returns | |
| ------- | |
| str | |
| Text with all placeholders substituted. | |
| Mathematical Note | |
| ----------------- | |
| Substitution order matters: if |p₁| > |p₂| and p₂ is a prefix of p₁, | |
| substituting p₂ first would corrupt p₁. Sorting by descending length | |
| ensures atomic replacement of longer tokens first. | |
| Complexity | |
| ---------- | |
| Time: O(k · n · m) where k = placeholders, n = text length, m = avg placeholder length | |
| Space: O(n) for intermediate string during substitution | |
| """ | |
| # Sort by descending length to avoid partial match conflicts | |
| sorted_placeholders = sorted(placeholder_map.keys(), key=len, reverse=True) | |
| result = text | |
| for placeholder in sorted_placeholders: | |
| value = placeholder_map[placeholder] | |
| # Escape special regex characters; case-sensitive match for placeholders | |
| pattern = re.escape(placeholder) | |
| result = re.sub(pattern, value, result) | |
| return result | |
| def _sanitize_unauthorized_entities( | |
| self, | |
| text: str, | |
| authorized_entities: Set[Tuple[str, str]] | |
| ) -> str: | |
| """ | |
| Remove or redact entities not present in the authorized set. | |
| Parameters | |
| ---------- | |
| text : str | |
| Text to sanitize. | |
| authorized_entities : Set[Tuple[str, str]] | |
| Set of (type, value) pairs that are permitted in output. | |
| Returns | |
| ------- | |
| str | |
| Text with unauthorized entities replaced by redaction marker. | |
| Algorithm | |
| --------- | |
| 1. Extract all entities from response text | |
| 2. Compute set difference: unauthorized = found \ authorized | |
| 3. Replace each unauthorized value with [REDACTED] marker | |
| Note: Replacement is value-based (not type-based) to avoid | |
| over-redaction when same entity type appears legitimately. | |
| Complexity | |
| ---------- | |
| Time: O(n · p + u · n) where n = text length, p = patterns, u = unauthorized entities | |
| Space: O(u) for unauthorized entity set | |
| """ | |
| # Extract entities present in response | |
| response_entities = self._extract_entities_from_text(text) | |
| # Identify unauthorized: in response but not in original | |
| unauthorized = response_entities - authorized_entities | |
| result = text | |
| for entity_type, value in unauthorized: | |
| # Escape value for safe regex substitution | |
| pattern = re.escape(value) | |
| result = re.sub(pattern, self._ENTITY_REDACT_MARKER, result) | |
| return result | |
| def _enforce_restrictions(self, text: str, restrictions: List[Restriction]) -> str: | |
| """ | |
| Apply FORBID/MANDATE semantic restrictions to the response. | |
| Parameters | |
| ---------- | |
| text : str | |
| Response text to constrain. | |
| restrictions : List[Restriction] | |
| List of semantic constraints from prompt analysis. | |
| Returns | |
| ------- | |
| str | |
| Text with restrictions enforced. | |
| Enforcement Rules | |
| ----------------- | |
| FORBID: Remove all case-insensitive occurrences of the entity. | |
| Uses word-boundary regex to avoid partial matches. | |
| MANDATE: If entity is absent, append a note requiring its use. | |
| Does not modify existing content; only adds guidance. | |
| Complexity | |
| ---------- | |
| Time: O(r · n · m) where r = restrictions, n = text length, m = avg entity length | |
| Space: O(n) for intermediate string during substitutions | |
| """ | |
| result = text | |
| for restriction in restrictions: | |
| entity = re.escape(restriction.entity) | |
| word_boundary_pattern = r'\b' + entity + r'\b' | |
| if restriction.type == "FORBID": | |
| # Remove all occurrences (case-insensitive, word-boundary matched) | |
| result = re.sub( | |
| word_boundary_pattern, | |
| self._FORBIDDEN_MARKER, | |
| result, | |
| flags=re.IGNORECASE | |
| ) | |
| logger.debug(f"Enforced FORBID restriction: '{restriction.entity}'") | |
| elif restriction.type == "MANDATE": | |
| # Check presence (case-insensitive substring match for flexibility) | |
| if restriction.entity.lower() not in result.lower(): | |
| # Append mandate note to guide downstream processing | |
| note = f"{self._MANDATE_NOTE_PREFIX}{restriction.entity}]" | |
| result = result.rstrip() + "\n" + note | |
| logger.debug(f"Enforced MANDATE restriction: '{restriction.entity}'") | |
| return result | |
| def _normalize_whitespace(self, text: str) -> str: | |
| """ | |
| Normalize whitespace and punctuation artifacts from substitutions. | |
| Operations: | |
| - Collapse multiple spaces to single space | |
| - Reduce multiple newlines to single newline | |
| - Strip leading/trailing whitespace | |
| Parameters | |
| ---------- | |
| text : str | |
| Text to normalize. | |
| Returns | |
| ------- | |
| str | |
| Cleaned text with consistent formatting. | |
| Complexity | |
| ---------- | |
| Time: O(n) where n = text length | |
| Space: O(n) for output string | |
| """ | |
| # Collapse multiple spaces | |
| text = re.sub(r' +', ' ', text) | |
| # Reduce multiple newlines (with optional whitespace) to single newline | |
| text = re.sub(r'\n\s*\n+', '\n', text) | |
| # Strip leading/trailing whitespace | |
| return text.strip() | |
| def correct(self, response_text: str, shield_result: ShieldResult) -> str: | |
| """ | |
| Apply all correction steps to sanitize an LLM response. | |
| Pipeline: | |
| 1. Re-inject protected placeholder values | |
| 2. Extract authorized entities from original prompt | |
| 3. Redact unauthorized entities in response | |
| 4. Enforce FORBID/MANDATE semantic restrictions | |
| 5. Normalize whitespace and formatting | |
| Parameters | |
| ---------- | |
| response_text : str | |
| Raw response from the LLM to be corrected. | |
| shield_result : ShieldResult | |
| Result from PromptShield containing: | |
| - placeholder_map: for re-injection | |
| - entities: authorized entity set | |
| - restrictions: semantic constraints to enforce | |
| Returns | |
| ------- | |
| str | |
| Sanitized response respecting all TruthTable constraints. | |
| Complexity | |
| ---------- | |
| Overall: O(n · (k + p + r)) where: | |
| n = response length | |
| k = number of placeholders | |
| p = number of entity patterns | |
| r = number of restrictions | |
| Space: O(|E_auth| + k) for authorized entity set + placeholder cache | |
| Example | |
| ------- | |
| >>> corrector = ResponseCorrector() | |
| >>> sanitized = corrector.correct( | |
| ... "The server IP is 192.168.1.1 and we use Python.", | |
| ... shield_result | |
| ... ) | |
| >>> # If 192.168.1.1 was authorized but Python was forbidden: | |
| >>> # Output: "The server IP is 192.168.1.1 and we use [PROHIBITED]." | |
| """ | |
| # Stage 1: Re-inject protected placeholder values | |
| text = self._reinject_placeholders(response_text, shield_result.placeholder_map) | |
| # Stage 2: Build authorized entity set from original prompt | |
| authorized_entities: Set[Tuple[str, str]] = set() | |
| if hasattr(shield_result, 'entities'): | |
| for entity in shield_result.entities: | |
| authorized_entities.add((entity.entity_type, entity.value)) | |
| # Stage 3: Redact entities not in authorized set | |
| text = self._sanitize_unauthorized_entities(text, authorized_entities) | |
| # Stage 4: Enforce semantic restrictions (FORBID/MANDATE) | |
| if hasattr(shield_result, 'restrictions') and shield_result.restrictions: | |
| text = self._enforce_restrictions(text, shield_result.restrictions) | |
| # Stage 5: Normalize whitespace and formatting | |
| text = self._normalize_whitespace(text) | |
| logger.debug(f"Response correction complete: {len(response_text)} → {len(text)} chars") | |
| return text | |