Spaces:
Running
Running
| r""" | |
| Post-LLM response verification and hallucination detection module. | |
| This module validates that LLM-generated responses respect the semantic | |
| constraints, entity restrictions, and factual consistency established | |
| by the original prompt's TruthTable. | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Entity Authorization Check: | |
| Given authorized entities E_auth = {(typeᵢ, valueᵢ)} from shielded prompt | |
| and response entities E_resp = {(typeⱼ, valueⱼ)}: | |
| violations = E_resp \ E_auth | |
| Time complexity: O(|E_resp| · |E_auth|) for set difference. | |
| 2. NLI-based Contradiction Detection: | |
| For premise P (original context) and hypothesis H (response claim): | |
| contradiction_score = P(contradiction | P, H) ∈ [0, 1] | |
| Threshold τ_cont = 0.7: if score > τ_cont → flag as violation. | |
| Reference: Bowman et al., "A large annotated corpus for learning natural | |
| language inference", EMNLP 2015 [1] | |
| 3. Entailment-based Forbidden Implication: | |
| For forbidden entity F and response R: | |
| implication_score = P(entailment | R, "contains F") | |
| If score > τ_entail (0.6) → response implies forbidden content. | |
| Reference: Williams et al., "A broad-coverage challenge corpus for | |
| sentence understanding through inference", NAACL 2018 [2] | |
| 4. Semantic Drift via Cosine Similarity: | |
| Given embeddings e_orig, e_resp ∈ ℝᵈ: | |
| drift_score = 1 - cos(e_orig, e_resp) = 1 - (e_orig·e_resp)/(||e_orig||·||e_resp||) | |
| Values > τ_drift (0.5) indicate significant semantic deviation. | |
| Reference: Reimers & Gurevych, "Sentence-BERT", EMNLP 2019 [3] | |
| 5. Confidence Score Aggregation: | |
| confidence = max(0.3, 1 - violations / total_checks) | |
| Clamped to [0.3, 1.0] to avoid over-penalizing edge cases. | |
| References | |
| ---------- | |
| [1] Bowman, S. R., et al. (2015). A large annotated corpus for learning | |
| natural language inference. arXiv:1508.05326. | |
| [2] Williams, A., et al. (2018). A broad-coverage challenge corpus for | |
| sentence understanding through inference. NAACL-HLT 2018. | |
| https://github.com/nyu-mll/multiNLI | |
| [3] Reimers, N., & Gurevych, I. (2019). Sentence-BERT: Sentence embeddings | |
| using Siamese BERT-networks. EMNLP-IJCNLP 2019. | |
| https://github.com/UKPLab/sentence-transformers | |
| [4] Honnibal, M., & Montani, I. (2017). spaCy 2: Natural language | |
| understanding with Bloom embeddings. https://github.com/explosion/spaCy | |
| Performance Characteristics | |
| --------------------------- | |
| - _extract_entities(): O(n · p) where n=text length, p=entity patterns | |
| - _verify_with_nli(): O(s · t_nli) where s=sentences, t_nli=NLI inference time | |
| - _verify_forbidden_implications(): O(r · t_nli) where r=restrictions | |
| - verify() full pipeline: O(n·p + s·t_nli + r·t_nli + t_drift) | |
| - Typical latency: 5-15ms (deterministic only), 30-80ms (with NLI on CPU) | |
| Thread Safety | |
| ------------- | |
| - NLI model loading is instance-level; no shared mutable state | |
| - All methods are reentrant; safe for concurrent use with separate instances | |
| - Embedding model (for drift) should be thread-safe per SentenceTransformer docs | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Set, Tuple, Callable | |
| import numpy as np | |
| import torch | |
| from langdetect import detect, LangDetectException | |
| from pysbd import Segmenter | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Conditional import for ONNX backend | |
| try: | |
| from optimum.onnxruntime import ORTModelForSequenceClassification | |
| _ONNX_AVAILABLE = True | |
| except ImportError: | |
| _ONNX_AVAILABLE = False | |
| ORTModelForSequenceClassification = None # type: ignore | |
| from nlproxy.core.segmenter import SemanticSegmenter | |
| # Import from sibling module (circular import handled at runtime) | |
| from nlproxy.core.restriction import Restriction, RestrictionGraph | |
| logger = logging.getLogger(__name__) | |
| class VerificationResult: | |
| """ | |
| Output container for post-LLM verification results. | |
| Attributes | |
| ---------- | |
| confidence_score : float | |
| Overall confidence in response validity ∈ [0.3, 1.0]. | |
| Computed as max(0.3, 1 - violations/total_checks). | |
| violations : List[str] | |
| Human-readable descriptions of detected policy violations. | |
| audit_trail : Dict | |
| Detailed metadata for debugging: | |
| - entities_in_response: extracted entity set from response | |
| - original_entities: authorized entity set from prompt | |
| - restrictions_checked: count of semantic constraints validated | |
| - nli_used: whether NLI-based checks were performed | |
| """ | |
| confidence_score: float | |
| violations: List[str] | |
| audit_trail: Dict = field(default_factory=dict) | |
| class PostLLMVerifier: | |
| """ | |
| Validates LLM responses against TruthTable constraints and semantic consistency. | |
| This class implements a multi-layer verification pipeline: | |
| 1. Deterministic entity authorization via regex pattern matching | |
| 2. Semantic contradiction detection via Natural Language Inference (optional) | |
| 3. Forbidden implication detection via entailment scoring (optional) | |
| 4. Semantic drift monitoring via embedding cosine similarity | |
| Key Design Decisions | |
| -------------------- | |
| - NLI verification is optional: disable for latency-sensitive deployments | |
| - Entity patterns are pre-compiled at initialization for O(1) lookup | |
| - Semantic drift requires external embedding model (shared from Segmenter) | |
| - Confidence score is clamped to avoid false negatives on sparse checks | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Contradiction Detection: | |
| For sentence s ∈ original prompt and response R: | |
| score = P(contradiction | s, R) | |
| Flag if score > τ_cont = 0.7 (empirically tuned on SNLI/MultiNLI) | |
| 2. Forbidden Implication: | |
| For forbidden entity F and response R: | |
| score = P(entailment | R, "contains F") | |
| Flag if score > τ_entail = 0.6 | |
| 3. Semantic Drift: | |
| drift = 1 - cos(e_orig, e_resp) ∈ [0, 2] | |
| Alert if drift > τ_drift = 0.5 | |
| Usage Example | |
| ------------- | |
| >>> verifier = PostLLMVerifier( | |
| ... mode="code", | |
| ... use_nli=True, | |
| ... models_dir=Path("models"), | |
| ... embedding_model=segmenter.embedding_model | |
| ... ) | |
| >>> result = verifier.verify(llm_response, shield_result) | |
| >>> if result.confidence_score < 0.7: | |
| ... print(f"Warning: {result.violations}") | |
| """ | |
| # NLI decision thresholds (empirically tuned) | |
| _THRESHOLD_CONTRADICTION: float = 0.7 | |
| _THRESHOLD_ENTAILMENT: float = 0.6 | |
| _THRESHOLD_DRIFT: float = 0.5 | |
| # Default model configuration (must be pre-downloaded) | |
| _DEFAULT_NLI_MODEL: str = "nli-distilroberta-base" | |
| _DEFAULT_MODELS_DIR: Path = Path("nlproxy") / "models" | |
| # Entity pattern definitions (shared across instances) | |
| _BASE_ENTITY_PATTERNS: Dict[str, re.Pattern] = { | |
| "ip": re.compile( | |
| r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}' | |
| r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)' | |
| r'|\b(?:[A-F0-9]{1,4}:){7}[A-F0-9]{1,4}\b' | |
| r'|\b(?:[A-F0-9]{1,4}:){1,7}:[A-F0-9]{1,4}\b' | |
| r'|\b::[A-F0-9]{1,4}\b' | |
| r'|::1\b', | |
| flags=re.IGNORECASE | |
| ), | |
| "date": re.compile( | |
| r'\b\d{4}-\d{2}-\d{2}\b' # ISO: 2025-06-15 | |
| r'|\b\d{2}/\d{2}/\d{4}\b' # DD/MM/YYYY | |
| r'|\b\d{2}\.\d{2}\.\d{4}\b' # DD.MM.YYYY | |
| ), | |
| "price": re.compile( | |
| r'(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?\s*[\$\€\£\¥]?\s*\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CHF|CAD|AUD)?', | |
| flags=re.IGNORECASE | |
| ), | |
| "hash": re.compile(r'\b[A-Fa-f0-9]{32,64}\b'), | |
| "percentage": re.compile(r'\b\d+(?:\.\d+)?\s*%\b'), | |
| } | |
| def __init__( | |
| self, | |
| mode: str = "general", | |
| use_nli: bool = True, | |
| language: Optional[str] = None, | |
| embedding_model: Optional = None, | |
| segmenter: Optional[SemanticSegmenter] = None, | |
| models_dir: Optional[Path] = None, | |
| nli_model_name: Optional[str] = None | |
| ) -> None: | |
| """ | |
| Initialize the PostLLMVerifier. | |
| Parameters | |
| ---------- | |
| mode : str, optional | |
| Domain mode for potential future pattern extensions (default: "general"). | |
| use_nli : bool, optional | |
| Enable NLI-based semantic verification (default: True). | |
| Disable for latency-critical deployments. | |
| language : Optional[str], optional | |
| ISO 639-1 language code to force sentence segmentation language. | |
| If None, auto-detects via langdetect with fallback to 'en'. | |
| embedding_model : Optional, optional | |
| Pre-loaded SentenceTransformer for semantic drift detection. | |
| If None, drift monitoring is disabled. | |
| models_dir : Optional[Path], optional | |
| Directory containing pre-downloaded NLI models (default: "models"). | |
| NLI model must exist at models_dir / nli_model_name. | |
| nli_model_name : Optional[str], optional | |
| Name of NLI model directory under models_dir (default: nli-distilroberta-base). | |
| Raises | |
| ------ | |
| FileNotFoundError | |
| If NLI is enabled but model not found in models_dir. | |
| ImportError | |
| If ONNX backend requested but optimum.onnxruntime not installed. | |
| """ | |
| self.mode = mode | |
| self.use_nli = use_nli | |
| self.language = language | |
| self.embedding_model = embedding_model | |
| self.segmenter = segmenter or SemanticSegmenter.get_instance( | |
| models_dir=models_dir or self._DEFAULT_MODELS_DIR | |
| ) | |
| # Resolve models directory to point to the specific NLI model folder by default | |
| if models_dir: | |
| candidate = Path(models_dir) | |
| model_path = candidate / (nli_model_name or self._DEFAULT_NLI_MODEL) | |
| if candidate.exists() and candidate.name == (nli_model_name or self._DEFAULT_NLI_MODEL): | |
| self.models_dir = candidate | |
| else: | |
| self.models_dir = model_path | |
| else: | |
| self.models_dir = self._DEFAULT_MODELS_DIR / (nli_model_name or self._DEFAULT_NLI_MODEL) | |
| self.nli_model_name = nli_model_name or self._DEFAULT_NLI_MODEL | |
| # Build entity patterns (deterministic checks) | |
| self.entity_patterns = self._build_entity_patterns() | |
| # Initialize NLI components (optional) | |
| self.nli_model = None | |
| self.nli_tokenizer = None | |
| self._is_onnx = False | |
| if self.use_nli: | |
| self._load_nli_model() | |
| logger.info(f"NLI verification enabled (backend={'ONNX' if self._is_onnx else 'PyTorch'})") | |
| else: | |
| logger.debug("NLI verification disabled") | |
| logger.info(f"PostLLMVerifier initialized: mode={mode}, use_nli={use_nli}") | |
| def _build_entity_patterns(self) -> List[Tuple[str, re.Pattern]]: | |
| """ | |
| Build the list of entity detection patterns. | |
| Returns | |
| ------- | |
| List[Tuple[str, re.Pattern]] | |
| List of (entity_type, compiled_regex) pairs for detection. | |
| Pattern Specifications | |
| --------------------- | |
| - IP: IPv4 (dotted decimal) or IPv6 (hex groups) with word boundaries | |
| - Date: ISO 8601, DD/MM/YYYY, or DD.MM.YYYY formats | |
| - Price: Currency code + symbol + amount with optional decimals | |
| - Hash: 32-64 character hexadecimal strings (MD5, SHA-256, etc.) | |
| - Percentage: Numeric value followed by % symbol | |
| Complexity | |
| ---------- | |
| Time: O(1) - constant number of pattern compilations | |
| Space: O(1) - fixed pattern set stored at class level | |
| """ | |
| return [(name, pattern) for name, pattern in self._BASE_ENTITY_PATTERNS.items()] | |
| def _load_nli_model(self) -> None: | |
| """ | |
| Load the NLI model from local storage. | |
| Supports both ONNX Runtime (preferred) and PyTorch backends. | |
| Model must be pre-downloaded to self.models_dir / self.nli_model_name. | |
| Raises | |
| ------ | |
| FileNotFoundError | |
| If model directory or required files are missing. | |
| ImportError | |
| If ONNX backend requested but dependencies unavailable. | |
| """ | |
| # `self.models_dir` already resolves to the model folder (see __init__ logic). | |
| # Do not append the model name again, otherwise we end up looking for | |
| # `<models_dir>/<model_name>/<model_name>` which is incorrect. | |
| model_path = self.models_dir | |
| if not model_path.exists(): | |
| raise FileNotFoundError( | |
| f"NLI model not found at {model_path}. " | |
| f"Run: python -m nlproxy download_models --models-dir {self._DEFAULT_MODELS_DIR}" | |
| ) | |
| # Prefer ONNX if available and files present | |
| onnx_model_path = model_path / "model.onnx" | |
| use_onnx = _ONNX_AVAILABLE and onnx_model_path.exists() | |
| if use_onnx: | |
| logger.info(f"Loading NLI model (ONNX) from {model_path}...") | |
| import onnxruntime | |
| available_providers = onnxruntime.get_available_providers() | |
| providers = [] | |
| if torch.cuda.is_available() and "CUDAExecutionProvider" in available_providers: | |
| providers.append("CUDAExecutionProvider") | |
| providers.append("CPUExecutionProvider") | |
| self.nli_model = ORTModelForSequenceClassification.from_pretrained( | |
| str(model_path), | |
| provider=providers[0] if len(providers) == 1 else providers | |
| ) | |
| self._is_onnx = True | |
| else: | |
| logger.info(f"Loading NLI model (PyTorch) from {model_path}...") | |
| self.nli_model = AutoModelForSequenceClassification.from_pretrained(str(model_path)) | |
| if torch.cuda.is_available(): | |
| self.nli_model = self.nli_model.to("cuda") | |
| self.nli_model.eval() | |
| self._is_onnx = False | |
| # Load tokenizer (shared backend) | |
| self.nli_tokenizer = AutoTokenizer.from_pretrained(str(model_path)) | |
| logger.debug(f"NLI model loaded: is_onnx={self._is_onnx}") | |
| def _nli_inference(self, premise: str, hypothesis: str) -> Tuple[float, float]: | |
| """ | |
| Perform NLI inference: returns (entailment_prob, contradiction_prob). | |
| Parameters | |
| ---------- | |
| premise : str | |
| Context sentence from original prompt. | |
| hypothesis : str | |
| Claim or statement from LLM response to evaluate. | |
| Returns | |
| ------- | |
| Tuple[float, float] | |
| (P(entailment | premise, hypothesis), P(contradiction | premise, hypothesis)) | |
| Backend Handling | |
| ---------------- | |
| - ONNX: Uses ORTModelForSequenceClassification with PyTorch tensor I/O | |
| - PyTorch: Direct model.forward() with torch.no_grad() for inference | |
| Complexity | |
| ---------- | |
| Time: O(L · d²) where L = tokenized length, d = model hidden dimension | |
| Space: O(L · d) for intermediate activations | |
| Note | |
| ---- | |
| Input is truncated to max_length=256 tokens for efficiency. | |
| For longer contexts, consider chunking or using long-context NLI models. | |
| """ | |
| inputs = self.nli_tokenizer( | |
| premise, | |
| hypothesis, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, | |
| padding=True | |
| ) | |
| if self._is_onnx: | |
| # ONNX backend accepts PyTorch tensors directly | |
| logits = self.nli_model(**inputs).logits | |
| else: | |
| # PyTorch backend: move inputs to device if needed | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = self.nli_model(**inputs).logits | |
| # Convert logits to probabilities: [neutral, entailment, contradiction] | |
| probs = torch.softmax(logits, dim=-1) | |
| entailment_prob = probs[0][1].item() | |
| contradiction_prob = probs[0][2].item() | |
| return entailment_prob, contradiction_prob | |
| def _extract_entities(self, text: str) -> Set[Tuple[str, str]]: | |
| """ | |
| Extract typed entities from text using registered patterns. | |
| Parameters | |
| ---------- | |
| text : str | |
| Text to scan for entities. | |
| Returns | |
| ------- | |
| Set[Tuple[str, str]] | |
| Set of (entity_type, entity_value) pairs found in text. | |
| Complexity | |
| ---------- | |
| Time: O(n · p) where n = text length, p = number of patterns | |
| Space: O(e) where e = number of unique entities found | |
| """ | |
| found: Set[Tuple[str, str]] = set() | |
| for entity_type, pattern in self.entity_patterns: | |
| for match in pattern.finditer(text): | |
| found.add((entity_type, match.group())) | |
| return found | |
| def _verify_with_nli(self, response_text: str, shielded_text: str) -> List[str]: | |
| """ | |
| Detect semantic contradictions between response and original context. | |
| Parameters | |
| ---------- | |
| response_text : str | |
| LLM-generated response to validate. | |
| shielded_text : str | |
| Original prompt text (with placeholders) serving as premise. | |
| Returns | |
| ------- | |
| List[str] | |
| List of contradiction violation descriptions. | |
| Algorithm | |
| --------- | |
| 1. Segment shielded_text into sentences using language-aware PySBD | |
| 2. For each sentence s: | |
| a. Compute P(contradiction | s, response_text) via NLI | |
| b. If score > τ_cont (0.7): flag as semantic contradiction | |
| 3. Return list of flagged violations with scores | |
| Complexity | |
| ---------- | |
| Time: O(s · t_nli) where s = sentence count, t_nli = NLI inference time | |
| Space: O(s) for storing violation messages | |
| Reference | |
| --------- | |
| [1] Bowman et al. (2015). SNLI corpus for NLI training/evaluation. | |
| """ | |
| violations: List[str] = [] | |
| if not self.use_nli or not shielded_text: | |
| return violations | |
| # Segment original text into sentences for granular checking | |
| sentences = [s.strip() for s in self.segmenter.split_sentences(shielded_text) if s.strip()] | |
| for sentence in sentences: | |
| _, contradiction_score = self._nli_inference(sentence, response_text) | |
| if contradiction_score > self._THRESHOLD_CONTRADICTION: | |
| violations.append( | |
| f"Semantic contradiction detected (score={contradiction_score:.2f}): " | |
| f"'{sentence[:80]}...' vs response" | |
| ) | |
| return violations | |
| def _verify_forbidden_implications( | |
| self, | |
| response_text: str, | |
| restrictions: List[Restriction] | |
| ) -> List[str]: | |
| """ | |
| Detect if response semantically implies a forbidden entity. | |
| Parameters | |
| ---------- | |
| response_text : str | |
| LLM-generated response to validate. | |
| restrictions : List[Restriction] | |
| List of semantic constraints (FORBID/MANDATE) from prompt analysis. | |
| Returns | |
| ------- | |
| List[str] | |
| List of forbidden implication violation descriptions. | |
| Algorithm | |
| --------- | |
| For each FORBID restriction with entity F: | |
| 1. Construct hypothesis: "response contains F" | |
| 2. Compute P(entailment | response, hypothesis) via NLI | |
| 3. If score > τ_entail (0.6): response implies forbidden content | |
| Note | |
| ---- | |
| This catches implicit violations where entity F is not explicitly | |
| mentioned but semantically entailed (e.g., "Python code" implies "Python"). | |
| Complexity | |
| ---------- | |
| Time: O(r · t_nli) where r = FORBID restrictions, t_nli = NLI inference time | |
| Space: O(r) for storing violation messages | |
| """ | |
| violations: List[str] = [] | |
| if not self.use_nli: | |
| return violations | |
| for restriction in restrictions: | |
| if restriction.type == "FORBID": | |
| # Hypothesis: response entails presence of forbidden entity | |
| hypothesis = f"The response contains or refers to {restriction.entity}" | |
| entailment_score, _ = self._nli_inference(response_text, hypothesis) | |
| if entailment_score > self._THRESHOLD_ENTAILMENT: | |
| violations.append( | |
| f"Response implies forbidden entity '{restriction.entity}' " | |
| f"(entailment score={entailment_score:.2f})" | |
| ) | |
| return violations | |
| def check_semantic_drift( | |
| self, | |
| original_text: str, | |
| response: str, | |
| threshold: Optional[float] = None | |
| ) -> float: | |
| """ | |
| Compute cosine similarity between original prompt and response embeddings. | |
| Parameters | |
| ---------- | |
| original_text : str | |
| Original user prompt (before compression). | |
| response : str | |
| LLM-generated response to evaluate. | |
| threshold : Optional[float], optional | |
| Override default drift threshold. If None, uses _THRESHOLD_DRIFT. | |
| Returns | |
| ------- | |
| float | |
| Cosine similarity ∈ [-1, 1]. Values < threshold indicate drift. | |
| Mathematical Note | |
| ----------------- | |
| Given L2-normalized embeddings e₁, e₂ ∈ ℝᵈ: | |
| cos(e₁, e₂) = e₁ · e₂ (dot product equals cosine for unit vectors) | |
| Semantic drift score: | |
| drift = 1 - cos(e₁, e₂) ∈ [0, 2] | |
| Lower = more similar, Higher = more divergent | |
| Complexity | |
| ---------- | |
| Time: O(L · d) for encoding + O(d) for dot product | |
| Space: O(d) for embedding vectors | |
| Reference | |
| --------- | |
| [3] Reimers & Gurevych (2019). Sentence-BERT for semantic similarity. | |
| """ | |
| if self.embedding_model is None: | |
| return 1.0 # Cannot evaluate without embedding model | |
| # Encode and normalize embeddings | |
| emb_orig = self.embedding_model.encode( | |
| [original_text], | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ) | |
| emb_resp = self.embedding_model.encode( | |
| [response], | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ) | |
| # Cosine similarity for L2-normalized vectors = dot product | |
| similarity = float(np.dot(emb_orig[0], emb_resp[0])) | |
| return similarity | |
| def get_nli_check_function(self) -> Callable[[str, str], Tuple[float, float]]: | |
| """ | |
| Return a callable for NLI inference: (premise, hypothesis) → (entailment, contradiction). | |
| Returns | |
| ------- | |
| Callable[[str, str], Tuple[float, float]] | |
| Function suitable for passing to RestrictionGraph.refine_restrictions_nli(). | |
| Usage | |
| ----- | |
| >>> verifier = PostLLMVerifier(use_nli=True) | |
| >>> nli_fn = verifier.get_nli_check_function() | |
| >>> ent, cont = nli_fn("Python is slow", "Java is faster") | |
| >>> if ent > 0.6: ... # Strong entailment evidence | |
| """ | |
| def check(premise: str, hypothesis: str) -> Tuple[float, float]: | |
| inputs = self.nli_tokenizer( | |
| premise, | |
| hypothesis, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, | |
| padding=True | |
| ) | |
| if torch.cuda.is_available() and not self._is_onnx: | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = self.nli_model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| return probs[0][1].item(), probs[0][2].item() | |
| return check | |
| def verify( | |
| self, | |
| response_text: str, | |
| shield_result | |
| ) -> VerificationResult: | |
| """ | |
| Execute the complete post-LLM verification pipeline. | |
| Validation Layers | |
| ----------------- | |
| 1. Deterministic Entity Check: | |
| - Extract entities from response via regex patterns | |
| - Compare against authorized entities from shield_result | |
| - Flag unauthorized entities as violations | |
| 2. Restriction Compliance: | |
| - Check FORBID entities: must not appear in response | |
| - Check MANDATE entities: must appear in response | |
| - Use RestrictionGraph for constraint validation | |
| 3. Semantic Contradiction (NLI, optional): | |
| - Compare response against original prompt sentences | |
| - Flag high-contradiction scores as violations | |
| 4. Forbidden Implication (NLI, optional): | |
| - Detect if response semantically entails forbidden content | |
| - Catch implicit violations not caught by exact matching | |
| 5. Confidence Aggregation: | |
| - confidence = max(0.3, 1 - violations / total_checks) | |
| - Clamped to [0.3, 1.0] for interpretability | |
| Parameters | |
| ---------- | |
| response_text : str | |
| LLM-generated response to validate. | |
| shield_result : ShieldResult | |
| Result from PromptShield containing: | |
| - entities: authorized entity set | |
| - restrictions: semantic constraints (FORBID/MANDATE) | |
| - shielded_text: original prompt for NLI premise | |
| Returns | |
| ------- | |
| VerificationResult | |
| Container with confidence score, violations, and audit metadata. | |
| Complexity | |
| ---------- | |
| Overall: O(n·p + s·t_nli + r·t_nli) where: | |
| n = response length, p = entity patterns, | |
| s = prompt sentences, r = restrictions, | |
| t_nli = NLI inference time (~20-50ms on CPU) | |
| Example | |
| ------- | |
| >>> verifier = PostLLMVerifier(mode="code", use_nli=True) | |
| >>> result = verifier.verify(llm_response, shield_result) | |
| >>> if result.confidence_score < 0.7: | |
| ... print(f"Low confidence: {result.violations}") | |
| """ | |
| violations: List[str] = [] | |
| # Layer 1: Deterministic entity authorization | |
| original_entities: Set[Tuple[str, str]] = set() | |
| if hasattr(shield_result, 'entities'): | |
| for entity in shield_result.entities: | |
| original_entities.add((entity.entity_type, entity.value)) | |
| response_entities = self._extract_entities(response_text) | |
| for entity_type, value in response_entities: | |
| if (entity_type, value) not in original_entities: | |
| violations.append( | |
| f"Unauthorized entity in response: {value} (type: {entity_type})" | |
| ) | |
| # Layer 2: Explicit restriction compliance (FORBID/MANDATE) | |
| if hasattr(shield_result, 'restrictions') and shield_result.restrictions: | |
| # Use RestrictionGraph for constraint validation | |
| compliance_violations = RestrictionGraph.get_instance().check_compliance( | |
| [response_text], | |
| [shield_result.shielded_text] if hasattr(shield_result, 'shielded_text') else [] | |
| ) | |
| for restriction in shield_result.restrictions: | |
| if restriction.type == "FORBID": | |
| if restriction.matches_in_text(response_text): | |
| violations.append( | |
| f"Response contains forbidden entity '{restriction.entity}'" | |
| ) | |
| elif restriction.type == "MANDATE": | |
| if not restriction.matches_in_text(response_text): | |
| violations.append( | |
| f"Response missing mandated entity '{restriction.entity}'" | |
| ) | |
| # Layer 3: Semantic contradiction detection (NLI) | |
| if self.use_nli and hasattr(shield_result, 'shielded_text'): | |
| nli_violations = self._verify_with_nli( | |
| response_text, | |
| shield_result.shielded_text | |
| ) | |
| violations.extend(nli_violations) | |
| # Layer 4: Forbidden implication detection (NLI) | |
| if self.use_nli and hasattr(shield_result, 'restrictions') and shield_result.restrictions: | |
| implication_violations = self._verify_forbidden_implications( | |
| response_text, | |
| shield_result.restrictions | |
| ) | |
| violations.extend(implication_violations) | |
| # Layer 5: Confidence score aggregation | |
| total_checks = len(original_entities) + len(shield_result.restrictions if hasattr(shield_result, 'restrictions') else []) | |
| total_checks = max(1, total_checks) # Avoid division by zero | |
| failed_checks = len(violations) | |
| confidence_score = max(0.3, 1.0 - (failed_checks / total_checks)) | |
| # Assemble audit trail for observability | |
| audit_trail = { | |
| "entities_in_response": list(response_entities), | |
| "original_entities": list(original_entities), | |
| "restrictions_checked": len(shield_result.restrictions) if hasattr(shield_result, 'restrictions') else 0, | |
| "nli_used": self.use_nli, | |
| "violation_count": len(violations) | |
| } | |
| logger.debug( | |
| f"Verification complete: confidence={confidence_score:.2f}, " | |
| f"violations={len(violations)}, checks={total_checks}" | |
| ) | |
| return VerificationResult( | |
| confidence_score=round(confidence_score, 2), | |
| violations=violations, | |
| audit_trail=audit_trail | |
| ) | |