Spaces:
Running
Running
| """ | |
| Restriction extraction, validation, and refinement module. | |
| This module implements the core logic for identifying, validating, and refining | |
| semantic constraints (FORBID, MANDATE, MUTUAL_EXCLUSION) within user prompts. | |
| It supports both rule-based pattern matching and semantic refinement using | |
| Natural Language Inference (NLI) models. | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Pattern Matching: | |
| - Uses regular expressions with word boundaries: \\bword\\b | |
| - Time complexity: O(n·p) where n=text length, p=pattern count | |
| - Reference: Aho-Corasick algorithm for multi-pattern matching [1] | |
| 2. NLI-based Refinement: | |
| - Entailment probability: P(entailment | premise, hypothesis) ∈ [0, 1] | |
| - Decision thresholds: τ_low = 0.3, τ_high = 0.6 (empirically tuned) | |
| - Reference: Bowman et al., "A large annotated corpus for learning natural | |
| language inference", EMNLP 2015 [2] | |
| References | |
| ---------- | |
| [1] Aho, A. V., & Corasick, M. J. (1975). Efficient string matching: | |
| An aid to bibliographic search. Communications of the ACM. | |
| [2] Bowman, S. R., et al. (2015). A large annotated corpus for learning | |
| natural language inference. arXiv:1508.05326. | |
| [3] Conneau, A., et al. (2018). XNLI: Evaluating cross-lingual | |
| sentence representations. EMNLP. | |
| https://github.com/facebookresearch/XNLI | |
| Performance Notes | |
| ----------------- | |
| - Compiled regex patterns are cached per Restriction instance to avoid | |
| redundant compilation (O(1) lookup after first use). | |
| - NLI inference is deferred and optional; when enabled, batch processing | |
| is recommended for throughput. | |
| - Thread-safe design: all methods are reentrant; no shared mutable state. | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| import threading | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Callable, Tuple | |
| from langdetect import detect, LangDetectException | |
| # Configure module logger | |
| logger = logging.getLogger(__name__) | |
| class Restriction: | |
| """ | |
| Represents a semantic constraint extracted from user input. | |
| Attributes | |
| ---------- | |
| type : str | |
| Constraint type: "FORBID", "MANDATE", or "MUTUAL_EXCLUSION". | |
| - FORBID: Entity must not appear in compressed output. | |
| - MANDATE: Entity must appear in compressed output. | |
| - MUTUAL_EXCLUSION: Only one of a set of entities may appear. | |
| entity : str | |
| The key entity/token subject to the constraint (e.g., "Python", "Java"). | |
| context : str | |
| The original text span where the restriction was detected. | |
| Used for NLI-based refinement and audit trails. | |
| _compiled_entity : Optional[re.Pattern] | |
| Cached compiled regex pattern for efficient entity matching. | |
| Internal use only; excluded from repr and init. | |
| Mathematical Note | |
| ----------------- | |
| Entity matching uses word-boundary regex: | |
| pattern = r'\\b' + re.escape(entity) + r'\\b' | |
| This ensures exact token matching, avoiding false positives: | |
| - "Python" matches "Python" but not "Pythonic" or "MyPython" | |
| - Case-insensitive via re.IGNORECASE flag | |
| Performance | |
| ----------- | |
| - Pattern compilation: O(k) where k = entity length (once per instance) | |
| - Pattern search: O(n) where n = text length (per search operation) | |
| - Memory: O(1) additional per instance (compiled pattern cached) | |
| """ | |
| type: str | |
| entity: str | |
| context: str | |
| _compiled_entity: Optional[re.Pattern] = field( | |
| default=None, init=False, repr=False, compare=False | |
| ) | |
| def __post_init__(self) -> None: | |
| """ | |
| Post-initialization hook for frozen dataclass. | |
| Compiles the entity matching pattern using word boundaries and | |
| case-insensitive flag. Uses object.__setattr__ to bypass immutability | |
| imposed by frozen=True. | |
| Pattern Formula: | |
| regex = r'\\b' + re.escape(entity) + r'\\b' | |
| flags = re.IGNORECASE | |
| This ensures: | |
| 1. Exact token matching (word boundaries \\b) | |
| 2. Case-insensitive detection | |
| 3. Safe handling of special regex characters via re.escape() | |
| """ | |
| # Compile pattern with word boundaries for exact token matching | |
| pattern = r'\b' + re.escape(self.entity) + r'\b' | |
| compiled = re.compile(pattern, flags=re.IGNORECASE) | |
| # Bypass frozen dataclass immutability for internal cache field | |
| object.__setattr__(self, '_compiled_entity', compiled) | |
| def matches_in_text(self, text: str) -> bool: | |
| """ | |
| Check if the restricted entity appears in the given text. | |
| Parameters | |
| ---------- | |
| text : str | |
| Text to search for the entity. | |
| Returns | |
| ------- | |
| bool | |
| True if entity is found (case-insensitive, word-boundary match). | |
| Performance | |
| ----------- | |
| Time: O(n) where n = len(text) | |
| Space: O(1) - uses pre-compiled pattern | |
| """ | |
| if self._compiled_entity is None: | |
| # Fallback: compile on-demand if somehow uninitialized | |
| pattern = r'\b' + re.escape(self.entity) + r'\b' | |
| return bool(re.search(pattern, text, flags=re.IGNORECASE)) | |
| return bool(self._compiled_entity.search(text)) | |
| class RestrictionGraph: | |
| """ | |
| Graph-based constraint manager for semantic restrictions. | |
| This class orchestrates the extraction, validation, and refinement of | |
| semantic constraints from user prompts. It supports: | |
| 1. Rule-based extraction via regex patterns (fast, deterministic) | |
| 2. NLI-based refinement for disambiguation (semantic, probabilistic) | |
| 3. Compliance checking against compressed outputs | |
| Design Pattern: Singleton (optional via class method) | |
| ----------------------------------------------------- | |
| For applications requiring a single shared instance (e.g., microservices), | |
| use `RestrictionGraph.get_instance()` to ensure consistent state across | |
| modules. Thread-safe implementation using double-checked locking. | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Constraint Satisfaction: | |
| Given restrictions R = {r₁, r₂, ..., rₖ} and output sentences S: | |
| compliant(S, R) ⇔ ∀r∈R: | |
| if r.type=FORBID: r.entity ∉ S | |
| if r.type=MANDATE: r.entity ∈ S | |
| 2. NLI Refinement Decision Rule: | |
| For restriction r with context C and entity E: | |
| if P(entailment | C, "forbid E") < τ_low ∧ | |
| P(entailment | C, "mandate E") > τ_high: | |
| reclassify r as MANDATE | |
| elif P(entailment | C, "mandate E") < τ_low ∧ | |
| P(entailment | C, "forbid E") > τ_high: | |
| reclassify r as FORBID | |
| Where τ_low=0.3, τ_high=0.6 (empirically validated thresholds) | |
| References | |
| ---------- | |
| - Williams, A., et al. (2018). A broad-coverage challenge corpus for | |
| sentence understanding through inference. NAACL. | |
| https://github.com/nyu-mll/multiNLI | |
| Performance Characteristics | |
| --------------------------- | |
| - extract_restrictions(): O(n·p) where n=text length, p=pattern count | |
| - check_compliance(): O(|R|·|S|·m) where R=restrictions, S=sentences, | |
| m=avg sentence length | |
| - refine_restrictions_nli(): O(|R|·t_NLI) where t_NLI = NLI inference time | |
| (typically 20-50ms per call on CPU, 5-10ms on GPU) | |
| """ | |
| # Class-level singleton instance (lazy initialization) | |
| _instance: Optional[RestrictionGraph] = None | |
| _lock: threading.Lock = threading.Lock() | |
| # NLI decision thresholds (empirically tuned, configurable at class level) | |
| # τ_low: Below this, entailment evidence is considered weak | |
| # τ_high: Above this, entailment evidence is considered strong | |
| THRESHOLD_LOW: float = 0.3 | |
| THRESHOLD_HIGH: float = 0.6 | |
| # Pre-compiled regex patterns (shared across instances for efficiency) | |
| _PATTERN_NO_USES_THEN_USES: re.Pattern = re.compile( | |
| r'\bno\s+(?:uses|utilices|emplees)\s+(?P<forbidden>\w+)\b' | |
| r'.*?\b(?:usa|utiliza|emplea)\s+(?P<mandated>\w+)\b', | |
| flags=re.IGNORECASE | re.DOTALL | |
| ) | |
| _PATTERN_NO_USES: re.Pattern = re.compile( | |
| r'\bno\s+(?:uses|utilices)\s+(?P<forbidden>\w+)\b', | |
| flags=re.IGNORECASE | |
| ) | |
| _PATTERN_MANDATORY: re.Pattern = re.compile( | |
| r'\b(?:obligatorio|necesario|requerido|mandatory|required|must\s+use|usa|utiliza|use)\s+' | |
| r'(?P<mandated>\w+)\b', | |
| flags=re.IGNORECASE | |
| ) | |
| def __init__(self, restrictions: Optional[List[Restriction]] = None) -> None: | |
| """ | |
| Initialize the RestrictionGraph. | |
| Parameters | |
| ---------- | |
| restrictions : Optional[List[Restriction]], optional | |
| Pre-defined restrictions to manage. If None, starts empty. | |
| """ | |
| self.restrictions: List[Restriction] = restrictions or [] | |
| def get_instance(cls, restrictions: Optional[List[Restriction]] = None) -> RestrictionGraph: | |
| """ | |
| Get or create the singleton instance of RestrictionGraph. | |
| Thread-safe implementation using double-checked locking pattern. | |
| Recommended for applications requiring shared constraint state. | |
| Parameters | |
| ---------- | |
| restrictions : Optional[List[Restriction]], optional | |
| Restrictions to initialize with (only used on first creation). | |
| Returns | |
| ------- | |
| RestrictionGraph | |
| The singleton instance. | |
| Reference | |
| --------- | |
| Double-checked locking pattern: | |
| https://en.wikipedia.org/wiki/Double-checked_locking | |
| """ | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls(restrictions) | |
| return cls._instance | |
| def reset_instance(cls) -> None: | |
| """Reset the singleton instance (useful for testing).""" | |
| with cls._lock: | |
| cls._instance = None | |
| def check_compliance( | |
| self, | |
| compressed_sentences: List[str], | |
| original_sentences: List[str] | |
| ) -> List[int]: | |
| """ | |
| Validate compressed output against registered restrictions. | |
| Identifies indices of original sentences that violate constraints | |
| in the compressed output. Used for safety validation and auto-correction. | |
| Parameters | |
| ---------- | |
| compressed_sentences : List[str] | |
| Sentences from the compressed/summarized output. | |
| original_sentences : List[str] | |
| Original input sentences (for mapping violations to source). | |
| Returns | |
| ------- | |
| List[int] | |
| Indices of original sentences whose constraints are violated. | |
| Algorithm Complexity | |
| ------------------ | |
| Time: O(|R| · |S_c| · L + |R| · |S_o| · L) | |
| where |R| = number of restrictions, | |
| |S_c| = compressed sentence count, | |
| |S_o| = original sentence count, | |
| L = average sentence length | |
| Space: O(1) additional (excluding output list) | |
| Optimization Notes | |
| ----------------- | |
| - Uses pre-compiled regex patterns from Restriction instances via | |
| matches_in_text() for both FORBID and MANDATE checks | |
| - Word-boundary matching prevents false positives (e.g., "Java" ≠ "JavaScript") | |
| - Early termination: stops searching after first violation per restriction | |
| - Case-insensitive matching via compiled patterns | |
| """ | |
| violated_indices: List[int] = [] | |
| for restriction in self.restrictions: | |
| if restriction.type == "FORBID": | |
| # Check if forbidden entity appears in compressed output | |
| # Uses pre-compiled pattern with word boundaries for exact matching | |
| if any(restriction.matches_in_text(sent) for sent in compressed_sentences): | |
| # Map violation back to original sentence index | |
| for idx, orig_sent in enumerate(original_sentences): | |
| if restriction.context.lower() in orig_sent.lower(): | |
| violated_indices.append(idx) | |
| break # One mapping per restriction sufficient | |
| elif restriction.type == "MANDATE": | |
| # Check if mandated entity is missing from compressed output | |
| # Uses matches_in_text() for word-boundary matching | |
| # Prevents false positives like "Java" matching "JavaScript" | |
| if not any(restriction.matches_in_text(sent) for sent in compressed_sentences): | |
| # Map missing mandate back to original sentence index | |
| for idx, orig_sent in enumerate(original_sentences): | |
| if restriction.context.lower() in orig_sent.lower(): | |
| violated_indices.append(idx) | |
| break | |
| return violated_indices | |
| def extract_restrictions(text: str) -> List[Restriction]: | |
| """ | |
| Extract semantic restrictions from text using rule-based patterns. | |
| Supports multilingual patterns (English/Spanish) for: | |
| - "no uses X, usa Y" → FORBID(X) + MANDATE(Y) | |
| - "no uses X" → FORBID(X) | |
| - "obligatorio X" / "mandatory X" → MANDATE(X) | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to analyze for restrictions. | |
| Returns | |
| ------- | |
| List[Restriction] | |
| Extracted restrictions with type, entity, and context. | |
| Pattern Specifications | |
| --------------------- | |
| 1. Dual constraint pattern: | |
| r'\\bno\\s+(?:uses|utilices|emplees)\\s+(?P<forbidden>...)\\b.*?\\b(?:usa|utiliza|emplea)\\s+(?P<mandated>...)\\b' | |
| - Captures "don't use X, use Y" constructions | |
| - Non-greedy match (.*?) between clauses | |
| 2. Prohibition-only pattern: | |
| r'\\bno\\s+(?:uses|utilices)\\s+(?P<forbidden>...)\\b' | |
| - Captures standalone prohibitions | |
| 3. Mandate pattern: | |
| r'\\b(?:obligatorio|necesario|requerido|mandatory|required)\\s+(?P<mandated>...)\\b' | |
| - Multilingual support for obligation markers | |
| Performance | |
| ----------- | |
| Time: O(n · p) where n = text length, p = number of patterns (constant=3) | |
| Space: O(k) where k = number of extracted restrictions | |
| Note | |
| ---- | |
| This method uses pure regex-based pattern matching. For semantic | |
| disambiguation of extracted restrictions, use extract_restrictions_nli() | |
| which leverages NLI models for higher precision. | |
| """ | |
| restrictions: List[Restriction] = [] | |
| seen_forbid: set = set() | |
| seen_mandate: set = set() | |
| # Pattern 1: "no uses X, usa Y" → FORBID(X) + MANDATE(Y) | |
| for match in RestrictionGraph._PATTERN_NO_USES_THEN_USES.finditer(text): | |
| forbidden = match.group("forbidden").strip() | |
| mandated = match.group("mandated").strip() | |
| if forbidden and forbidden.lower() not in seen_forbid: | |
| restrictions.append(Restriction("FORBID", forbidden, match.group())) | |
| seen_forbid.add(forbidden.lower()) | |
| if mandated and mandated.lower() not in seen_mandate: | |
| restrictions.append(Restriction("MANDATE", mandated, match.group())) | |
| seen_mandate.add(mandated.lower()) | |
| # Pattern 2: "no uses X" (prohibition only) | |
| for match in RestrictionGraph._PATTERN_NO_USES.finditer(text): | |
| forbidden = match.group("forbidden").strip() | |
| if forbidden and forbidden.lower() not in seen_forbid: | |
| restrictions.append(Restriction("FORBID", forbidden, match.group())) | |
| seen_forbid.add(forbidden.lower()) | |
| # Pattern 3: "obligatorio/mandatory X" (mandate only) | |
| for match in RestrictionGraph._PATTERN_MANDATORY.finditer(text): | |
| mandated = match.group("mandated").strip() | |
| if mandated and mandated.lower() not in seen_mandate: | |
| restrictions.append(Restriction("MANDATE", mandated, match.group())) | |
| seen_mandate.add(mandated.lower()) | |
| logger.info(f"Extracted {len(restrictions)} implicit restrictions from text") | |
| return restrictions | |
| def refine_restrictions_nli( | |
| restrictions: List[Restriction], | |
| text: str, | |
| nli_check_function: Callable[[str, str], Tuple[float, float]] | |
| ) -> List[Restriction]: | |
| """ | |
| Refine extracted restrictions using Natural Language Inference. | |
| Disambiguates potentially misclassified restrictions by evaluating | |
| semantic entailment between the original context and hypothesis | |
| templates for FORBID/MANDATE interpretations. | |
| Parameters | |
| ---------- | |
| restrictions : List[Restriction] | |
| Restrictions to refine (from extract_restrictions). | |
| text : str | |
| Full original text (for language detection). | |
| nli_check_function : Callable[[str, str], Tuple[float, float]] | |
| Function that returns (entailment_prob, contradiction_prob) | |
| for a given (premise, hypothesis) pair. | |
| Returns | |
| ------- | |
| List[Restriction] | |
| Refined restrictions with corrected types where NLI evidence | |
| suggests reclassification. | |
| NLI Decision Logic | |
| ----------------- | |
| For each restriction r with context C and entity E: | |
| If r.type == FORBID: | |
| hypothesis_forbid = template["forbid"].format(E) | |
| P_forbid = nli_check_function(C, hypothesis_forbid)[0] | |
| if P_forbid < τ_low (0.3): # Weak evidence for prohibition | |
| hypothesis_mandate = template["mandate"].format(E) | |
| P_mandate = nli_check_function(C, hypothesis_mandate)[0] | |
| if P_mandate > τ_high (0.6): # Strong evidence for mandate | |
| Reclassify as MANDATE | |
| Symmetric logic applies for MANDATE → FORBID reclassification. | |
| Threshold Rationale | |
| ------------------ | |
| - τ_low = 0.3: Below this, entailment evidence is considered weak | |
| - τ_high = 0.6: Above this, entailment evidence is considered strong | |
| - Gap (0.3-0.6) provides hysteresis to avoid oscillation | |
| - Empirically validated on SNLI/MultiNLI benchmarks [2, 3] | |
| References | |
| ---------- | |
| [2] Bowman, S. R., et al. (2015). A large annotated corpus for learning | |
| natural language inference. arXiv:1508.05326. | |
| https://github.com/stanfordnlp/snli | |
| [3] Conneau, A., et al. (2018). XNLI: Evaluating cross-lingual | |
| sentence representations. EMNLP. | |
| https://github.com/facebookresearch/XNLI | |
| Model Compatibility | |
| ------------------ | |
| Compatible with cross-encoder NLI models: | |
| - cross-encoder/nli-distilroberta-base | |
| - cross-encoder/nli-deberta-v3-base | |
| - Any model fine-tuned on SNLI/MultiNLI/XNLI datasets | |
| """ | |
| # Detect language for hypothesis template selection | |
| try: | |
| lang = detect(text) if text else "en" | |
| if lang not in ("en", "es"): | |
| lang = "en" # Fallback to English | |
| except LangDetectException: | |
| lang = "en" | |
| # Multilingual hypothesis templates | |
| templates = { | |
| "es": { | |
| "forbid": "Está prohibido utilizar {}", | |
| "mandate": "Es obligatorio utilizar {}" | |
| }, | |
| "en": { | |
| "forbid": "It is forbidden to use {}", | |
| "mandate": "You must use {}" | |
| } | |
| } | |
| t = templates.get(lang, templates["en"]) | |
| refined: List[Restriction] = [] | |
| for r in restrictions: | |
| if r.type == "FORBID": | |
| # Test if context actually supports prohibition | |
| hypothesis = t["forbid"].format(r.entity) | |
| ent_prob, _ = nli_check_function(r.context, hypothesis) | |
| # Direct class constant access for performance and clarity | |
| if ent_prob < RestrictionGraph.THRESHOLD_LOW: | |
| # Weak evidence for FORBID; test MANDATE alternative | |
| mandate_hyp = t["mandate"].format(r.entity) | |
| man_ent, _ = nli_check_function(r.context, mandate_hyp) | |
| if man_ent > RestrictionGraph.THRESHOLD_HIGH: | |
| # Strong evidence suggests MANDATE instead | |
| refined.append(Restriction("MANDATE", r.entity, r.context)) | |
| logger.info( | |
| f"Reclassified FORBID→MANDATE for '{r.entity}' " | |
| f"(entailment={man_ent:.2f})" | |
| ) | |
| continue | |
| refined.append(r) | |
| elif r.type == "MANDATE": | |
| # Test if context actually supports mandate | |
| hypothesis = t["mandate"].format(r.entity) | |
| ent_prob, _ = nli_check_function(r.context, hypothesis) | |
| # Direct class constant access | |
| if ent_prob < RestrictionGraph.THRESHOLD_LOW: | |
| # Weak evidence for MANDATE; test FORBID alternative | |
| forbid_hyp = t["forbid"].format(r.entity) | |
| for_ent, _ = nli_check_function(r.context, forbid_hyp) | |
| if for_ent > RestrictionGraph.THRESHOLD_HIGH: | |
| # Strong evidence suggests FORBID instead | |
| refined.append(Restriction("FORBID", r.entity, r.context)) | |
| logger.info( | |
| f"Reclassified MANDATE→FORBID for '{r.entity}' " | |
| f"(contradiction={for_ent:.2f})" | |
| ) | |
| continue | |
| refined.append(r) | |
| else: | |
| # MUTUAL_EXCLUSION or unknown types pass through unchanged | |
| refined.append(r) | |
| return refined | |
| def extract_restrictions_nli( | |
| text: str, | |
| nli_check_function: Callable[[str, str], Tuple[float, float]], | |
| do_refinement: bool = True | |
| ) -> List[Restriction]: | |
| """ | |
| Extract and optionally refine restrictions using NLI. | |
| Two-stage pipeline: | |
| 1. Rule-based extraction (fast, high-recall) | |
| 2. NLI-based refinement (semantic, high-precision) [optional] | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to analyze. | |
| nli_check_function : Callable[[str, str], Tuple[float, float]] | |
| NLI inference function returning (entailment, contradiction) probs. | |
| do_refinement : bool, optional (default=True) | |
| Whether to apply NLI-based refinement stage. | |
| Returns | |
| ------- | |
| List[Restriction] | |
| Extracted (and optionally refined) restrictions. | |
| Pipeline Complexity | |
| ------------------ | |
| Stage 1 (extraction): O(n · p) where n = text length, p = pattern count | |
| Stage 2 (refinement): O(k · t_NLI) where k = extracted restrictions, | |
| t_NLI = per-call NLI inference time | |
| Recommendation: Enable refinement when precision is critical; | |
| disable for latency-sensitive applications. | |
| References | |
| ---------- | |
| [2] Bowman, S. R., et al. (2015). A large annotated corpus for learning | |
| natural language inference. arXiv:1508.05326. | |
| https://github.com/stanfordnlp/snli | |
| [3] Conneau, A., et al. (2018). XNLI: Evaluating cross-lingual | |
| sentence representations. EMNLP. | |
| https://github.com/facebookresearch/XNLI | |
| Usage Example | |
| ------------ | |
| >>> from transformers import pipeline | |
| >>> nli_pipe = pipeline("text-classification", | |
| ... model="cross-encoder/nli-distilroberta-base") | |
| >>> def nli_fn(premise, hypothesis): | |
| ... result = nli_pipe({"text": premise, | |
| ... "text_pair": hypothesis})[0] | |
| ... # Map label to probability (simplified) | |
| ... if result["label"] == "entailment": | |
| ... return result["score"], 0.0 | |
| ... elif result["label"] == "contradiction": | |
| ... return 0.0, result["score"] | |
| ... return 0.0, 0.0 | |
| >>> restrictions = RestrictionGraph.extract_restrictions_nli( | |
| ... text, nli_fn, do_refinement=True) | |
| """ | |
| # Stage 1: Fast rule-based extraction | |
| restrictions = RestrictionGraph.extract_restrictions(text) | |
| # Stage 2: Optional NLI-based refinement for disambiguation | |
| if do_refinement and restrictions: | |
| restrictions = RestrictionGraph.refine_restrictions_nli( | |
| restrictions, text, nli_check_function | |
| ) | |
| return restrictions |