""" Safety validation and intent preservation module. This module validates that compressed prompts preserve critical domain-specific intentions and constraints. If key phrases or entities are lost during compression, they are automatically re-injected from the original shielded text. Mathematical Foundations ------------------------ 1. Safety Score Calculation: Given original intents I_orig and forced keywords K_orig: total = |I_orig| + |K_orig| lost = |I_missing| + |K_missing| safety_score = 1 - (lost / total) ∈ [0, 1] Clamped to [0, 1] to handle edge cases. 2. Perplexity as Coherence Metric: PPL(text) = exp(-1/N * Σᵢ log P(wᵢ | w_<ᵢ)) Lower perplexity → more coherent/predictable text. Threshold τ_ppl = 500.0 (empirically tuned for distilgpt2). Reference: Jelinek et al., "Perplexity-based measure of language model quality" 3. Intent Extraction via Pattern Matching: For pattern set P and text T: intents(T, P) = {match.group() : ∃p∈P, p.finditer(T) yields match} Time complexity: O(|T| · |P|) with regex engine optimizations. 4. Keyword Preservation Check: Given forced keyword set K and text T: preserved(K, T) = {k ∈ K : k.lower() ∈ T.lower()} Case-insensitive substring matching for robustness. References ---------- [1] Jelinek, F., et al. (1977). Perplexity—a measure of the difficulty of speech recognition tasks. Journal of the Acoustical Society of America. [2] Bowman, S. R., et al. (2015). A large annotated corpus for learning natural language inference. EMNLP 2015. [3] Conneau, A., et al. (2018). XNLI: Evaluating cross-lingual sentence representations. EMNLP 2018. https://github.com/facebookresearch/XNLI Performance Characteristics --------------------------- - _extract_critical_intents(): O(n · |P|) where n=text length, |P|=pattern count - _find_forced_keywords(): O(|K| · n) for substring searches - _force_include_missing(): O(|M| · |S| · L) where M=missing keywords, S=original sentences, L=avg sentence length - validate() full pipeline: O(n · (|P| + |K|) + |M| · |S| · L) - Perplexity computation (optional): O(N · d²) for N tokens, d=model dimension Thread Safety ------------- - Perplexity model loading uses double-checked locking for thread-safe lazy init - Pattern compilation cache is instance-level (no shared mutable state) - All methods are reentrant; safe for concurrent use Author: IntelliDeep Labs Team License: BSL 1.1 """ from __future__ import annotations import logging import re import threading from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Set, Tuple, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM from nlproxy.core.restriction import RestrictionGraph logger = logging.getLogger(__name__) @dataclass class SafetyReport: """ Output container for safety validation results. Attributes ---------- safety_score : float Preservation score ∈ [0, 1]; higher = better intent retention. Computed as 1 - (lost_intents / total_intents). missing_intents : List[str] Critical phrases/patterns from original that are absent in compressed text. forced_sentences_added : int Count of sentences re-inserted to preserve mandatory content. final_text : str Validated (and potentially corrected) compressed prompt ready for LLM. perplexity : Optional[float] Language model perplexity of final_text (lower = more coherent). None if perplexity calculation was not requested or model unavailable. should_relax : bool Recommendation to reduce compression aggressiveness in future runs. True if perplexity exceeds threshold OR ≥3 sentences were force-added. suggested_threshold : Optional[float] Suggested new aggressiveness value if should_relax=True (typically 0.1). alerts : List[str] Human-readable warnings about validation issues or recommendations. audit_entries : List[Dict] Detailed step-by-step log for debugging and observability. """ safety_score: float missing_intents: List[str] forced_sentences_added: int final_text: str perplexity: Optional[float] should_relax: bool suggested_threshold: Optional[float] alerts: List[str] = field(default_factory=list) audit_entries: List[Dict] = field(default_factory=list) class SafetyChecker: """ Validates compressed prompts for critical intent preservation. This class ensures that domain-specific constraints, warnings, and mandatory content are not lost during semantic compression. If critical elements are missing, they are automatically re-injected from the original shielded text. Key Features ------------ - Domain-aware pattern matching (legal, finance, code, general) - Forced keyword preservation with case-insensitive substring matching - Optional perplexity-based coherence validation - Automatic sentence re-insertion for missing critical content - Adaptive relaxation recommendations based on validation results Mathematical Foundations ------------------------ 1. Safety Score: safety = 1 - (|I_missing| + |K_missing|) / (|I_orig| + |K_orig|) Clamped to [0, 1] for interpretability. 2. Perplexity Thresholding: if PPL(final_text) > τ_ppl: recommend relaxation Default τ_ppl = 500.0 for distilgpt2 (empirically validated). 3. Pattern Matching Complexity: For |P| patterns and text length n: Time: O(n · |P|) with regex engine optimizations Space: O(|P| · m) for compiled patterns, m=avg pattern length Domain Pattern Configuration ---------------------------- Each domain defines regex patterns for critical phrases: legal: confidentiality, privilege, non-disclosure finance: risk disclaimers, non-guarantee statements, advisor references code: language constraints ("no uses X", "usa Y"), deprecation warnings general: importance markers, critical reminders, obligation phrases Usage Example ------------- >>> checker = SafetyChecker(mode="code") >>> report = checker.validate( ... original_text=original_prompt, ... compressed_text=compressed_prompt, ... shield_result=shield_result, ... original_sentences=sentences ... ) >>> if report.safety_score < 0.8: ... print(f"Warning: {report.missing_intents}") """ # Class-level caches for perplexity model (shared across instances) _perplexity_model: Optional[AutoModelForCausalLM] = None _perplexity_tokenizer: Optional[AutoTokenizer] = None _perplexity_lock: threading.Lock = threading.Lock() # Domain-specific regex patterns for critical intent detection # Patterns use word boundaries (\b) for exact phrase matching MODE_PATTERNS: Dict[str, List[str]] = { "legal": [ r'\bno\s+reveal\b', r'\bconfidential\b', r'\bprivileged\b', r'\bstrictly\s+prohibited\b', r'\bdo\s+not\s+disclose\b', r'\bnon[- ]?disclosure\s+agreement\b', r'\battorney[- ]?client\s+privilege\b' ], "finance": [ r'\bdo\s+not\s+invest\b', r'\bhigh\s+risk\b', r'\bnot\s+guaranteed\b', r'\bpast\s+performance\b', r'\bconsult\s+your\s+advisor\b', r'\bnot\s+financial\s+advice\b', r'\bfigures\s+in\s+thousands\b' ], "code": [ r'\bdo\s+not\s+use\s+\w+\b', r'\buse\s+\w+\b', r'\bavoid\s+using\b', r'\bimportant\b', r'\brequired\b', r'\bmandatory\b', r'\bdo\s+not\s+compile\b', r'\bdeprecated\b', r'\bsecurity\b' ], "general": [ r'\bdo\s+not\s+use\b', r'\bavoid\s+using\b', r'\bimportant\b', r'\bcritical\b', r'\bessential\b', r'\bmandatory\b', r'\bdo\s+not\s+forget\b', r'\bremember\s+that\b', r'\bnote\s+that\b' ] } # Forced keywords: if missing, trigger sentence re-insertion # Lowercased for case-insensitive matching MODE_FORCED_KEYWORDS: Dict[str, List[str]] = { "legal": ["confidential", "privileged", "do not reveal"], "finance": ["risk", "not guaranteed", "advisor"], "code": ["java", "python", "do not use", "use"], "general": ["important", "critical", "do not use"] } # Configuration defaults DEFAULT_PERPLEXITY_MODEL: str = "distilgpt2" DEFAULT_MAX_PERPLEXITY: float = 500.0 DEFAULT_MODELS_DIR: Path = Path("nlproxy") / "models" def __init__( self, mode: str = "general", custom_patterns: Optional[List[str]] = None, forced_keywords: Optional[List[str]] = None, perplexity_model: str = DEFAULT_PERPLEXITY_MODEL, max_perplexity: float = DEFAULT_MAX_PERPLEXITY, models_dir: Optional[Path] = None ) -> None: """ Initialize the SafetyChecker. Parameters ---------- mode : str, optional Domain mode for pattern selection: "legal", "finance", "code", "general". custom_patterns : Optional[List[str]], optional Additional regex patterns to supplement mode defaults. forced_keywords : Optional[List[str]], optional Additional keywords that must be preserved (case-insensitive). perplexity_model : str, optional Model name for perplexity calculation (default: distilgpt2). Must be pre-downloaded to models_dir. max_perplexity : float, optional Threshold above which to recommend compression relaxation. models_dir : Optional[Path], optional Directory containing pre-downloaded models (default: "models"). Perplexity model must exist at models_dir / perplexity_model. Raises ------ FileNotFoundError If perplexity model is requested but not found in models_dir. """ self.mode = mode self.max_perplexity = max_perplexity self.perplexity_model_name = perplexity_model # Default to specific perplexity model folder under nlproxy/models when models_dir not provided if models_dir: candidate = Path(models_dir) model_path = candidate / self.perplexity_model_name if candidate.exists() and candidate.name == self.perplexity_model_name: self.models_dir = candidate else: self.models_dir = model_path else: self.models_dir = self.DEFAULT_MODELS_DIR / self.perplexity_model_name # Resolve patterns: mode defaults + custom extensions base_patterns = self.MODE_PATTERNS.get(mode, []) self.patterns = base_patterns + (custom_patterns or []) # Resolve forced keywords: mode defaults + custom extensions (lowercased) base_keywords = self.MODE_FORCED_KEYWORDS.get(mode, []) self.forced_keywords = [ kw.lower() for kw in (base_keywords + (forced_keywords or [])) ] # Compile patterns with caching self._compiled_patterns_cache: Dict[str, List[re.Pattern]] = {} self.compiled_patterns = self._compile_patterns(mode, custom_patterns or []) logger.info(f"SafetyChecker initialized: mode={mode}, patterns={len(self.patterns)}, " f"forced_keywords={len(self.forced_keywords)}, max_ppl={max_perplexity}") def _ensure_perplexity_model(self) -> bool: """ Load the perplexity model from local storage (thread-safe lazy init). Returns ------- bool True if model loaded successfully, False otherwise. Note ---- Model must be pre-downloaded to self.models_dir / self.perplexity_model_name. Use scripts/download_models.py to fetch required models before deployment. """ if SafetyChecker._perplexity_model is not None: return True with SafetyChecker._perplexity_lock: if SafetyChecker._perplexity_model is not None: return True model_path = self.models_dir / self.perplexity_model_name if not model_path.exists(): logger.warning( f"Perplexity model not found at {model_path}. " f"Run: python scripts/download_models.py --model {self.perplexity_model_name}" ) return False try: logger.info(f"Loading perplexity model from {model_path}...") tokenizer = AutoTokenizer.from_pretrained(str(model_path)) model = AutoModelForCausalLM.from_pretrained(str(model_path)) if torch.cuda.is_available(): model = model.to("cuda") logger.debug("Perplexity model loaded on CUDA") SafetyChecker._perplexity_tokenizer = tokenizer SafetyChecker._perplexity_model = model logger.info(f"Perplexity model '{self.perplexity_model_name}' loaded successfully") return True except Exception as e: logger.warning(f"Failed to load perplexity model: {e}") return False def _compile_patterns(self, mode: str, extra_patterns: List[str]) -> List[re.Pattern]: """ Compile regex patterns for a given mode (with caching). Parameters ---------- mode : str Domain mode for pattern selection. extra_patterns : List[str] Additional patterns to include. Returns ------- List[re.Pattern] Compiled regex patterns with re.IGNORECASE flag. Complexity ---------- Time: O(|P| · m) for |P| patterns of avg length m (compilation) Space: O(|P| · m) for cached compiled patterns """ cache_key = f"{mode}|{','.join(sorted(extra_patterns))}" if cache_key not in self._compiled_patterns_cache: patterns = self.MODE_PATTERNS.get(mode, []) + extra_patterns self._compiled_patterns_cache[cache_key] = [ re.compile(p, flags=re.IGNORECASE) for p in patterns ] return self._compiled_patterns_cache[cache_key] def _compute_perplexity(self, text: str) -> Optional[float]: """ Compute language model perplexity for the given text. Perplexity formula: PPL = exp(-1/N * Σᵢ log P(wᵢ | w_<ᵢ)) Lower values indicate more coherent/predictable text. Parameters ---------- text : str Text to evaluate. Returns ------- Optional[float] Perplexity score, or None if model unavailable or error occurred. Complexity ---------- Time: O(N · d²) where N = token count, d = model hidden dimension Space: O(N · d) for activations """ if SafetyChecker._perplexity_model is None: return None try: inputs = SafetyChecker._perplexity_tokenizer( text, return_tensors="pt", truncation=True, max_length=512, # Limit context for efficiency padding=True ) if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} with torch.no_grad(): outputs = SafetyChecker._perplexity_model( **inputs, labels=inputs["input_ids"] ) loss = outputs.loss if loss is not None: return torch.exp(loss).item() except Exception as e: logger.warning(f"Error computing perplexity: {e}") return None def _get_effective_patterns_and_keywords( self, mode: Optional[str] = None, custom_patterns: Optional[List[str]] = None ) -> Tuple[List[re.Pattern], List[str]]: """ Get compiled patterns and forced keywords for validation. Parameters ---------- mode : Optional[str], optional Override instance mode for this call. custom_patterns : Optional[List[str]], optional Additional patterns for this call. Returns ------- Tuple[List[re.Pattern], List[str]] Compiled patterns and lowercased forced keywords. """ if mode is None: return self.compiled_patterns, self.forced_keywords compiled = self._compile_patterns(mode, custom_patterns or []) base_keywords = self.MODE_FORCED_KEYWORDS.get(mode, []) keywords = [kw.lower() for kw in (base_keywords + (custom_patterns or []))] return compiled, keywords def _extract_critical_intents( self, text: str, compiled_patterns: Optional[List[re.Pattern]] = None ) -> List[str]: """ Extract critical intent phrases matching domain patterns. Parameters ---------- text : str Text to scan for intents. compiled_patterns : Optional[List[re.Pattern]], optional Pre-compiled patterns to use. If None, uses instance patterns. Returns ------- List[str] Unique matched intent phrases (deduplicated, case-normalized). Complexity ---------- Time: O(n · |P|) where n = text length, |P| = pattern count Space: O(|I|) for storing unique intents """ if compiled_patterns is None: compiled_patterns = self.compiled_patterns clauses = [c.strip() for c in re.split(r'[.!?\n]+', text) if c.strip()] intents: List[str] = [] seen: Set[str] = set() for clause in clauses: for pattern in compiled_patterns: if pattern.search(clause): clause_key = clause.lower() if clause_key not in seen: seen.add(clause_key) intents.append(clause) break # Once matched, no need to check other patterns for this clause return intents def _find_forced_keywords( self, text: str, forced_keywords: Optional[List[str]] = None ) -> Set[str]: """ Find which forced keywords appear in the text. Uses case-insensitive substring matching for robustness. Parameters ---------- text : str Text to search. forced_keywords : Optional[List[str]], optional Keywords to look for. If None, uses instance keywords. Returns ------- Set[str] Set of found keywords (original case from keyword list). Complexity ---------- Time: O(|K| · n) where |K| = keyword count, n = text length Space: O(|K_found|) for result set """ if forced_keywords is None: forced_keywords = self.forced_keywords text_lower = text.lower() return {kw for kw in forced_keywords if kw in text_lower} def _force_include_missing( self, compressed_text: str, original_sentences: List[str], missing_keywords: Set[str] ) -> List[Tuple[int, str]]: """ Re-insert original sentences for missing critical keywords. Parameters ---------- compressed_text : str Current compressed text (for checking keyword presence). original_sentences : List[str] Original sentences to draw from for re-insertion. missing_keywords : Set[str] Keywords/intents that are missing and need preservation. Returns ------- List[Tuple[int, str]] List of (original_index, sentence) tuples to re-insert. Algorithm --------- 1. For each missing keyword: a. If special marker __GRAPH_REINSERT_{idx}__: extract index b. Else: find first original sentence containing the keyword c. Add sentence if not already in result list 2. Return list preserving original sentence order via indices Complexity ---------- Time: O(|M| · |S| · L) where M=missing keywords, S=original sentences, L=avg sentence length for substring search Space: O(|A| · L) for added sentences list """ added_sentences: List[Tuple[int, str]] = [] seen_sentences: Set[str] = set() for keyword in missing_keywords: # Handle special graph-based re-insertion markers if keyword.startswith("__GRAPH_REINSERT_"): try: idx_str = keyword.split("__GRAPH_REINSERT_")[1].rstrip("__") idx = int(idx_str) if 0 <= idx < len(original_sentences): sentence = original_sentences[idx] if sentence not in seen_sentences: added_sentences.append((idx, sentence)) seen_sentences.add(sentence) logger.info(f"Re-inserted sentence by graph constraint: '{sentence[:60]}...'") except (ValueError, IndexError): pass # Malformed marker; skip continue # Standard keyword-based re-insertion if keyword.lower() not in compressed_text.lower(): for idx, sentence in enumerate(original_sentences): if keyword.lower() in sentence.lower(): if sentence not in seen_sentences: added_sentences.append((idx, sentence)) seen_sentences.add(sentence) logger.info(f"Re-inserted sentence for keyword '{keyword}': '{sentence[:60]}...'") break # Only re-insert first matching sentence per keyword return added_sentences def validate( self, original_text: str, compressed_text: str, shield_result: Optional[Any] = None, original_sentences: Optional[List[str]] = None, compressed_indices: Optional[List[int]] = None, mode: Optional[str] = None, custom_patterns: Optional[List[str]] = None, use_perplexity: bool = False ) -> SafetyReport: """ Validate compressed prompt for critical intent preservation. Main entry point for safety validation. Checks that domain-specific constraints, warnings, and mandatory content are preserved after semantic compression. Parameters ---------- original_text : str Original uncompressed prompt (source of truth for intents). compressed_text : str Compressed prompt to validate. shield_result : Optional[Any], optional Result from PromptShield containing restrictions and placeholders. original_sentences : Optional[List[str]], optional Original sentence list for re-insertion if needed. compressed_indices : Optional[List[int]], optional Indices of compressed sentences for order preservation. mode : Optional[str], optional Override instance mode for this validation call. custom_patterns : Optional[List[str]], optional Additional patterns for this call. use_perplexity : bool, optional Whether to compute perplexity as coherence metric. Returns ------- SafetyReport Validation results with safety score, missing intents, and corrections. Validation Pipeline ------------------- 1. Extract critical intents from original text via pattern matching 2. Identify missing intents in compressed text 3. Check forced keyword preservation 4. Validate implicit restrictions from RestrictionGraph (if available) 5. Re-insert original sentences for missing critical content 6. Reconstruct final text preserving original sentence order 7. Re-validate intents on corrected text 8. Optionally compute perplexity for coherence assessment 9. Generate relaxation recommendations if quality is low 10. Compute final safety score and assemble report Complexity ---------- Overall: O(n · (|P| + |K|) + |M| · |S| · L + t_ppl) where n=text length, |P|=patterns, |K|=keywords, |M|=missing items, |S|=sentences, L=avg length, t_ppl=perplexity computation time (optional) """ # Resolve patterns and keywords for this validation compiled_patterns, forced_keywords = self._get_effective_patterns_and_keywords( mode, custom_patterns ) alerts: List[str] = [] audit: List[Dict] = [] # Stage 1: Extract critical intents from original text original_intents = self._extract_critical_intents(original_text, compiled_patterns) # Stage 2: Identify missing intents in compressed text compressed_lower = compressed_text.lower() missing_intents = [ intent for intent in original_intents if intent.lower() not in compressed_lower ] # Stage 3: Check forced keyword preservation original_forced = self._find_forced_keywords(original_text, forced_keywords) compressed_forced = self._find_forced_keywords(compressed_text, forced_keywords) missing_forced = original_forced - compressed_forced all_missing = set(missing_intents) | missing_forced # Stage 4: Validate implicit restrictions from RestrictionGraph missing_graph_indices: Set[int] = set() if shield_result and hasattr(shield_result, 'restrictions') and shield_result.restrictions: # Flatten compressed text for restriction checking flat_compressed = " ".join( s.strip() for s in compressed_text.splitlines() if s.strip() ) violations = RestrictionGraph.get_instance().check_compliance( [flat_compressed], original_sentences or [] ) for idx in violations: missing_graph_indices.add(idx) if missing_graph_indices: logger.info(f"Implicit restrictions violated: {len(missing_graph_indices)} sentences to re-insert") for idx in missing_graph_indices: all_missing.add(f"__GRAPH_REINSERT_{idx}__") # Stage 5: Re-insert sentences for missing critical content added_sentences: List[Tuple[int, str]] = [] if original_sentences and all_missing: added_sentences = self._force_include_missing( compressed_text, original_sentences, all_missing ) forced_added = len(added_sentences) # Stage 6: Reconstruct final text with original ordering if added_sentences: existing_lines = compressed_text.splitlines() line_indices = compressed_indices or list(range(len(existing_lines))) # Combine existing and added sentences combined = list(zip(line_indices, existing_lines)) for idx, sentence in added_sentences: if sentence not in [s for _, s in combined]: combined.append((idx, sentence)) # Sort by original index and extract text combined.sort(key=lambda x: x[0]) final_text = "\n".join(s for _, s in combined) else: final_text = compressed_text # Stage 7: Re-validate intents on corrected text final_intents = self._extract_critical_intents(final_text, compiled_patterns) still_missing_intents = [ i for i in original_intents if i.lower() not in [x.lower() for x in final_intents] ] final_forced = self._find_forced_keywords(final_text, forced_keywords) still_missing_forced = original_forced - final_forced # Stage 8: Compute perplexity if requested perplexity: Optional[float] = None if use_perplexity: if self._ensure_perplexity_model(): perplexity = self._compute_perplexity(final_text) else: logger.warning("Perplexity requested but model unavailable") # Stage 9: Evaluate need for compression relaxation should_relax = False suggested_threshold: Optional[float] = None if perplexity is not None and perplexity > self.max_perplexity: should_relax = True suggested_threshold = 0.1 alerts.append( f"High perplexity ({perplexity:.1f} > {self.max_perplexity}); " f"consider reducing compression aggressiveness." ) if forced_added > 0: alerts.append( f"Re-inserted {forced_added} sentence(s) to preserve critical intents." ) if forced_added >= 3: should_relax = True suggested_threshold = 0.1 alerts.append( "Multiple re-insertions suggest overly aggressive compression; " f"recommended aggressiveness ≤ {suggested_threshold}." ) # Stage 10: Compute safety score total_critical = len(original_intents) + len(original_forced) lost_critical = len(still_missing_intents) + len(still_missing_forced) if total_critical > 0: safety_score = 1.0 - (lost_critical / total_critical) else: safety_score = 1.0 # No critical items to lose = perfect score # Clamp to [0, 1] for interpretability safety_score = max(0.0, min(1.0, safety_score)) # Assemble audit log audit.append({ "original_intents": original_intents, "missing_intents_before": missing_intents, "missing_intents_after": still_missing_intents, "forced_keywords_found": list(final_forced), "forced_keywords_missing": list(still_missing_forced), "forced_sentences_added": forced_added, "perplexity": perplexity, "safety_score": safety_score }) logger.info( f"Safety validation complete: score={safety_score:.2f}, " f"missing={len(still_missing_intents) + len(still_missing_forced)}, " f"re-inserted={forced_added}" ) return SafetyReport( safety_score=safety_score, missing_intents=still_missing_intents + list(still_missing_forced), forced_sentences_added=forced_added, final_text=final_text, perplexity=perplexity, should_relax=should_relax, suggested_threshold=suggested_threshold, alerts=alerts, audit_entries=audit )