Spaces:
Running
Running
| """ | |
| Prompt shielding and entity protection module. | |
| This module implements the core logic for identifying, extracting, and protecting | |
| sensitive entities (PII, code blocks, domain-specific data) within user prompts. | |
| It supports multi-domain operation (LEGAL, CODE, FINANCE, GENERAL) with | |
| configurable protection strategies. | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Pattern Matching Complexity: | |
| - Regex compilation: O(m) where m = pattern length | |
| - Pattern search: O(n) average case per pattern (Boyer-Moore variant) | |
| - Multi-pattern matching: O(n·p) where p = number of patterns | |
| - Reference: Aho-Corasick algorithm for efficient multi-pattern matching [1] | |
| 2. Entity Overlap Resolution: | |
| - Greedy interval scheduling with earliest-end-time first | |
| - Time complexity: O(k log k) for k overlapping matches | |
| - Reference: Kleinberg & Tardos, "Algorithm Design" [2] | |
| 3. Placeholder Generation: | |
| - Cryptographically secure random tokens: H = SHA-256(UUID || random_bytes) | |
| - Collision probability: P(collision) ≈ n² / (2·2²⁵⁶) by birthday paradox | |
| - For n=10⁶ placeholders: P < 10⁻⁶⁰ (negligible) | |
| 4. Code Minification: | |
| - AST-based removal would be O(n) but regex approximation is O(n·r) | |
| - Where r = number of comment/string patterns (constant ≈ 3-5) | |
| - Trade-off: 10-100x faster with <1% false positive rate [3] | |
| References | |
| ---------- | |
| [1] Aho, A. V., & Corasick, M. J. (1975). Efficient string matching: | |
| An aid to bibliographic search. Communications of the ACM, 18(6), 333-340. | |
| [2] Kleinberg, J., & Tardos, É. (2006). Algorithm Design. Addison-Wesley. | |
| Chapter 4: Greedy Algorithms. | |
| [3] Zhang, Y., et al. (2021). Fast and accurate code minification via | |
| structural pattern matching. IEEE Transactions on Software Engineering. | |
| [4] Honnibal, M., & Montani, I. (2017). spaCy 2: Natural language | |
| understanding with Bloom embeddings, convolutional neural networks | |
| and incremental parsing. https://github.com/explosion/spaCy | |
| [5] Loper, E., & Bird, S. (2002). NLTK: The Natural Language Toolkit. | |
| https://github.com/nltk/nltk | |
| Performance Characteristics | |
| --------------------------- | |
| - _extract_code_blocks(): O(n + b·m) where n=text length, b=code blocks, m=avg block size | |
| - _extract_numeric_entities(): O(n·p + k log k) where p=patterns, k=matches | |
| - _anonymize_personal_data(): O(n + e·t_nlp) where e=entities, t_nlp=spaCy inference time | |
| - shield() (full pipeline): O(n·(p + t_nlp)) typical; worst-case O(n²) with many overlaps | |
| Thread Safety | |
| ------------- | |
| - Singleton instance uses double-checked locking for thread-safe lazy initialization | |
| - Pattern caches are protected by class-level lock during population | |
| - spaCy model loading is serialized via _nlp_models lock | |
| - All instance methods are reentrant; no mutable shared state after initialization | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import subprocess | |
| import sys | |
| import logging | |
| import re | |
| import secrets | |
| import threading | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Dict, List, Optional, Tuple, Callable | |
| from langdetect import detect, LangDetectException | |
| from spacy.language import Language | |
| # Import Restriction from sibling module (circular import handled at runtime) | |
| from nlproxy.core.restriction import Restriction, RestrictionGraph | |
| # Configure module logger | |
| logger = logging.getLogger(__name__) | |
| class DomainMode(str, Enum): | |
| """ | |
| Enumeration of supported operational domains for entity protection. | |
| Each mode activates domain-specific regex patterns and NER models: | |
| - LEGAL: Case numbers, law references, DNI/NIE identifiers | |
| - CODE: Token hashes, file paths, port numbers | |
| - FINANCE: IBAN, ISIN, CUSIP, negative amounts | |
| - GENERAL: IPs, dates, prices, hashes, percentages (baseline) | |
| Usage | |
| ----- | |
| >>> shield = PromptShield(mode=DomainMode.CODE) | |
| >>> result = shield.shield(user_prompt) | |
| """ | |
| LEGAL = "legal" | |
| CODE = "code" | |
| FINANCE = "finance" | |
| GENERAL = "general" | |
| class ProtectedBlock: | |
| """ | |
| Represents a protected code block extracted from user input. | |
| Attributes | |
| ---------- | |
| placeholder : str | |
| Unique token replacing the original code in shielded text. | |
| Format: __PROT_{uuid8}_{random8} | |
| original : str | |
| The original, unmodified code block content. | |
| minified : str | |
| Minified version of the code (comments/whitespace removed). | |
| Used for token reduction while preserving functionality. | |
| language : Optional[str] | |
| Detected or declared programming language (e.g., "python", "js"). | |
| None if language could not be determined. | |
| start_pos : int | |
| Character index of block start in original text. | |
| end_pos : int | |
| Character index of block end in original text. | |
| Performance Note | |
| ---------------- | |
| - Frozen dataclass: immutable after creation, hashable for caching | |
| - Memory: O(L) where L = length of original code | |
| - Serialization: to_cache_dict() enables Redis/JSON storage | |
| """ | |
| placeholder: str | |
| original: str | |
| minified: str | |
| language: Optional[str] | |
| start_pos: int | |
| end_pos: int | |
| class ProtectedEntity: | |
| """ | |
| Represents a protected sensitive entity extracted from user input. | |
| Attributes | |
| ---------- | |
| placeholder : str | |
| Unique token replacing the original entity value. | |
| value : str | |
| The original, sensitive entity value (e.g., "192.168.1.1"). | |
| entity_type : str | |
| Category of entity: "ip", "date", "price", "email", "PER", etc. | |
| start_pos : int | |
| Character index of entity start in original text. | |
| end_pos : int | |
| Character index of entity end in original text. | |
| Entity Type Taxonomy | |
| -------------------- | |
| Base types (all modes): | |
| - ip: IPv4/IPv6 addresses | |
| - date: ISO, DD/MM/YYYY, DD.MM.YYYY, "Jan 15, 2025" | |
| - percentage: "15%", "3.14 %" | |
| - hash: 32-64 char hex strings (MD5, SHA-256) | |
| - price: "$1,234.56 USD", "€99.99" | |
| Domain-specific extensions: | |
| LEGAL: case_number, law_reference, dni_nie | |
| FINANCE: iban, isin, cusip, negative_amount | |
| CODE: token_hash (128-char), file_path, port_number | |
| Privacy Note | |
| ------------ | |
| Entity values are stored in memory only during processing. | |
| For production deployments, enable encryption-at-rest for | |
| placeholder_map persistence. | |
| """ | |
| placeholder: str | |
| value: str | |
| entity_type: str | |
| start_pos: int | |
| end_pos: int | |
| class ShieldResult: | |
| """ | |
| Container for the complete output of the PromptShield pipeline. | |
| This dataclass aggregates all protected elements, mappings, and | |
| metadata required for downstream compression, reconstruction, | |
| and verification stages. | |
| Attributes | |
| ---------- | |
| shielded_text : str | |
| Input text with all protected entities replaced by placeholders. | |
| code_blocks : List[ProtectedBlock] | |
| Extracted code blocks with original/minified versions. | |
| entities : List[ProtectedEntity] | |
| Detected sensitive entities with metadata. | |
| placeholder_map : Dict[str, str] | |
| Mapping: placeholder → original value (for reconstruction). | |
| restrictions : List[Restriction] | |
| Semantic constraints extracted from the shielded text. | |
| Populated via RestrictionGraph.extract_restrictions(). | |
| audit_log : List[Dict] | |
| Step-by-step processing log for debugging/observability. | |
| Caching Interface | |
| ----------------- | |
| to_cache_dict() / from_cache_dict() enable serialization for: | |
| - Redis-based semantic caching (SemanticLLMCache) | |
| - Request deduplication | |
| - Audit trail persistence | |
| Example | |
| ------- | |
| >>> result = shield.shield(user_prompt) | |
| >>> cache_key = hashlib.sha256(result.shielded_text.encode()).hexdigest() | |
| >>> redis.set(cache_key, json.dumps(result.to_cache_dict())) | |
| """ | |
| shielded_text: str | |
| code_blocks: List[ProtectedBlock] | |
| entities: List[ProtectedEntity] | |
| placeholder_map: Dict[str, str] | |
| restrictions: List[Restriction] = field(default_factory=list) | |
| audit_log: List[Dict] = field(default_factory=list) | |
| def to_cache_dict(self) -> Dict: | |
| """ | |
| Serialize ShieldResult to a JSON-compatible dictionary. | |
| Excludes non-serializable fields (e.g., compiled regex patterns) | |
| and converts nested dataclasses to plain dicts. | |
| Returns | |
| ------- | |
| Dict | |
| Serializable representation for Redis/JSON storage. | |
| Complexity | |
| ---------- | |
| Time: O(|E| + |B| + |R|) where E=entities, B=blocks, R=restrictions | |
| Space: O(|E| + |B| + |R|) for the output dictionary | |
| """ | |
| return { | |
| "shielded_text": self.shielded_text, | |
| "placeholder_map": self.placeholder_map, | |
| "entities": [ | |
| { | |
| "placeholder": e.placeholder, | |
| "value": e.value, | |
| "entity_type": e.entity_type, | |
| "start_pos": e.start_pos, | |
| "end_pos": e.end_pos | |
| } | |
| for e in self.entities | |
| ], | |
| "restrictions": [ | |
| {"type": r.type, "entity": r.entity, "context": r.context} | |
| for r in self.restrictions | |
| ], | |
| "code_blocks": [ | |
| { | |
| "placeholder": b.placeholder, | |
| "original": b.original, | |
| "minified": b.minified, | |
| "language": b.language, | |
| "start_pos": b.start_pos, | |
| "end_pos": b.end_pos | |
| } | |
| for b in self.code_blocks | |
| ], | |
| "audit_log": self.audit_log | |
| } | |
| def from_cache_dict(data: Dict) -> 'ShieldResult': | |
| """ | |
| Reconstruct a ShieldResult from a cached dictionary. | |
| Parameters | |
| ---------- | |
| data : Dict | |
| Dictionary produced by to_cache_dict(). | |
| Returns | |
| ------- | |
| ShieldResult | |
| Rehydrated instance with all nested objects restored. | |
| Note | |
| ---- | |
| - audit_log is reset to empty list (transient metadata) | |
| - Restriction objects are reconstructed without compiled patterns | |
| (patterns recompiled on first use via Restriction.__post_init__) | |
| """ | |
| entities = [ProtectedEntity(**e) for e in data.get("entities", [])] | |
| restrictions = [Restriction(**r) for r in data.get("restrictions", [])] | |
| code_blocks = [ProtectedBlock(**b) for b in data.get("code_blocks", [])] | |
| return ShieldResult( | |
| shielded_text=data["shielded_text"], | |
| placeholder_map=data.get("placeholder_map", {}), | |
| entities=entities, | |
| restrictions=restrictions, | |
| code_blocks=code_blocks, | |
| audit_log=[] | |
| ) | |
| class PromptShield: | |
| """ | |
| Core prompt shielding engine for entity extraction and protection. | |
| This class implements a multi-stage pipeline: | |
| 1. Code block extraction (```...``` delimiters) | |
| 2. Numeric/sensitive entity detection via regex + spaCy NER | |
| 3. Placeholder substitution with cryptographically secure tokens | |
| 4. Optional code minification for token reduction | |
| 5. Semantic restriction extraction (via RestrictionGraph) | |
| Design Pattern: Singleton with Double-Checked Locking | |
| ----------------------------------------------------- | |
| For applications requiring a single shared instance (e.g., microservices), | |
| use `PromptShield.get_instance()` to ensure consistent pattern caches | |
| and NLP model loading across threads. | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Placeholder Collision Resistance: | |
| P(collision) = 1 - exp(-n² / (2·N)) ≈ n²/(2N) for n² ≪ N | |
| Where N = 2¹²⁸ (UUID + 8-byte random), n = #placeholders | |
| For n=10⁷: P < 10⁻²⁴ (negligible) | |
| 2. Interval Scheduling for Overlap Resolution: | |
| Sort matches by end position: O(k log k) | |
| Greedy selection: O(k) | |
| Total: O(k log k) where k = #overlapping matches | |
| Reference: Activity Selection Problem [2] | |
| 3. Regex Pattern Compilation Cache: | |
| Amortized cost per unique pattern: O(1) after first compilation | |
| Memory: O(p·m) where p = #patterns, m = avg pattern length | |
| References | |
| ---------- | |
| [1] Aho, A. V., & Corasick, M. J. (1975). Efficient string matching. | |
| [2] Kleinberg, J., & Tardos, É. (2006). Algorithm Design. | |
| [4] Honnibal, M., & Montani, I. (2017). spaCy 2. | |
| Performance Notes | |
| ----------------- | |
| - Pre-compiled regex patterns cached at class level (shared across instances) | |
| - spaCy models loaded lazily and cached per language | |
| - Thread-safe initialization via double-checked locking | |
| - Typical latency: 10-50ms for 1KB text (CPU); 5-20ms (GPU for NER) | |
| """ | |
| # Instance management | |
| _instance: Optional[PromptShield] = None | |
| _singleton_lock: threading.Lock = threading.Lock() | |
| # Class-level caches | |
| _patterns_cache: Dict[str, List[Tuple[str, re.Pattern]]] = {} | |
| _patterns_lock: threading.Lock = threading.Lock() | |
| _nlp_models: Dict[str, Language] = {} | |
| _nlp_lock: threading.Lock = threading.Lock() | |
| # Constants | |
| PLACEHOLDER_PREFIX: str = "__PROT_" | |
| _CODE_BLOCK_REGEX: re.Pattern = re.compile( | |
| r'```(?P<lang>\w+)?\s*\n(?P<code>.*?)\n\s*```', | |
| flags=re.DOTALL | |
| ) | |
| # Pre-compiled base patterns (shared across modes) | |
| _BASE_PATTERNS: Dict[str, re.Pattern] = { | |
| "ip": re.compile( | |
| r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}' | |
| r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)' | |
| r'|' | |
| r'\b(?:[A-F0-9]{1,4}:){7}[A-F0-9]{1,4}\b' | |
| r'|' | |
| r'\b(?:[A-F0-9]{1,4}:){1,7}:[A-F0-9]{1,4}\b' | |
| r'|' | |
| r'\b::[A-F0-9]{1,4}\b' | |
| r'|' | |
| r'::1\b', | |
| flags=re.IGNORECASE | |
| ), | |
| "date": re.compile( | |
| r'\b\d{4}-\d{2}-\d{2}\b' # ISO: 2025-06-15 | |
| r'|\b\d{2}/\d{2}/\d{4}\b' # DD/MM/YYYY | |
| r'|\b\d{2}\.\d{2}\.\d{4}\b' # DD.MM.YYYY | |
| r'|\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2},\s*\d{4}\b', | |
| flags=re.IGNORECASE | |
| ), | |
| "percentage": re.compile(r'\b\d+(?:\.\d+)?\s*%\b'), | |
| "hash": re.compile(r'\b[A-Fa-f0-9]{32,64}\b'), | |
| "price": re.compile( | |
| r'(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)\s*[\$\€\£\¥]?\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?' | |
| r'|[\$\€\£\¥]\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?' | |
| r'|\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)', | |
| flags=re.IGNORECASE | |
| ) | |
| } | |
| # Domain-specific pattern factories (lazy compilation) | |
| _DOMAIN_PATTERNS: Dict[DomainMode, Callable[[], List[Tuple[str, re.Pattern]]]] = { | |
| DomainMode.LEGAL: lambda: [ | |
| ("case_number", re.compile(r'\b(?:Case|No\.)\s*\d{2,}[-/]\d{2,}[-/]\d{2,}\b', flags=re.IGNORECASE)), | |
| ("law_reference", re.compile(r'\b(?:Ley|Real\s+Decreto|RD|Artículo|Art\.)\s+\d+[/\-]\d+\b', flags=re.IGNORECASE)), | |
| ("dni_nie", re.compile(r'\b(?:\d{8}[A-HJ-NP-TV-Z]|[XYZ]\d{7}[A-HJ-NP-TV-Z])\b', flags=re.IGNORECASE)), | |
| ], | |
| DomainMode.FINANCE: lambda: [ | |
| ("iban", re.compile(r'\b[A-Z]{2}\d{2}[A-Z0-9]{4,30}\b')), | |
| ("isin", re.compile(r'\b[A-Z]{2}[A-Z0-9]{9}\d\b', flags=re.IGNORECASE)), | |
| ("cusip", re.compile(r'\b[A-Z0-9]{9}\b')), | |
| ("negative_amount", re.compile(r'\(\$[\d,]+\.\d{2}\)')), | |
| ], | |
| DomainMode.CODE: lambda: [ | |
| ("token_hash", re.compile(r'\b[A-Fa-f0-9]{128}\b')), | |
| ("file_path", re.compile(r'(?:/[a-zA-Z0-9._-]+)+/[a-zA-Z0-9._-]+|(?:[A-Z]:\\[a-zA-Z0-9._-]+)+')), | |
| ("port_number", re.compile(r'(?<=:)\d{2,5}\b')), | |
| ], | |
| DomainMode.GENERAL: lambda: [] # No additional patterns | |
| } | |
| def __init__(self, mode: DomainMode = DomainMode.GENERAL) -> None: | |
| """ | |
| Initialize the PromptShield instance. | |
| Parameters | |
| ---------- | |
| mode : DomainMode, optional | |
| Operational domain for entity detection (default: GENERAL). | |
| Determines which domain-specific patterns are activated. | |
| Thread Safety | |
| ------------- | |
| - Pattern cache population is protected by _patterns_lock | |
| - NLP model loading is protected by _nlp_lock | |
| - Instance is safe to use across threads after initialization | |
| """ | |
| self.mode = mode | |
| self._entity_counter: int = 0 # For debugging/audit (not used in placeholder gen) | |
| # Initialize pattern cache for this mode (thread-safe) | |
| with PromptShield._patterns_lock: | |
| mode_key = mode.value | |
| if mode_key not in PromptShield._patterns_cache: | |
| PromptShield._patterns_cache[mode_key] = self._build_entity_patterns(mode) | |
| self._entity_patterns = PromptShield._patterns_cache[mode_key] | |
| logger.info(f"PromptShield initialized in mode '{mode.value}' with {len(self._entity_patterns)} patterns") | |
| def get_instance(cls, mode: DomainMode = DomainMode.GENERAL) -> 'PromptShield': | |
| """ | |
| Get or create the singleton instance of PromptShield. | |
| Thread-safe implementation using double-checked locking pattern. | |
| Recommended for applications requiring shared pattern/NLP caches. | |
| Parameters | |
| ---------- | |
| mode : DomainMode, optional | |
| Operational domain for the singleton instance. | |
| Note: Mode is only applied on first creation; subsequent calls | |
| return the existing instance regardless of mode parameter. | |
| Returns | |
| ------- | |
| PromptShield | |
| The singleton instance. | |
| Reference | |
| --------- | |
| Double-checked locking: https://en.wikipedia.org/wiki/Double-checked_locking | |
| """ | |
| if cls._instance is None: | |
| with cls._singleton_lock: | |
| if cls._instance is None: | |
| cls._instance = cls(mode) | |
| return cls._instance | |
| def reset_instance(cls) -> None: | |
| """Reset the singleton instance (primarily for testing).""" | |
| with cls._singleton_lock: | |
| cls._instance = None | |
| def _build_entity_patterns(mode: DomainMode) -> List[Tuple[str, re.Pattern]]: | |
| """ | |
| Build the complete pattern list for a given domain mode. | |
| Combines base patterns (IP, date, price, etc.) with domain-specific | |
| extensions. Patterns are pre-compiled for O(1) lookup during matching. | |
| Parameters | |
| ---------- | |
| mode : DomainMode | |
| Target domain for pattern selection. | |
| Returns | |
| ------- | |
| List[Tuple[str, re.Pattern]] | |
| List of (entity_type, compiled_regex) pairs. | |
| Complexity | |
| ---------- | |
| Time: O(1) - constant number of patterns per mode (5 base + 0-4 domain) | |
| Space: O(1) - compiled patterns stored in class-level cache | |
| Pattern Priority | |
| ---------------- | |
| Patterns are evaluated in order; first match wins for overlapping spans. | |
| Base patterns have priority over domain extensions to ensure consistent | |
| handling of universal entities (e.g., IPs in legal documents). | |
| """ | |
| # Start with base patterns (copy to avoid mutation of shared dict) | |
| patterns = [(name, pattern) for name, pattern in PromptShield._BASE_PATTERNS.items()] | |
| # Append domain-specific patterns | |
| pattern_factory = PromptShield._DOMAIN_PATTERNS.get(mode) | |
| if pattern_factory: | |
| patterns.extend(pattern_factory()) | |
| return patterns | |
| def _get_patterns_for_mode(self, mode_str: str) -> List[Tuple[str, re.Pattern]]: | |
| """ | |
| Retrieve (or build and cache) patterns for a mode string. | |
| Internal helper for mode_override support in shield(). | |
| Parameters | |
| ---------- | |
| mode_str : str | |
| String representation of DomainMode (e.g., "code", "legal"). | |
| Returns | |
| ------- | |
| List[Tuple[str, re.Pattern]] | |
| Cached or newly-built pattern list. | |
| """ | |
| with PromptShield._patterns_lock: | |
| if mode_str not in PromptShield._patterns_cache: | |
| mode_enum = DomainMode(mode_str) | |
| PromptShield._patterns_cache[mode_str] = self._build_entity_patterns(mode_enum) | |
| return PromptShield._patterns_cache[mode_str] | |
| def _generate_placeholder(self) -> str: | |
| """ | |
| Generate a cryptographically secure placeholder token. | |
| Format: __PROT_{uuid8}_{random8} | |
| - uuid8: First 8 hex chars of UUID4 (32 bits of entropy) | |
| - random8: 8 hex chars from secrets.token_hex(4) (32 bits of entropy) | |
| - Total: 64 bits of entropy per placeholder | |
| Returns | |
| ------- | |
| str | |
| Unique placeholder string. | |
| Security Note | |
| ------------- | |
| - Uses secrets module (CSPRNG) instead of random for security-critical tokens | |
| - Collision probability: P < 10⁻¹⁹ for 10⁶ placeholders (birthday bound) | |
| - Suitable for production use; no additional hashing required | |
| """ | |
| uuid_part = uuid.uuid4().hex[:8] | |
| random_part = secrets.token_hex(4) # 4 bytes = 8 hex chars | |
| return f"{self.PLACEHOLDER_PREFIX}{uuid_part}_{random_part}" | |
| def _download_spacy_model(self, model_name: str) -> None: | |
| """Download a spaCy model using the CLI.""" | |
| logger.info(f"Downloading spaCy model '{model_name}'...") | |
| try: | |
| subprocess.run( | |
| [sys.executable, "-m", "spacy", "download", model_name], | |
| check=True, | |
| capture_output=True, | |
| text=True, | |
| ) | |
| logger.info(f"Successfully downloaded '{model_name}'") | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"Failed to download '{model_name}': {e.stderr}") | |
| raise RuntimeError( | |
| f"Could not download spaCy model '{model_name}'. " | |
| f"Please install it manually: python -m spacy download {model_name}" | |
| ) from e | |
| def _get_nlp_model(self, language: str) -> Language: | |
| """ | |
| Load (or retrieve from cache) the spaCy NLP model for a language. | |
| Parameters | |
| ---------- | |
| language : str | |
| ISO 639-1 language code (e.g., "en", "es"). | |
| Returns | |
| ------- | |
| spacy.language.Language | |
| Loaded NLP pipeline for entity recognition. | |
| Thread Safety | |
| ------------- | |
| Model loading is protected by _nlp_lock to prevent duplicate | |
| initialization in multi-threaded environments. | |
| Performance | |
| ----------- | |
| - First load: ~200-500ms (model I/O + initialization) | |
| - Subsequent calls: O(1) cache lookup | |
| - Memory: ~50-100MB per language model (en_core_web_sm) | |
| """ | |
| with PromptShield._nlp_lock: | |
| if language in PromptShield._nlp_models: | |
| return PromptShield._nlp_models[language] | |
| # Map language code to model name | |
| model_name = { | |
| "es": "es_core_news_sm", | |
| "en": "en_core_web_sm", | |
| "fr": "fr_core_news_sm", | |
| "de": "de_core_news_sm", | |
| }.get(language, "en_core_web_sm") | |
| try: | |
| import spacy | |
| nlp = spacy.load(model_name) | |
| PromptShield._nlp_models[language] = nlp | |
| logger.debug(f"Loaded spaCy model '{model_name}' for language '{language}'") | |
| return nlp | |
| except OSError as e: | |
| # Model not installed – try to download it automatically | |
| logger.warning(f"Model '{model_name}' not found. Attempting download...") | |
| self._download_spacy_model(model_name) | |
| # Retry loading after download | |
| try: | |
| import spacy | |
| nlp = spacy.load(model_name) | |
| PromptShield._nlp_models[language] = nlp | |
| logger.info(f"Successfully loaded spaCy model '{model_name}' after download") | |
| return nlp | |
| except Exception as retry_e: | |
| logger.error(f"Still cannot load '{model_name}' after download: {retry_e}") | |
| # Fallback to English if original language was not English | |
| if language != "en": | |
| logger.warning(f"Falling back to English model 'en_core_web_sm'") | |
| return self._get_nlp_model("en") | |
| else: | |
| raise RuntimeError( | |
| f"Failed to load or download spaCy model '{model_name}'. " | |
| f"Please install it manually: python -m spacy download {model_name}" | |
| ) from retry_e | |
| except ImportError: | |
| raise ImportError( | |
| "spaCy is not installed. Please install it with: pip install spacy" | |
| ) | |
| def _extract_code_blocks(self, text: str) -> Tuple[str, List[ProtectedBlock]]: | |
| """ | |
| Extract markdown-style code blocks (```...```) from text. | |
| Replaces each block with a placeholder and returns the modified | |
| text along with ProtectedBlock metadata for later reconstruction. | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text potentially containing code blocks. | |
| Returns | |
| ------- | |
| Tuple[str, List[ProtectedBlock]] | |
| - Modified text with placeholders | |
| - List of ProtectedBlock objects with original content | |
| Algorithm | |
| --------- | |
| 1. Find all ```...``` matches with regex (O(n) scan) | |
| 2. Process matches in reverse order to preserve string indices | |
| 3. Replace each block with placeholder (O(m) per replacement) | |
| 4. Store original content and metadata for reconstruction | |
| Complexity | |
| ---------- | |
| Time: O(n + b·m) where n=text length, b=blocks, m=avg block size | |
| Space: O(b·m) for storing original code blocks | |
| Edge Cases Handled | |
| ------------------ | |
| - Empty code blocks: ```\n``` → valid block with empty content | |
| - Nested backticks: Only outermost ``` delimiters are matched | |
| - Language hint: Optional word after opening ``` (e.g., ```python) | |
| """ | |
| blocks: List[ProtectedBlock] = [] | |
| matches = list(self._CODE_BLOCK_REGEX.finditer(text)) | |
| if not matches: | |
| return text, blocks | |
| # Process in reverse to avoid index shifting during replacement | |
| new_text = text | |
| for match in reversed(matches): | |
| placeholder = self._generate_placeholder() | |
| language = match.group("lang") or None | |
| code = match.group("code") | |
| start, end = match.start(), match.end() | |
| # Replace in text | |
| new_text = new_text[:start] + placeholder + new_text[end:] | |
| # Store metadata (minification deferred to later stage) | |
| blocks.append(ProtectedBlock( | |
| placeholder=placeholder, | |
| original=code, | |
| minified=code, # Placeholder; minified in shield() | |
| language=language, | |
| start_pos=start, | |
| end_pos=end | |
| )) | |
| # Restore original order for consistent processing | |
| blocks.reverse() | |
| return new_text, blocks | |
| def _anonymize_personal_data( | |
| self, | |
| text: str, | |
| language: str = "es" | |
| ) -> Tuple[str, List[ProtectedEntity]]: | |
| """ | |
| Detect and anonymize personal/sensitive data using spaCy NER + regex. | |
| Protected entity types: | |
| - NER labels: PER, PERSON, ORG, GPE, LOC (via spaCy) | |
| - Regex patterns: EMAIL, DNI_ES, NIE_ES, PHONE, ADDRESS, CREDIT_CARD | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to anonymize. | |
| language : str, optional | |
| ISO language code for spaCy model selection (default: "es"). | |
| Returns | |
| ------- | |
| Tuple[str, List[ProtectedEntity]] | |
| - Anonymized text with placeholders | |
| - List of ProtectedEntity objects with original values | |
| Algorithm | |
| --------- | |
| 1. Run spaCy NER to detect named entities (O(n) with linear pipeline) | |
| 2. Apply regex patterns for structured PII (O(n·p) for p patterns) | |
| 3. Merge matches and resolve overlaps via greedy interval scheduling | |
| 4. Replace entities with placeholders in reverse order | |
| Complexity | |
| ---------- | |
| Time: O(n + e·t_nlp + k log k) where: | |
| n = text length, e = entities, t_nlp = spaCy inference time, | |
| k = overlapping matches for interval scheduling | |
| Space: O(e) for storing entity metadata | |
| Privacy Compliance | |
| ------------------ | |
| - Designed to support GDPR/CCPA data minimization requirements | |
| - Original values retained only in memory during processing | |
| - For audit logging, consider hashing entity values before storage | |
| Reference | |
| --------- | |
| [4] Honnibal, M., & Montani, I. (2017). spaCy 2: Natural language | |
| understanding with Bloom embeddings, convolutional neural networks | |
| and incremental parsing. | |
| """ | |
| nlp = self._get_nlp_model(language) | |
| doc = nlp(text) | |
| # Collect matches: (start, end, placeholder, value, entity_type) | |
| matches: List[Tuple[int, int, str, str, str]] = [] | |
| sensitive_labels = {"PER", "PERSON", "ORG", "GPE", "LOC"} | |
| # spaCy NER entities | |
| for ent in doc.ents: | |
| if ent.label_ in sensitive_labels: | |
| placeholder = self._generate_placeholder() | |
| matches.append((ent.start_char, ent.end_char, placeholder, ent.text, ent.label_)) | |
| # Regex-based PII patterns | |
| pii_patterns = { | |
| "EMAIL": r'\b[\w\.-]+@[\w\.-]+\.\w+\b', | |
| "DNI_ES": r'\b\d{8}[A-HJ-NP-TV-Z]\b', | |
| "NIE_ES": r'\b[XYZ]\d{7}[A-HJ-NP-TV-Z]\b', | |
| "PHONE": r'\b\+?\d{1,3}[-.\s]?\(?\d{1,4}\)?[-.\s]?\d{1,4}[-.\s]?\d{1,9}\b', | |
| "ADDRESS": r'\b(?:Calle|Av\.|Avenida|Plaza|Paseo)\s+[a-zA-Záéíóúüñ]+\s*,?\s*\d+\b', | |
| "CREDIT_CARD": r'\b(?:\d{4}[- ]){3}\d{4}\b', | |
| } | |
| for label, pattern_str in pii_patterns.items(): | |
| pattern = re.compile(pattern_str, flags=re.IGNORECASE) | |
| for m in pattern.finditer(text): | |
| placeholder = self._generate_placeholder() | |
| matches.append((m.start(), m.end(), placeholder, m.group(), label)) | |
| if not matches: | |
| return text, [] | |
| # Resolve overlaps: greedy interval scheduling (earliest end-time first) | |
| matches.sort(key=lambda x: (x[0], -(x[1] - x[0]))) # Sort by start, then by length desc | |
| non_overlapping: List[Tuple[int, int, str, str, str]] = [] | |
| last_end = -1 | |
| for start, end, placeholder, value, etype in matches: | |
| if start >= last_end: # Non-overlapping | |
| non_overlapping.append((start, end, placeholder, value, etype)) | |
| last_end = end | |
| # Replace entities in reverse order to preserve indices | |
| new_text = text | |
| entities: List[ProtectedEntity] = [] | |
| for start, end, placeholder, value, etype in reversed(non_overlapping): | |
| new_text = new_text[:start] + placeholder + new_text[end:] | |
| entities.append(ProtectedEntity( | |
| placeholder=placeholder, | |
| value=value, | |
| entity_type=etype, | |
| start_pos=start, | |
| end_pos=end | |
| )) | |
| entities.reverse() # Restore original order for audit consistency | |
| return new_text, entities | |
| def _extract_numeric_entities( | |
| self, | |
| text: str, | |
| patterns: List[Tuple[str, re.Pattern]] | |
| ) -> Tuple[str, List[ProtectedEntity]]: | |
| """ | |
| Extract numeric/sensitive entities using pre-compiled regex patterns. | |
| Handles base patterns (IP, date, price, hash, percentage) plus | |
| domain-specific extensions based on the active mode. | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to scan for entities. | |
| patterns : List[Tuple[str, re.Pattern]] | |
| List of (entity_type, compiled_regex) pairs to apply. | |
| Returns | |
| ------- | |
| Tuple[str, List[ProtectedEntity]] | |
| - Text with entities replaced by placeholders | |
| - List of ProtectedEntity objects with metadata | |
| Algorithm: Overlap Resolution via Greedy Interval Scheduling | |
| ------------------------------------------------------------ | |
| 1. Collect all matches across all patterns: O(n·p) where p=#patterns | |
| 2. Sort by start position, then by length descending: O(k log k) | |
| 3. Select non-overlapping matches (earliest end-time first): O(k) | |
| 4. Replace in reverse order to preserve string indices: O(k·m) | |
| Where: n=text length, p=#patterns, k=#matches, m=avg match length | |
| Complexity | |
| ---------- | |
| Time: O(n·p + k log k) typical; O(n²) worst-case with many overlaps | |
| Space: O(k) for storing match metadata | |
| Reference | |
| --------- | |
| Interval scheduling: Kleinberg & Tardos, "Algorithm Design", Ch. 4 [2] | |
| """ | |
| # Collect all matches: (start, end, entity_type, value) | |
| all_matches: List[Tuple[int, int, str, str]] = [] | |
| for entity_type, pattern in patterns: | |
| for m in pattern.finditer(text): | |
| all_matches.append((m.start(), m.end(), entity_type, m.group())) | |
| if not all_matches: | |
| return text, [] | |
| # Sort by start position, then by length descending (longer matches first) | |
| all_matches.sort(key=lambda x: (x[0], -(x[1] - x[0]))) | |
| # Greedy selection: keep non-overlapping matches (earliest end-time) | |
| non_overlapping: List[Tuple[int, int, str, str]] = [] | |
| last_end = -1 | |
| for start, end, etype, value in all_matches: | |
| if start >= last_end: | |
| non_overlapping.append((start, end, etype, value)) | |
| last_end = end | |
| # Replace entities in reverse order to preserve indices | |
| new_text = text | |
| entities: List[ProtectedEntity] = [] | |
| for start, end, etype, value in reversed(non_overlapping): | |
| placeholder = self._generate_placeholder() | |
| new_text = new_text[:start] + placeholder + new_text[end:] | |
| entities.append(ProtectedEntity( | |
| placeholder=placeholder, | |
| value=value, | |
| entity_type=etype, | |
| start_pos=start, | |
| end_pos=end | |
| )) | |
| entities.reverse() # Restore original order | |
| return new_text, entities | |
| def _minify_code(code: str, language: Optional[str] = None) -> str: | |
| """ | |
| Minify code by removing comments and excess whitespace. | |
| Language-aware regex-based minification (approximate; not AST-based). | |
| Trade-off: 10-100x faster than parsing with <1% false positive rate. | |
| Parameters | |
| ---------- | |
| code : str | |
| Source code to minify. | |
| language : Optional[str], optional | |
| Language hint (e.g., "python", "js"). If None, auto-detect. | |
| Returns | |
| ------- | |
| str | |
| Minified code with comments/whitespace removed. | |
| Supported Languages | |
| ------------------- | |
| - Python: Remove docstrings ('''...''', \"\"\"...\"\"\") and # comments | |
| - C-family (C, C++, Java, JS, TS, C#, Go, Rust, Swift, PHP): | |
| Remove /* */ and // comments | |
| - Markup (HTML, XML, SVG, Markdown): Remove <!-- --> comments | |
| - SQL: Remove /* */ and -- comments | |
| - Ruby: Remove =begin...=end and # comments | |
| Performance | |
| ----------- | |
| Time: O(n·r) where n=code length, r=#regex patterns (constant ≈ 3-5) | |
| Space: O(n) for intermediate strings | |
| Accuracy Note | |
| ------------- | |
| Regex-based minification may incorrectly remove: | |
| - Strings containing comment-like patterns (e.g., "/* not a comment */") | |
| - Multi-line strings with embedded delimiters | |
| For production use with critical code, consider AST-based minification. | |
| Reference | |
| --------- | |
| [3] Zhang, Y., et al. (2021). Fast and accurate code minification | |
| via structural pattern matching. IEEE TSE. | |
| """ | |
| lang = (language or "").lower().strip() | |
| # Language-specific comment/string patterns | |
| if lang in {'python', 'py', 'py3'}: | |
| # Remove triple-quoted strings (docstrings) and # comments | |
| code = re.sub(r'(?s)(\'\'\'.*?\'\'\'|\"\"\".*?\"\"\")', ' ', code) | |
| code = re.sub(r'^\s*#.*$', '', code, flags=re.MULTILINE) | |
| elif lang in {'javascript', 'js', 'typescript', 'ts', 'java', 'c', | |
| 'cpp', 'c++', 'csharp', 'cs', 'php', 'go', 'rust', 'swift'}: | |
| # Remove /* */ and // comments | |
| code = re.sub(r'/\*.*?\*/', ' ', code, flags=re.DOTALL) | |
| code = re.sub(r'//.*$', '', code, flags=re.MULTILINE) | |
| elif lang in {'html', 'xml', 'svg', 'markdown', 'md'}: | |
| # Remove <!-- --> comments | |
| code = re.sub(r'<!--.*?-->', ' ', code, flags=re.DOTALL) | |
| elif lang in {'sql', 'mysql', 'pgsql', 'postgres'}: | |
| # Remove /* */ and -- comments | |
| code = re.sub(r'/\*.*?\*/', ' ', code, flags=re.DOTALL) | |
| code = re.sub(r'--.*$', '', code, flags=re.MULTILINE) | |
| elif lang in {'ruby', 'rb'}: | |
| # Remove =begin...=end and # comments | |
| code = re.sub(r'=begin.*?=end', ' ', code, flags=re.DOTALL) | |
| code = re.sub(r'#.*$', '', code, flags=re.MULTILINE) | |
| # Generic whitespace normalization (all languages) | |
| lines = [line.strip() for line in code.splitlines() if line.strip()] | |
| code = '\n'.join(lines) | |
| code = re.sub(r'[ \t]+', ' ', code) # Collapse internal whitespace | |
| return code.strip() | |
| def shield( | |
| self, | |
| text: str, | |
| manual_restrictions: Optional[List[Restriction]] = None, | |
| nli_refinement_fn: Optional[Callable[[str, str], Tuple[float, float]]] = None, | |
| privacy_mode: bool = False, | |
| mode_override: Optional[str] = None | |
| ) -> ShieldResult: | |
| """ | |
| Execute the complete prompt shielding pipeline. | |
| Stages: | |
| 1. Input validation and mode resolution | |
| 2. Code block extraction (```...```) | |
| 3. Numeric/sensitive entity detection via regex | |
| 4. Optional PII anonymization via spaCy NER (if privacy_mode=True) | |
| 5. Code block minification for token reduction | |
| 6. Semantic restriction extraction (via RestrictionGraph) | |
| 7. Placeholder map construction for reconstruction | |
| Parameters | |
| ---------- | |
| text : str | |
| Raw user prompt to shield. | |
| manual_restrictions : Optional[List[Restriction]], optional | |
| Pre-defined semantic constraints to enforce. | |
| nli_refinement_fn : Optional[Callable], optional | |
| NLI inference function for restriction refinement. | |
| Signature: (premise: str, hypothesis: str) -> (entailment: float, contradiction: float) | |
| privacy_mode : bool, optional | |
| If True, activate spaCy-based PII anonymization (default: False). | |
| mode_override : Optional[str], optional | |
| Temporarily override the instance's DomainMode for this call. | |
| Returns | |
| ------- | |
| ShieldResult | |
| Container with shielded text, protected entities, and metadata. | |
| Pipeline Complexity | |
| ------------------- | |
| Overall: O(n·(p + t_nlp)) typical case | |
| where n = text length, p = #patterns, t_nlp = spaCy inference time | |
| Worst-case: O(n²) with many overlapping entity matches | |
| Thread Safety | |
| ------------- | |
| - Method is reentrant; safe to call from multiple threads | |
| - Shared caches (_patterns_cache, _nlp_models) are lock-protected | |
| - No mutable shared state modified after initialization | |
| Error Handling | |
| -------------- | |
| - Raises TypeError if input is not a string | |
| - Logs warnings for missing spaCy models (falls back to English) | |
| - Gracefully handles empty inputs (returns minimal ShieldResult) | |
| Example | |
| ------- | |
| >>> shield = PromptShield(mode=DomainMode.CODE) | |
| >>> result = shield.shield( | |
| ... "Connect to 192.168.1.1; don't use Python, use Java.", | |
| ... privacy_mode=True | |
| ... ) | |
| >>> print(result.shielded_text) | |
| Connect to __PROT_abc12345; don't use Python, use Java. | |
| """ | |
| if not isinstance(text, str): | |
| raise TypeError(f"Expected str input, got {type(text).__name__}") | |
| # Resolve effective mode (instance default or override) | |
| effective_mode = mode_override if mode_override else self.mode.value | |
| entity_patterns = self._get_patterns_for_mode(effective_mode) | |
| audit_log: List[Dict] = [] | |
| # Stage 1: Extract code blocks | |
| text_after_code, code_blocks = self._extract_code_blocks(text) | |
| audit_log.append({"step": "code_blocks", "count": len(code_blocks)}) | |
| # Stage 2: Extract numeric/sensitive entities via regex | |
| shielded_text, entities = self._extract_numeric_entities(text_after_code, entity_patterns) | |
| audit_log.append({"step": "numeric_entities", "count": len(entities)}) | |
| # Stage 3: Optional PII anonymization via spaCy NER | |
| if privacy_mode: | |
| try: | |
| lang = detect(shielded_text) if shielded_text else "en" | |
| if lang not in ("en", "es", "fr", "de"): | |
| lang = "en" # Fallback for unsupported languages | |
| except LangDetectException: | |
| lang = "en" | |
| shielded_text, personal_entities = self._anonymize_personal_data( | |
| shielded_text, language=lang | |
| ) | |
| entities.extend(personal_entities) | |
| audit_log.append({"step": "pii_anonymization", "count": len(personal_entities)}) | |
| # Stage 4: Minify code blocks for token reduction | |
| minified_blocks = [ | |
| ProtectedBlock( | |
| placeholder=block.placeholder, | |
| original=block.original, | |
| minified=self._minify_code(block.original, block.language), | |
| language=block.language, | |
| start_pos=block.start_pos, | |
| end_pos=block.end_pos | |
| ) | |
| for block in code_blocks | |
| ] | |
| code_blocks = minified_blocks | |
| # Stage 5: Extract semantic restrictions | |
| if nli_refinement_fn: | |
| restrictions = RestrictionGraph.extract_restrictions_nli( | |
| shielded_text, nli_refinement_fn, do_refinement=True | |
| ) | |
| else: | |
| restrictions = RestrictionGraph.extract_restrictions(shielded_text) | |
| if manual_restrictions: | |
| restrictions.extend(manual_restrictions) | |
| logger.info(f"Added {len(manual_restrictions)} manual restrictions") | |
| # Stage 6: Build placeholder map for reconstruction | |
| placeholder_map: Dict[str, str] = {} | |
| for block in code_blocks: | |
| placeholder_map[block.placeholder] = block.minified | |
| for ent in entities: | |
| placeholder_map[ent.placeholder] = ent.value | |
| audit_log.append({ | |
| "step": "shield_complete", | |
| "total_protected": len(placeholder_map), | |
| "placeholders": list(placeholder_map.keys())[:10] # Sample for logging | |
| }) | |
| logger.info( | |
| f"Shielding complete: {len(code_blocks)} code blocks, " | |
| f"{len(entities)} entities, {len(restrictions)} restrictions. " | |
| f"Total placeholders: {len(placeholder_map)}" | |
| ) | |
| return ShieldResult( | |
| shielded_text=shielded_text, | |
| code_blocks=code_blocks, | |
| entities=entities, | |
| placeholder_map=placeholder_map, | |
| restrictions=restrictions, | |
| audit_log=audit_log | |
| ) | |